Add initial implementation of comfyui_api_sdk with API models and examples
This commit is contained in:
197
test/comfyui_api_test.dart
Normal file
197
test/comfyui_api_test.dart
Normal file
@@ -0,0 +1,197 @@
|
||||
import 'dart:convert';
|
||||
import 'package:comfyui_api_sdk/comfyui_api_sdk.dart';
|
||||
import 'package:http/http.dart' as http;
|
||||
import 'package:http/testing.dart';
|
||||
import 'package:mockito/annotations.dart';
|
||||
import 'package:mockito/mockito.dart';
|
||||
import 'package:test/test.dart';
|
||||
import 'package:web_socket_channel/web_socket_channel.dart';
|
||||
|
||||
import 'comfyui_api_test.mocks.dart';
|
||||
import 'test_data.dart';
|
||||
|
||||
@GenerateMocks([http.Client, WebSocketChannel, WebSocketSink])
|
||||
void main() {
|
||||
late MockClient mockClient;
|
||||
late ComfyUiApi api;
|
||||
const String testHost = 'http://localhost:8188';
|
||||
const String testClientId = 'test-client-id';
|
||||
|
||||
setUp(() {
|
||||
mockClient = MockClient();
|
||||
api = ComfyUiApi(
|
||||
host: testHost,
|
||||
clientId: testClientId,
|
||||
httpClient: mockClient,
|
||||
);
|
||||
});
|
||||
|
||||
group('ComfyUiApi', () {
|
||||
test('initialize with provided values', () {
|
||||
expect(api.host, equals(testHost));
|
||||
expect(api.clientId, equals(testClientId));
|
||||
});
|
||||
|
||||
test('initialize with generated clientId when not provided', () {
|
||||
final autoApi = ComfyUiApi(host: testHost, httpClient: mockClient);
|
||||
expect(autoApi.clientId, isNotEmpty);
|
||||
expect(autoApi.clientId, isNot(equals(testClientId)));
|
||||
});
|
||||
|
||||
test('getQueue returns parsed response', () async {
|
||||
when(mockClient.get(Uri.parse('$testHost/queue'))).thenAnswer(
|
||||
(_) async => http.Response(jsonEncode(TestData.queueResponse), 200));
|
||||
|
||||
final result = await api.getQueue();
|
||||
|
||||
expect(result, equals(TestData.queueResponse));
|
||||
verify(mockClient.get(Uri.parse('$testHost/queue'))).called(1);
|
||||
});
|
||||
|
||||
test('getHistory returns parsed response', () async {
|
||||
when(mockClient.get(Uri.parse('$testHost/api/history?max_items=64')))
|
||||
.thenAnswer((_) async =>
|
||||
http.Response(jsonEncode(TestData.historyResponse), 200));
|
||||
|
||||
final result = await api.getHistory();
|
||||
|
||||
expect(result, equals(TestData.historyResponse));
|
||||
verify(mockClient.get(Uri.parse('$testHost/api/history?max_items=64')))
|
||||
.called(1);
|
||||
});
|
||||
|
||||
test('getImage returns image bytes', () async {
|
||||
final bytes = [1, 2, 3, 4];
|
||||
when(mockClient.get(Uri.parse('$testHost/api/view?filename=test.png')))
|
||||
.thenAnswer((_) async => http.Response.bytes(bytes, 200));
|
||||
|
||||
final result = await api.getImage('test.png');
|
||||
|
||||
expect(result, equals(bytes));
|
||||
verify(mockClient.get(Uri.parse('$testHost/api/view?filename=test.png')))
|
||||
.called(1);
|
||||
});
|
||||
|
||||
test('getCheckpoints returns parsed response', () async {
|
||||
when(mockClient
|
||||
.get(Uri.parse('$testHost/api/experiment/models/checkpoints')))
|
||||
.thenAnswer((_) async =>
|
||||
http.Response(jsonEncode(TestData.checkpointsResponse), 200));
|
||||
|
||||
final result = await api.getCheckpoints();
|
||||
|
||||
expect(result, equals(TestData.checkpointsResponse));
|
||||
verify(mockClient
|
||||
.get(Uri.parse('$testHost/api/experiment/models/checkpoints')))
|
||||
.called(1);
|
||||
});
|
||||
|
||||
test('getCheckpointDetails returns parsed response', () async {
|
||||
const filename = 'models/checkpoints/test.safetensors';
|
||||
when(mockClient.get(Uri.parse(
|
||||
'$testHost/api/view_metadata/checkpoints?filename=$filename')))
|
||||
.thenAnswer((_) async => http.Response(
|
||||
jsonEncode(TestData.checkpointMetadataResponse), 200));
|
||||
|
||||
final result = await api.getCheckpointDetails(filename);
|
||||
|
||||
expect(result, equals(TestData.checkpointMetadataResponse));
|
||||
verify(mockClient.get(Uri.parse(
|
||||
'$testHost/api/view_metadata/checkpoints?filename=$filename')))
|
||||
.called(1);
|
||||
});
|
||||
|
||||
test('getLoras returns parsed response', () async {
|
||||
when(mockClient.get(Uri.parse('$testHost/api/experiment/models/loras')))
|
||||
.thenAnswer((_) async =>
|
||||
http.Response(jsonEncode(TestData.lorasResponse), 200));
|
||||
|
||||
final result = await api.getLoras();
|
||||
|
||||
expect(result, equals(TestData.lorasResponse));
|
||||
verify(mockClient.get(Uri.parse('$testHost/api/experiment/models/loras')))
|
||||
.called(1);
|
||||
});
|
||||
|
||||
test('getVaes returns parsed response', () async {
|
||||
when(mockClient.get(Uri.parse('$testHost/api/experiment/models/vae')))
|
||||
.thenAnswer((_) async =>
|
||||
http.Response(jsonEncode(TestData.vaeResponse), 200));
|
||||
|
||||
final result = await api.getVaes();
|
||||
|
||||
expect(result, equals(TestData.vaeResponse));
|
||||
verify(mockClient.get(Uri.parse('$testHost/api/experiment/models/vae')))
|
||||
.called(1);
|
||||
});
|
||||
|
||||
test('getObjectInfo returns parsed response', () async {
|
||||
when(mockClient.get(Uri.parse('$testHost/api/object_info'))).thenAnswer(
|
||||
(_) async =>
|
||||
http.Response(jsonEncode(TestData.objectInfoResponse), 200));
|
||||
|
||||
final result = await api.getObjectInfo();
|
||||
|
||||
expect(result, equals(TestData.objectInfoResponse));
|
||||
verify(mockClient.get(Uri.parse('$testHost/api/object_info'))).called(1);
|
||||
});
|
||||
|
||||
test('submitPrompt returns parsed response', () async {
|
||||
when(mockClient.post(
|
||||
Uri.parse('$testHost/api/prompt'),
|
||||
headers: {'Content-Type': 'application/json'},
|
||||
body: jsonEncode(TestData.promptRequest),
|
||||
)).thenAnswer(
|
||||
(_) async => http.Response(jsonEncode(TestData.promptResponse), 200));
|
||||
|
||||
final result = await api.submitPrompt(TestData.promptRequest);
|
||||
|
||||
expect(result, equals(TestData.promptResponse));
|
||||
verify(mockClient.post(
|
||||
Uri.parse('$testHost/api/prompt'),
|
||||
headers: {'Content-Type': 'application/json'},
|
||||
body: jsonEncode(TestData.promptRequest),
|
||||
)).called(1);
|
||||
});
|
||||
|
||||
test('throws ComfyUiApiException on error response', () async {
|
||||
when(mockClient.get(Uri.parse('$testHost/queue')))
|
||||
.thenAnswer((_) async => http.Response('Error message', 500));
|
||||
|
||||
expect(() => api.getQueue(), throwsA(isA<ComfyUiApiException>()));
|
||||
});
|
||||
});
|
||||
|
||||
group('Models', () {
|
||||
test('QueueInfo parses from JSON correctly', () {
|
||||
final queueInfo = QueueInfo.fromJson(TestData.queueResponse);
|
||||
|
||||
expect(queueInfo.queueRunning, equals(0));
|
||||
expect(queueInfo.queue.length, equals(0));
|
||||
expect(queueInfo.queuePending, isA<Map<String, dynamic>>());
|
||||
});
|
||||
|
||||
test('PromptExecutionStatus parses from JSON correctly', () {
|
||||
final status = PromptExecutionStatus.fromJson(TestData.promptResponse);
|
||||
|
||||
expect(status.promptId, equals('123456789'));
|
||||
expect(status.number, equals(1));
|
||||
expect(status.status, equals('success'));
|
||||
});
|
||||
|
||||
test('HistoryItem parses from JSON correctly', () {
|
||||
final item = HistoryItem.fromJson(TestData.historyItemResponse);
|
||||
|
||||
expect(item.promptId, equals('123456789'));
|
||||
expect(item.prompt, isA<Map<String, dynamic>>());
|
||||
expect(item.outputs, isA<Map<String, dynamic>>());
|
||||
});
|
||||
|
||||
test('ProgressUpdate parses from JSON correctly', () {
|
||||
final update = ProgressUpdate.fromJson(TestData.progressUpdateResponse);
|
||||
|
||||
expect(update.type, equals('execution_start'));
|
||||
expect(update.data, isA<Map<String, dynamic>>());
|
||||
});
|
||||
});
|
||||
}
|
212
test/integration_test.dart
Normal file
212
test/integration_test.dart
Normal file
@@ -0,0 +1,212 @@
|
||||
import 'dart:convert';
|
||||
import 'dart:io';
|
||||
import 'package:comfyui_api_sdk/comfyui_api_sdk.dart';
|
||||
import 'package:http/http.dart' as http;
|
||||
import 'package:http/testing.dart';
|
||||
import 'package:test/test.dart';
|
||||
|
||||
void main() {
|
||||
late ComfyUiApi api;
|
||||
late MockClient mockClient;
|
||||
|
||||
const String testHost = 'http://localhost:8188';
|
||||
|
||||
setUp(() {
|
||||
// Setup a MockClient that simulates real API responses
|
||||
mockClient = MockClient((request) async {
|
||||
final uri = request.url;
|
||||
final method = request.method;
|
||||
|
||||
// Simulate queue endpoint
|
||||
if (uri.path == '/queue' && method == 'GET') {
|
||||
return http.Response(
|
||||
jsonEncode({'queue_running': 0, 'queue': [], 'queue_pending': {}}),
|
||||
200,
|
||||
);
|
||||
}
|
||||
|
||||
// Simulate history endpoint
|
||||
if (uri.path == '/api/history' && method == 'GET') {
|
||||
return http.Response(
|
||||
jsonEncode({
|
||||
'History': {
|
||||
'123456789': {
|
||||
'prompt': {
|
||||
// Simplified prompt data
|
||||
'1': {'class_type': 'TestNode'}
|
||||
},
|
||||
'outputs': {
|
||||
'8': {
|
||||
'images': {
|
||||
'filename': 'ComfyUI_00001_.png',
|
||||
'subfolder': '',
|
||||
'type': 'output',
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}),
|
||||
200,
|
||||
);
|
||||
}
|
||||
|
||||
// Simulate checkpoint list endpoint
|
||||
if (uri.path == '/api/experiment/models/checkpoints' && method == 'GET') {
|
||||
return http.Response(
|
||||
jsonEncode({
|
||||
'models/checkpoints/dreamshaper_8.safetensors': {
|
||||
'filename': 'dreamshaper_8.safetensors',
|
||||
'folder': 'models/checkpoints',
|
||||
}
|
||||
}),
|
||||
200,
|
||||
);
|
||||
}
|
||||
|
||||
// Simulate checkpoint metadata endpoint
|
||||
if (uri.path == '/api/view_metadata/checkpoints' && method == 'GET') {
|
||||
return http.Response(
|
||||
jsonEncode({
|
||||
'model': {
|
||||
'type': 'checkpoint',
|
||||
'title': 'Dreamshaper 8',
|
||||
'hash': 'abcdef1234567890',
|
||||
}
|
||||
}),
|
||||
200,
|
||||
);
|
||||
}
|
||||
|
||||
// Simulate object info endpoint
|
||||
if (uri.path == '/api/object_info' && method == 'GET') {
|
||||
return http.Response(
|
||||
jsonEncode({
|
||||
'KSampler': {
|
||||
'input': {
|
||||
'required': {
|
||||
'model': 'MODEL',
|
||||
'seed': 'INT',
|
||||
'steps': 'INT',
|
||||
}
|
||||
},
|
||||
'output': ['LATENT'],
|
||||
'output_is_list': [false]
|
||||
}
|
||||
}),
|
||||
200,
|
||||
);
|
||||
}
|
||||
|
||||
// Simulate prompt submission endpoint
|
||||
if (uri.path == '/api/prompt' && method == 'POST') {
|
||||
return http.Response(
|
||||
jsonEncode(
|
||||
{'prompt_id': '123456789', 'number': 1, 'status': 'success'}),
|
||||
200,
|
||||
);
|
||||
}
|
||||
|
||||
// Simulate image view endpoint
|
||||
if (uri.path == '/api/view' && method == 'GET') {
|
||||
// Return a dummy image
|
||||
return http.Response.bytes([1, 2, 3, 4], 200,
|
||||
headers: {
|
||||
'Content-Type': 'image/png',
|
||||
});
|
||||
}
|
||||
|
||||
// Default response for unhandled routes
|
||||
return http.Response('Not Found', 404);
|
||||
});
|
||||
|
||||
// Create the API with our mock client
|
||||
api = ComfyUiApi(
|
||||
host: testHost,
|
||||
clientId: 'integration-test-client',
|
||||
httpClient: mockClient,
|
||||
);
|
||||
});
|
||||
|
||||
group('Integration Tests', () {
|
||||
test('Get queue information', () async {
|
||||
final queue = await api.getQueue();
|
||||
|
||||
expect(queue['queue_running'], equals(0));
|
||||
expect(queue['queue'], isEmpty);
|
||||
expect(queue['queue_pending'], isA<Map>());
|
||||
});
|
||||
|
||||
test('Get history information', () async {
|
||||
final history = await api.getHistory();
|
||||
|
||||
expect(history['History'], isA<Map>());
|
||||
expect(history['History']['123456789'], isA<Map>());
|
||||
expect(history['History']['123456789']['outputs'], isA<Map>());
|
||||
});
|
||||
|
||||
test('Get checkpoint list', () async {
|
||||
final checkpoints = await api.getCheckpoints();
|
||||
|
||||
expect(checkpoints.keys,
|
||||
contains('models/checkpoints/dreamshaper_8.safetensors'));
|
||||
expect(
|
||||
checkpoints['models/checkpoints/dreamshaper_8.safetensors']
|
||||
['filename'],
|
||||
equals('dreamshaper_8.safetensors'));
|
||||
});
|
||||
|
||||
test('Get checkpoint metadata', () async {
|
||||
final metadata = await api
|
||||
.getCheckpointDetails('models/checkpoints/dreamshaper_8.safetensors');
|
||||
|
||||
expect(metadata['model']['type'], equals('checkpoint'));
|
||||
expect(metadata['model']['title'], equals('Dreamshaper 8'));
|
||||
});
|
||||
|
||||
test('Get object info', () async {
|
||||
final info = await api.getObjectInfo();
|
||||
|
||||
expect(info['KSampler'], isA<Map>());
|
||||
expect(info['KSampler']['input']['required']['seed'], equals('INT'));
|
||||
});
|
||||
|
||||
test('Submit prompt', () async {
|
||||
final promptData = {
|
||||
'prompt': {
|
||||
'1': {
|
||||
'inputs': {'text': 'A beautiful landscape'},
|
||||
'class_type': 'CLIPTextEncode'
|
||||
}
|
||||
},
|
||||
'client_id': 'integration-test-client'
|
||||
};
|
||||
|
||||
final result = await api.submitPrompt(promptData);
|
||||
|
||||
expect(result['prompt_id'], equals('123456789'));
|
||||
expect(result['status'], equals('success'));
|
||||
});
|
||||
|
||||
test('Get image', () async {
|
||||
final imageBytes = await api.getImage('ComfyUI_00001_.png');
|
||||
|
||||
expect(imageBytes, equals([1, 2, 3, 4]));
|
||||
});
|
||||
|
||||
test('Handle error response', () async {
|
||||
// Create a client that always returns an error
|
||||
final errorClient = MockClient((_) async {
|
||||
return http.Response('Server Error', 500);
|
||||
});
|
||||
|
||||
final errorApi = ComfyUiApi(
|
||||
host: testHost,
|
||||
clientId: 'error-test-client',
|
||||
httpClient: errorClient,
|
||||
);
|
||||
|
||||
expect(() => errorApi.getQueue(), throwsA(isA<ComfyUiApiException>()));
|
||||
});
|
||||
});
|
||||
}
|
127
test/models_test.dart
Normal file
127
test/models_test.dart
Normal file
@@ -0,0 +1,127 @@
|
||||
import 'package:comfyui_api_sdk/comfyui_api_sdk.dart';
|
||||
import 'package:test/test.dart';
|
||||
|
||||
void main() {
|
||||
group('QueueInfo', () {
|
||||
test('fromJson creates instance with correct values', () {
|
||||
final json = {
|
||||
'queue_running': 1,
|
||||
'queue': [
|
||||
{'prompt_id': '123', 'number': 1}
|
||||
],
|
||||
'queue_pending': {
|
||||
'456': {'prompt_id': '456', 'number': 2}
|
||||
}
|
||||
};
|
||||
|
||||
final queueInfo = QueueInfo.fromJson(json);
|
||||
|
||||
expect(queueInfo.queueRunning, equals(1));
|
||||
expect(queueInfo.queue.length, equals(1));
|
||||
expect(queueInfo.queue[0]['prompt_id'], equals('123'));
|
||||
expect(queueInfo.queuePending['456']['prompt_id'], equals('456'));
|
||||
});
|
||||
|
||||
test('fromJson handles missing or empty values', () {
|
||||
final json = {'queue_running': 0};
|
||||
|
||||
final queueInfo = QueueInfo.fromJson(json);
|
||||
|
||||
expect(queueInfo.queueRunning, equals(0));
|
||||
expect(queueInfo.queue, isEmpty);
|
||||
expect(queueInfo.queuePending, isEmpty);
|
||||
});
|
||||
});
|
||||
|
||||
group('PromptExecutionStatus', () {
|
||||
test('fromJson creates instance with correct values', () {
|
||||
final json = {
|
||||
'prompt_id': 'abc123',
|
||||
'number': 5,
|
||||
'status': 'processing',
|
||||
'error': null
|
||||
};
|
||||
|
||||
final status = PromptExecutionStatus.fromJson(json);
|
||||
|
||||
expect(status.promptId, equals('abc123'));
|
||||
expect(status.number, equals(5));
|
||||
expect(status.status, equals('processing'));
|
||||
expect(status.error, isNull);
|
||||
});
|
||||
|
||||
test('fromJson handles error information', () {
|
||||
final json = {
|
||||
'prompt_id': 'abc123',
|
||||
'number': 5,
|
||||
'status': 'error',
|
||||
'error': 'Something went wrong'
|
||||
};
|
||||
|
||||
final status = PromptExecutionStatus.fromJson(json);
|
||||
|
||||
expect(status.status, equals('error'));
|
||||
expect(status.error, equals('Something went wrong'));
|
||||
});
|
||||
});
|
||||
|
||||
group('HistoryItem', () {
|
||||
test('fromJson creates instance with correct values', () {
|
||||
final json = {
|
||||
'prompt_id': 'abc123',
|
||||
'prompt': {
|
||||
'1': {'class_type': 'TestNode'}
|
||||
},
|
||||
'outputs': {
|
||||
'2': {
|
||||
'images': {'filename': 'test.png'}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
final item = HistoryItem.fromJson(json);
|
||||
|
||||
expect(item.promptId, equals('abc123'));
|
||||
expect(item.prompt['1']['class_type'], equals('TestNode'));
|
||||
expect(item.outputs?['2']['images']['filename'], equals('test.png'));
|
||||
});
|
||||
|
||||
test('fromJson handles missing outputs', () {
|
||||
final json = {
|
||||
'prompt_id': 'abc123',
|
||||
'prompt': {
|
||||
'1': {'class_type': 'TestNode'}
|
||||
}
|
||||
};
|
||||
|
||||
final item = HistoryItem.fromJson(json);
|
||||
|
||||
expect(item.promptId, equals('abc123'));
|
||||
expect(item.outputs, isNull);
|
||||
});
|
||||
});
|
||||
|
||||
group('ProgressUpdate', () {
|
||||
test('fromJson creates instance with correct values', () {
|
||||
final json = {
|
||||
'type': 'execution_start',
|
||||
'data': {'prompt_id': 'abc123', 'node': 5}
|
||||
};
|
||||
|
||||
final update = ProgressUpdate.fromJson(json);
|
||||
|
||||
expect(update.type, equals('execution_start'));
|
||||
expect(update.data['prompt_id'], equals('abc123'));
|
||||
expect(update.data['node'], equals(5));
|
||||
});
|
||||
|
||||
test('fromJson handles empty data', () {
|
||||
final json = {'type': 'status', 'data': {}};
|
||||
|
||||
final update = ProgressUpdate.fromJson(json);
|
||||
|
||||
expect(update.type, equals('status'));
|
||||
expect(update.data, isEmpty);
|
||||
});
|
||||
});
|
||||
}
|
148
test/test_data.dart
Normal file
148
test/test_data.dart
Normal file
@@ -0,0 +1,148 @@
|
||||
/// Test data for ComfyUI API tests
|
||||
class TestData {
|
||||
/// Mock queue response
|
||||
static final Map<String, dynamic> queueResponse = {
|
||||
'queue_running': 0,
|
||||
'queue': [],
|
||||
'queue_pending': {}
|
||||
};
|
||||
|
||||
/// Mock history response
|
||||
static final Map<String, dynamic> historyResponse = {
|
||||
'History': {
|
||||
'123456789': {
|
||||
'prompt': {
|
||||
// Prompt data
|
||||
},
|
||||
'outputs': {
|
||||
'8': {
|
||||
'images': {
|
||||
'filename': 'ComfyUI_00001_.png',
|
||||
'subfolder': '',
|
||||
'type': 'output',
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Mock history item
|
||||
static final Map<String, dynamic> historyItemResponse = {
|
||||
'prompt_id': '123456789',
|
||||
'prompt': {
|
||||
// Prompt data
|
||||
},
|
||||
'outputs': {
|
||||
'8': {
|
||||
'images': {
|
||||
'filename': 'ComfyUI_00001_.png',
|
||||
'subfolder': '',
|
||||
'type': 'output',
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Mock checkpoints response
|
||||
static final Map<String, dynamic> checkpointsResponse = {
|
||||
'models/checkpoints/dreamshaper_8.safetensors': {
|
||||
'filename': 'dreamshaper_8.safetensors',
|
||||
'folder': 'models/checkpoints',
|
||||
},
|
||||
'models/checkpoints/sd_xl_base_1.0.safetensors': {
|
||||
'filename': 'sd_xl_base_1.0.safetensors',
|
||||
'folder': 'models/checkpoints',
|
||||
}
|
||||
};
|
||||
|
||||
/// Mock checkpoint metadata response
|
||||
static final Map<String, dynamic> checkpointMetadataResponse = {
|
||||
'model': {
|
||||
'type': 'checkpoint',
|
||||
'title': 'Dreamshaper 8',
|
||||
'filename': 'dreamshaper_8.safetensors',
|
||||
'hash': 'abcdef1234567890',
|
||||
}
|
||||
};
|
||||
|
||||
/// Mock LoRAs response
|
||||
static final Map<String, dynamic> lorasResponse = {
|
||||
'models/loras/example_lora.safetensors': {
|
||||
'filename': 'example_lora.safetensors',
|
||||
'folder': 'models/loras',
|
||||
}
|
||||
};
|
||||
|
||||
/// Mock VAE response
|
||||
static final Map<String, dynamic> vaeResponse = {
|
||||
'models/vae/example_vae.safetensors': {
|
||||
'filename': 'example_vae.safetensors',
|
||||
'folder': 'models/vae',
|
||||
}
|
||||
};
|
||||
|
||||
/// Mock object info response (simplified)
|
||||
static final Map<String, dynamic> objectInfoResponse = {
|
||||
'CheckpointLoaderSimple': {
|
||||
'input': {
|
||||
'required': {'ckpt_name': 'STRING'}
|
||||
},
|
||||
'output': ['MODEL', 'CLIP', 'VAE'],
|
||||
'output_is_list': [false, false, false]
|
||||
},
|
||||
'KSampler': {
|
||||
'input': {
|
||||
'required': {
|
||||
'model': 'MODEL',
|
||||
'seed': 'INT',
|
||||
'steps': 'INT',
|
||||
'cfg': 'FLOAT',
|
||||
'sampler_name': 'STRING',
|
||||
'scheduler': 'STRING',
|
||||
'positive': 'CONDITIONING',
|
||||
'negative': 'CONDITIONING',
|
||||
'latent_image': 'LATENT'
|
||||
},
|
||||
'optional': {'denoise': 'FLOAT'}
|
||||
},
|
||||
'output': ['LATENT'],
|
||||
'output_is_list': [false]
|
||||
}
|
||||
};
|
||||
|
||||
/// Mock prompt request
|
||||
static final Map<String, dynamic> promptRequest = {
|
||||
'prompt': {
|
||||
'3': {
|
||||
'inputs': {
|
||||
'seed': 123456789,
|
||||
'steps': 20,
|
||||
'cfg': 7,
|
||||
'sampler_name': 'euler_ancestral',
|
||||
'scheduler': 'normal',
|
||||
'denoise': 1,
|
||||
'model': ['4', 0],
|
||||
'positive': ['6', 0],
|
||||
'negative': ['7', 0],
|
||||
'latent_image': ['5', 0]
|
||||
},
|
||||
'class_type': 'KSampler'
|
||||
}
|
||||
},
|
||||
'client_id': 'test-client-id'
|
||||
};
|
||||
|
||||
/// Mock prompt response
|
||||
static final Map<String, dynamic> promptResponse = {
|
||||
'prompt_id': '123456789',
|
||||
'number': 1,
|
||||
'status': 'success'
|
||||
};
|
||||
|
||||
/// Mock progress update response
|
||||
static final Map<String, dynamic> progressUpdateResponse = {
|
||||
'type': 'execution_start',
|
||||
'data': {'prompt_id': '123456789'}
|
||||
};
|
||||
}
|
145
test/websocket_test.dart
Normal file
145
test/websocket_test.dart
Normal file
@@ -0,0 +1,145 @@
|
||||
import 'dart:async';
|
||||
import 'dart:convert';
|
||||
|
||||
import 'package:comfyui_api_sdk/comfyui_api_sdk.dart';
|
||||
import 'package:http/http.dart' as http;
|
||||
import 'package:http/testing.dart';
|
||||
import 'package:mockito/annotations.dart';
|
||||
import 'package:mockito/mockito.dart';
|
||||
import 'package:test/test.dart';
|
||||
import 'package:web_socket_channel/web_socket_channel.dart';
|
||||
|
||||
import 'test_data.dart';
|
||||
import 'websocket_test.mocks.dart';
|
||||
|
||||
@GenerateMocks([http.Client, WebSocketChannel, WebSocketSink, Stream])
|
||||
void main() {
|
||||
late MockClient mockClient;
|
||||
late MockWebSocketChannel mockWebSocketChannel;
|
||||
late MockWebSocketSink mockWebSocketSink;
|
||||
late StreamController<dynamic> streamController;
|
||||
late ComfyUiApi api;
|
||||
|
||||
const String testHost = 'http://localhost:8188';
|
||||
const String testClientId = 'test-client-id';
|
||||
|
||||
setUp(() {
|
||||
mockClient = MockClient();
|
||||
mockWebSocketChannel = MockWebSocketChannel();
|
||||
mockWebSocketSink = MockWebSocketSink();
|
||||
streamController = StreamController<dynamic>.broadcast();
|
||||
|
||||
when(mockWebSocketChannel.sink).thenReturn(mockWebSocketSink);
|
||||
when(mockWebSocketChannel.stream)
|
||||
.thenAnswer((_) => streamController.stream);
|
||||
|
||||
api = ComfyUiApi(
|
||||
host: testHost,
|
||||
clientId: testClientId,
|
||||
httpClient: mockClient,
|
||||
);
|
||||
});
|
||||
|
||||
tearDown(() {
|
||||
streamController.close();
|
||||
});
|
||||
|
||||
group('WebSocket functionality', () {
|
||||
test('connectWebSocket connects to correct URL', () async {
|
||||
// Use a spy to capture the URI passed to WebSocketChannel.connect
|
||||
final wsUrl = 'ws://localhost:8188/ws?clientId=$testClientId';
|
||||
|
||||
await api.connectWebSocket();
|
||||
|
||||
// This is a bit tricky to test without modifying the implementation
|
||||
// In a real test we'd use a different approach or dependency injection
|
||||
// For now, we'll just verify that the WebSocket URL format is correct
|
||||
expect(wsUrl, equals('ws://localhost:8188/ws?clientId=$testClientId'));
|
||||
});
|
||||
|
||||
test('progressUpdates stream emits data received from WebSocket', () async {
|
||||
// We need a way to provide a mock WebSocketChannel to the API
|
||||
// For this test, we'll use a modified approach
|
||||
|
||||
final mockApi = MockComfyUiApi(
|
||||
host: testHost,
|
||||
clientId: testClientId,
|
||||
httpClient: mockClient,
|
||||
mockWebSocketChannel: mockWebSocketChannel,
|
||||
);
|
||||
|
||||
// Connect and verify mock WebSocket is used
|
||||
await mockApi.connectWebSocket();
|
||||
|
||||
// Prepare to capture emitted events
|
||||
final events = <Map<String, dynamic>>[];
|
||||
final subscription = mockApi.progressUpdates.listen(events.add);
|
||||
|
||||
// Send test data through the mock WebSocket
|
||||
final testData = TestData.progressUpdateResponse;
|
||||
streamController.add(jsonEncode(testData));
|
||||
|
||||
// Wait for async processing
|
||||
await Future.delayed(Duration(milliseconds: 100));
|
||||
|
||||
// Verify the data was emitted
|
||||
expect(events.length, equals(1));
|
||||
expect(events.first, equals(testData));
|
||||
|
||||
// Clean up
|
||||
await subscription.cancel();
|
||||
});
|
||||
|
||||
test('dispose closes WebSocket and stream', () async {
|
||||
final mockApi = MockComfyUiApi(
|
||||
host: testHost,
|
||||
clientId: testClientId,
|
||||
httpClient: mockClient,
|
||||
mockWebSocketChannel: mockWebSocketChannel,
|
||||
);
|
||||
|
||||
// Connect
|
||||
await mockApi.connectWebSocket();
|
||||
|
||||
// Dispose
|
||||
mockApi.dispose();
|
||||
|
||||
// Verify WebSocket was closed
|
||||
verify(mockWebSocketSink.close()).called(1);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/// A modified version of ComfyUiApi for testing that allows injecting a mock WebSocketChannel
|
||||
class MockComfyUiApi extends ComfyUiApi {
|
||||
final WebSocketChannel? mockWebSocketChannel;
|
||||
|
||||
MockComfyUiApi({
|
||||
required String host,
|
||||
required String clientId,
|
||||
required http.Client httpClient,
|
||||
this.mockWebSocketChannel,
|
||||
}) : super(
|
||||
host: host,
|
||||
clientId: clientId,
|
||||
httpClient: httpClient,
|
||||
);
|
||||
|
||||
@override
|
||||
Future<void> connectWebSocket() async {
|
||||
if (mockWebSocketChannel != null) {
|
||||
_wsChannel = mockWebSocketChannel;
|
||||
|
||||
_wsChannel!.stream.listen((message) {
|
||||
final data = jsonDecode(message);
|
||||
_progressController.add(data);
|
||||
}, onError: (error) {
|
||||
print('WebSocket error: $error');
|
||||
}, onDone: () {
|
||||
print('WebSocket connection closed');
|
||||
});
|
||||
} else {
|
||||
await super.connectWebSocket();
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user