comfyui_api_sdk/test/comfyui_api_test.dart

198 lines
7.0 KiB
Dart

import 'dart:convert';
import 'package:comfyui_api_sdk/comfyui_api_sdk.dart';
import 'package:http/http.dart' as http;
import 'package:http/testing.dart';
import 'package:mockito/annotations.dart';
import 'package:mockito/mockito.dart';
import 'package:test/test.dart';
import 'package:web_socket_channel/web_socket_channel.dart';
import 'comfyui_api_test.mocks.dart';
import 'test_data.dart';
@GenerateMocks([http.Client, WebSocketChannel, WebSocketSink])
void main() {
late MockClient mockClient;
late ComfyUiApi api;
const String testHost = 'http://localhost:8188';
const String testClientId = 'test-client-id';
setUp(() {
mockClient = MockClient();
api = ComfyUiApi(
host: testHost,
clientId: testClientId,
httpClient: mockClient,
);
});
group('ComfyUiApi', () {
test('initialize with provided values', () {
expect(api.host, equals(testHost));
expect(api.clientId, equals(testClientId));
});
test('initialize with generated clientId when not provided', () {
final autoApi = ComfyUiApi(host: testHost, httpClient: mockClient);
expect(autoApi.clientId, isNotEmpty);
expect(autoApi.clientId, isNot(equals(testClientId)));
});
test('getQueue returns parsed response', () async {
when(mockClient.get(Uri.parse('$testHost/queue'))).thenAnswer(
(_) async => http.Response(jsonEncode(TestData.queueResponse), 200));
final result = await api.getQueue();
expect(result, equals(TestData.queueResponse));
verify(mockClient.get(Uri.parse('$testHost/queue'))).called(1);
});
test('getHistory returns parsed response', () async {
when(mockClient.get(Uri.parse('$testHost/api/history?max_items=64')))
.thenAnswer((_) async =>
http.Response(jsonEncode(TestData.historyResponse), 200));
final result = await api.getHistory();
expect(result, equals(TestData.historyResponse));
verify(mockClient.get(Uri.parse('$testHost/api/history?max_items=64')))
.called(1);
});
test('getImage returns image bytes', () async {
final bytes = [1, 2, 3, 4];
when(mockClient.get(Uri.parse('$testHost/api/view?filename=test.png')))
.thenAnswer((_) async => http.Response.bytes(bytes, 200));
final result = await api.getImage('test.png');
expect(result, equals(bytes));
verify(mockClient.get(Uri.parse('$testHost/api/view?filename=test.png')))
.called(1);
});
test('getCheckpoints returns parsed response', () async {
when(mockClient
.get(Uri.parse('$testHost/api/experiment/models/checkpoints')))
.thenAnswer((_) async =>
http.Response(jsonEncode(TestData.checkpointsResponse), 200));
final result = await api.getCheckpoints();
expect(result, equals(TestData.checkpointsResponse));
verify(mockClient
.get(Uri.parse('$testHost/api/experiment/models/checkpoints')))
.called(1);
});
test('getCheckpointDetails returns parsed response', () async {
const filename = 'models/checkpoints/test.safetensors';
when(mockClient.get(Uri.parse(
'$testHost/api/view_metadata/checkpoints?filename=$filename')))
.thenAnswer((_) async => http.Response(
jsonEncode(TestData.checkpointMetadataResponse), 200));
final result = await api.getCheckpointDetails(filename);
expect(result, equals(TestData.checkpointMetadataResponse));
verify(mockClient.get(Uri.parse(
'$testHost/api/view_metadata/checkpoints?filename=$filename')))
.called(1);
});
test('getLoras returns parsed response', () async {
when(mockClient.get(Uri.parse('$testHost/api/experiment/models/loras')))
.thenAnswer((_) async =>
http.Response(jsonEncode(TestData.lorasResponse), 200));
final result = await api.getLoras();
expect(result, equals(TestData.lorasResponse));
verify(mockClient.get(Uri.parse('$testHost/api/experiment/models/loras')))
.called(1);
});
test('getVaes returns parsed response', () async {
when(mockClient.get(Uri.parse('$testHost/api/experiment/models/vae')))
.thenAnswer((_) async =>
http.Response(jsonEncode(TestData.vaeResponse), 200));
final result = await api.getVaes();
expect(result, equals(TestData.vaeResponse));
verify(mockClient.get(Uri.parse('$testHost/api/experiment/models/vae')))
.called(1);
});
test('getObjectInfo returns parsed response', () async {
when(mockClient.get(Uri.parse('$testHost/api/object_info'))).thenAnswer(
(_) async =>
http.Response(jsonEncode(TestData.objectInfoResponse), 200));
final result = await api.getObjectInfo();
expect(result, equals(TestData.objectInfoResponse));
verify(mockClient.get(Uri.parse('$testHost/api/object_info'))).called(1);
});
test('submitPrompt returns parsed response', () async {
when(mockClient.post(
Uri.parse('$testHost/api/prompt'),
headers: {'Content-Type': 'application/json'},
body: jsonEncode(TestData.promptRequest),
)).thenAnswer(
(_) async => http.Response(jsonEncode(TestData.promptResponse), 200));
final result = await api.submitPrompt(TestData.promptRequest);
expect(result, equals(TestData.promptResponse));
verify(mockClient.post(
Uri.parse('$testHost/api/prompt'),
headers: {'Content-Type': 'application/json'},
body: jsonEncode(TestData.promptRequest),
)).called(1);
});
test('throws ComfyUiApiException on error response', () async {
when(mockClient.get(Uri.parse('$testHost/queue')))
.thenAnswer((_) async => http.Response('Error message', 500));
expect(() => api.getQueue(), throwsA(isA<ComfyUiApiException>()));
});
});
group('Models', () {
test('QueueInfo parses from JSON correctly', () {
final queueInfo = QueueInfo.fromJson(TestData.queueResponse);
expect(queueInfo.queueRunning, equals(0));
expect(queueInfo.queue.length, equals(0));
expect(queueInfo.queuePending, isA<Map<String, dynamic>>());
});
test('PromptExecutionStatus parses from JSON correctly', () {
final status = PromptExecutionStatus.fromJson(TestData.promptResponse);
expect(status.promptId, equals('123456789'));
expect(status.number, equals(1));
expect(status.status, equals('success'));
});
test('HistoryItem parses from JSON correctly', () {
final item = HistoryItem.fromJson(TestData.historyItemResponse);
expect(item.promptId, equals('123456789'));
expect(item.prompt, isA<Map<String, dynamic>>());
expect(item.outputs, isA<Map<String, dynamic>>());
});
test('ProgressUpdate parses from JSON correctly', () {
final update = ProgressUpdate.fromJson(TestData.progressUpdateResponse);
expect(update.type, equals('execution_start'));
expect(update.data, isA<Map<String, dynamic>>());
});
});
}