198 lines
7.0 KiB
Dart
198 lines
7.0 KiB
Dart
import 'dart:convert';
|
|
import 'package:comfyui_api_sdk/comfyui_api_sdk.dart';
|
|
import 'package:http/http.dart' as http;
|
|
import 'package:http/testing.dart';
|
|
import 'package:mockito/annotations.dart';
|
|
import 'package:mockito/mockito.dart';
|
|
import 'package:test/test.dart';
|
|
import 'package:web_socket_channel/web_socket_channel.dart';
|
|
|
|
import 'comfyui_api_test.mocks.dart';
|
|
import 'test_data.dart';
|
|
|
|
@GenerateMocks([http.Client, WebSocketChannel, WebSocketSink])
|
|
void main() {
|
|
late MockClient mockClient;
|
|
late ComfyUiApi api;
|
|
const String testHost = 'http://localhost:8188';
|
|
const String testClientId = 'test-client-id';
|
|
|
|
setUp(() {
|
|
mockClient = MockClient();
|
|
api = ComfyUiApi(
|
|
host: testHost,
|
|
clientId: testClientId,
|
|
httpClient: mockClient,
|
|
);
|
|
});
|
|
|
|
group('ComfyUiApi', () {
|
|
test('initialize with provided values', () {
|
|
expect(api.host, equals(testHost));
|
|
expect(api.clientId, equals(testClientId));
|
|
});
|
|
|
|
test('initialize with generated clientId when not provided', () {
|
|
final autoApi = ComfyUiApi(host: testHost, httpClient: mockClient);
|
|
expect(autoApi.clientId, isNotEmpty);
|
|
expect(autoApi.clientId, isNot(equals(testClientId)));
|
|
});
|
|
|
|
test('getQueue returns parsed response', () async {
|
|
when(mockClient.get(Uri.parse('$testHost/queue'))).thenAnswer(
|
|
(_) async => http.Response(jsonEncode(TestData.queueResponse), 200));
|
|
|
|
final result = await api.getQueue();
|
|
|
|
expect(result, equals(TestData.queueResponse));
|
|
verify(mockClient.get(Uri.parse('$testHost/queue'))).called(1);
|
|
});
|
|
|
|
test('getHistory returns parsed response', () async {
|
|
when(mockClient.get(Uri.parse('$testHost/api/history?max_items=64')))
|
|
.thenAnswer((_) async =>
|
|
http.Response(jsonEncode(TestData.historyResponse), 200));
|
|
|
|
final result = await api.getHistory();
|
|
|
|
expect(result, equals(TestData.historyResponse));
|
|
verify(mockClient.get(Uri.parse('$testHost/api/history?max_items=64')))
|
|
.called(1);
|
|
});
|
|
|
|
test('getImage returns image bytes', () async {
|
|
final bytes = [1, 2, 3, 4];
|
|
when(mockClient.get(Uri.parse('$testHost/api/view?filename=test.png')))
|
|
.thenAnswer((_) async => http.Response.bytes(bytes, 200));
|
|
|
|
final result = await api.getImage('test.png');
|
|
|
|
expect(result, equals(bytes));
|
|
verify(mockClient.get(Uri.parse('$testHost/api/view?filename=test.png')))
|
|
.called(1);
|
|
});
|
|
|
|
test('getCheckpoints returns parsed response', () async {
|
|
when(mockClient
|
|
.get(Uri.parse('$testHost/api/experiment/models/checkpoints')))
|
|
.thenAnswer((_) async =>
|
|
http.Response(jsonEncode(TestData.checkpointsResponse), 200));
|
|
|
|
final result = await api.getCheckpoints();
|
|
|
|
expect(result, equals(TestData.checkpointsResponse));
|
|
verify(mockClient
|
|
.get(Uri.parse('$testHost/api/experiment/models/checkpoints')))
|
|
.called(1);
|
|
});
|
|
|
|
test('getCheckpointDetails returns parsed response', () async {
|
|
const filename = 'models/checkpoints/test.safetensors';
|
|
when(mockClient.get(Uri.parse(
|
|
'$testHost/api/view_metadata/checkpoints?filename=$filename')))
|
|
.thenAnswer((_) async => http.Response(
|
|
jsonEncode(TestData.checkpointMetadataResponse), 200));
|
|
|
|
final result = await api.getCheckpointDetails(filename);
|
|
|
|
expect(result, equals(TestData.checkpointMetadataResponse));
|
|
verify(mockClient.get(Uri.parse(
|
|
'$testHost/api/view_metadata/checkpoints?filename=$filename')))
|
|
.called(1);
|
|
});
|
|
|
|
test('getLoras returns parsed response', () async {
|
|
when(mockClient.get(Uri.parse('$testHost/api/experiment/models/loras')))
|
|
.thenAnswer((_) async =>
|
|
http.Response(jsonEncode(TestData.lorasResponse), 200));
|
|
|
|
final result = await api.getLoras();
|
|
|
|
expect(result, equals(TestData.lorasResponse));
|
|
verify(mockClient.get(Uri.parse('$testHost/api/experiment/models/loras')))
|
|
.called(1);
|
|
});
|
|
|
|
test('getVaes returns parsed response', () async {
|
|
when(mockClient.get(Uri.parse('$testHost/api/experiment/models/vae')))
|
|
.thenAnswer((_) async =>
|
|
http.Response(jsonEncode(TestData.vaeResponse), 200));
|
|
|
|
final result = await api.getVaes();
|
|
|
|
expect(result, equals(TestData.vaeResponse));
|
|
verify(mockClient.get(Uri.parse('$testHost/api/experiment/models/vae')))
|
|
.called(1);
|
|
});
|
|
|
|
test('getObjectInfo returns parsed response', () async {
|
|
when(mockClient.get(Uri.parse('$testHost/api/object_info'))).thenAnswer(
|
|
(_) async =>
|
|
http.Response(jsonEncode(TestData.objectInfoResponse), 200));
|
|
|
|
final result = await api.getObjectInfo();
|
|
|
|
expect(result, equals(TestData.objectInfoResponse));
|
|
verify(mockClient.get(Uri.parse('$testHost/api/object_info'))).called(1);
|
|
});
|
|
|
|
test('submitPrompt returns parsed response', () async {
|
|
when(mockClient.post(
|
|
Uri.parse('$testHost/api/prompt'),
|
|
headers: {'Content-Type': 'application/json'},
|
|
body: jsonEncode(TestData.promptRequest),
|
|
)).thenAnswer(
|
|
(_) async => http.Response(jsonEncode(TestData.promptResponse), 200));
|
|
|
|
final result = await api.submitPrompt(TestData.promptRequest);
|
|
|
|
expect(result, equals(TestData.promptResponse));
|
|
verify(mockClient.post(
|
|
Uri.parse('$testHost/api/prompt'),
|
|
headers: {'Content-Type': 'application/json'},
|
|
body: jsonEncode(TestData.promptRequest),
|
|
)).called(1);
|
|
});
|
|
|
|
test('throws ComfyUiApiException on error response', () async {
|
|
when(mockClient.get(Uri.parse('$testHost/queue')))
|
|
.thenAnswer((_) async => http.Response('Error message', 500));
|
|
|
|
expect(() => api.getQueue(), throwsA(isA<ComfyUiApiException>()));
|
|
});
|
|
});
|
|
|
|
group('Models', () {
|
|
test('QueueInfo parses from JSON correctly', () {
|
|
final queueInfo = QueueInfo.fromJson(TestData.queueResponse);
|
|
|
|
expect(queueInfo.queueRunning, equals(0));
|
|
expect(queueInfo.queue.length, equals(0));
|
|
expect(queueInfo.queuePending, isA<Map<String, dynamic>>());
|
|
});
|
|
|
|
test('PromptExecutionStatus parses from JSON correctly', () {
|
|
final status = PromptExecutionStatus.fromJson(TestData.promptResponse);
|
|
|
|
expect(status.promptId, equals('123456789'));
|
|
expect(status.number, equals(1));
|
|
expect(status.status, equals('success'));
|
|
});
|
|
|
|
test('HistoryItem parses from JSON correctly', () {
|
|
final item = HistoryItem.fromJson(TestData.historyItemResponse);
|
|
|
|
expect(item.promptId, equals('123456789'));
|
|
expect(item.prompt, isA<Map<String, dynamic>>());
|
|
expect(item.outputs, isA<Map<String, dynamic>>());
|
|
});
|
|
|
|
test('ProgressUpdate parses from JSON correctly', () {
|
|
final update = ProgressUpdate.fromJson(TestData.progressUpdateResponse);
|
|
|
|
expect(update.type, equals('execution_start'));
|
|
expect(update.data, isA<Map<String, dynamic>>());
|
|
});
|
|
});
|
|
}
|