From 0b2769310b62d5ca77235f606042006e37f2dfc5 Mon Sep 17 00:00:00 2001 From: Menno van Leeuwen Date: Thu, 20 Mar 2025 14:27:29 +0000 Subject: [PATCH] Add initial structure for comfyui_api_sdk with API models and event handling --- .dart_tool/package_config.json | 64 +++++++-- .gitignore | 1 + Makefile | 4 + comfyui-api-sdk.dart | 205 --------------------------- example/example.dart | 94 ------------ lib/comfyui_api_sdk.dart | 4 +- lib/src/comfyui_api.dart | 167 +++++++++++++++++++++- lib/src/comfyui_api_sdk.dart | 7 + lib/src/models/callbacks.dart | 11 ++ lib/src/models/execution_event.dart | 31 ++++ lib/src/models/progress_event.dart | 39 +++++ lib/src/models/websocket_event.dart | 109 ++++++++++++++ pubspec.lock | 82 ++++++++--- pubspec.yaml | 16 ++- test/comfyui_api_test.dart | 197 -------------------------- test/integration_test.dart | 212 ---------------------------- test/models_test.dart | 127 ----------------- test/test_data.dart | 148 ------------------- test/websocket_test.dart | 145 ------------------- 19 files changed, 491 insertions(+), 1172 deletions(-) create mode 100644 .gitignore create mode 100644 Makefile delete mode 100644 comfyui-api-sdk.dart delete mode 100644 example/example.dart create mode 100644 lib/src/comfyui_api_sdk.dart create mode 100644 lib/src/models/callbacks.dart create mode 100644 lib/src/models/execution_event.dart create mode 100644 lib/src/models/progress_event.dart create mode 100644 lib/src/models/websocket_event.dart delete mode 100644 test/comfyui_api_test.dart delete mode 100644 test/integration_test.dart delete mode 100644 test/models_test.dart delete mode 100644 test/test_data.dart delete mode 100644 test/websocket_test.dart diff --git a/.dart_tool/package_config.json b/.dart_tool/package_config.json index 92233e5..68d175d 100644 --- a/.dart_tool/package_config.json +++ b/.dart_tool/package_config.json @@ -57,7 +57,7 @@ }, { "name": "build_runner", - "rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/build_runner-2.4.14", + "rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/build_runner-2.4.15", "packageUri": "lib/", "languageVersion": "3.6" }, @@ -133,6 +133,18 @@ "packageUri": "lib/", "languageVersion": "3.1" }, + { + "name": "freezed", + "rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/freezed-3.0.4", + "packageUri": "lib/", + "languageVersion": "3.6" + }, + { + "name": "freezed_annotation", + "rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/freezed_annotation-3.0.0", + "packageUri": "lib/", + "languageVersion": "3.0" + }, { "name": "frontend_server_client", "rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/frontend_server_client-4.0.0", @@ -153,9 +165,9 @@ }, { "name": "http", - "rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/http-0.13.6", + "rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/http-1.3.0", "packageUri": "lib/", - "languageVersion": "2.19" + "languageVersion": "3.4" }, { "name": "http_multi_server", @@ -188,10 +200,16 @@ "languageVersion": "3.0" }, { - "name": "lints", - "rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/lints-2.1.1", + "name": "json_serializable", + "rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/json_serializable-6.9.4", "packageUri": "lib/", - "languageVersion": "3.0" + "languageVersion": "3.6" + }, + { + "name": "lints", + "rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/lints-5.1.1", + "packageUri": "lib/", + "languageVersion": "3.6" }, { "name": "logging", @@ -279,9 +297,9 @@ }, { "name": "shelf_web_socket", - "rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/shelf_web_socket-2.0.1", + "rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/shelf_web_socket-3.0.0", "packageUri": "lib/", - "languageVersion": "3.3" + "languageVersion": "3.5" }, { "name": "source_gen", @@ -289,6 +307,12 @@ "packageUri": "lib/", "languageVersion": "3.6" }, + { + "name": "source_helper", + "rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/source_helper-1.3.5", + "packageUri": "lib/", + "languageVersion": "3.4" + }, { "name": "source_map_stack_trace", "rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/source_map_stack_trace-2.1.2", @@ -307,6 +331,12 @@ "packageUri": "lib/", "languageVersion": "3.1" }, + { + "name": "sprintf", + "rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/sprintf-7.0.0", + "packageUri": "lib/", + "languageVersion": "2.12" + }, { "name": "stack_trace", "rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/stack_trace-1.12.1", @@ -369,9 +399,9 @@ }, { "name": "uuid", - "rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/uuid-3.0.7", + "rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/uuid-4.5.1", "packageUri": "lib/", - "languageVersion": "2.12" + "languageVersion": "3.0" }, { "name": "vm_service", @@ -387,13 +417,19 @@ }, { "name": "web", - "rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/web-0.5.1", + "rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/web-1.1.1", + "packageUri": "lib/", + "languageVersion": "3.4" + }, + { + "name": "web_socket", + "rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/web_socket-0.1.6", "packageUri": "lib/", "languageVersion": "3.3" }, { "name": "web_socket_channel", - "rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/web_socket_channel-2.4.5", + "rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/web_socket_channel-3.0.2", "packageUri": "lib/", "languageVersion": "3.3" }, @@ -416,10 +452,8 @@ "languageVersion": "3.0" } ], - "generated": "2025-03-20T10:48:40.871849Z", + "generated": "2025-03-20T13:27:48.930532Z", "generator": "pub", "generatorVersion": "3.7.2", - "flutterRoot": "file:///home/menno/.flutter/flutter", - "flutterVersion": "3.29.2", "pubCache": "file:///home/menno/.pub-cache" } diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c7f77dc --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.dart_tool \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..40c7487 --- /dev/null +++ b/Makefile @@ -0,0 +1,4 @@ +default: + +build-runner: + dart run build_runner build --delete-conflicting-outputs \ No newline at end of file diff --git a/comfyui-api-sdk.dart b/comfyui-api-sdk.dart deleted file mode 100644 index 69477e1..0000000 --- a/comfyui-api-sdk.dart +++ /dev/null @@ -1,205 +0,0 @@ -import 'dart:async'; -import 'dart:convert'; - -import 'package:http/http.dart' as http; -import 'package:uuid/uuid.dart'; -import 'package:web_socket_channel/web_socket_channel.dart'; - -class ComfyUiApi { - final String host; - final String clientId; - final http.Client _httpClient; - WebSocketChannel? _wsChannel; - final StreamController> _progressController = - StreamController.broadcast(); - - /// Stream of progress updates from ComfyUI - Stream> get progressUpdates => - _progressController.stream; - - /// Creates a new ComfyUI API client - /// - /// [host] The host of the ComfyUI server (e.g. 'http://localhost:8188') - /// [clientId] Optional client ID, will be automatically generated if not provided - ComfyUiApi({ - required this.host, - String? clientId, - http.Client? httpClient, - }) : clientId = clientId ?? const Uuid().v4(), - _httpClient = httpClient ?? http.Client(); - - /// Connects to the WebSocket for progress updates - Future connectWebSocket() async { - final wsUrl = - 'ws://${host.replaceFirst(RegExp(r'^https?://'), '')}/ws?clientId=$clientId'; - _wsChannel = WebSocketChannel.connect(Uri.parse(wsUrl)); - - _wsChannel!.stream.listen((message) { - final data = jsonDecode(message); - _progressController.add(data); - }, onError: (error) { - print('WebSocket error: $error'); - }, onDone: () { - print('WebSocket connection closed'); - }); - } - - /// Closes the WebSocket connection and cleans up resources - void dispose() { - _wsChannel?.sink.close(); - _progressController.close(); - _httpClient.close(); - } - - /// Gets the current queue status - Future> getQueue() async { - final response = await _httpClient.get(Uri.parse('$host/queue')); - _validateResponse(response); - return jsonDecode(response.body); - } - - /// Gets the history of the queue - Future> getHistory({int maxItems = 64}) async { - final response = await _httpClient - .get(Uri.parse('$host/api/history?max_items=$maxItems')); - _validateResponse(response); - return jsonDecode(response.body); - } - - /// Gets image data by filename - Future> getImage(String filename) async { - final response = - await _httpClient.get(Uri.parse('$host/api/view?filename=$filename')); - _validateResponse(response); - return response.bodyBytes; - } - - /// Gets a list of all available models - Future> getModels() async { - final response = - await _httpClient.get(Uri.parse('$host/api/experiment/models')); - _validateResponse(response); - return jsonDecode(response.body); - } - - /// Gets a list of checkpoints - Future> getCheckpoints() async { - final response = await _httpClient - .get(Uri.parse('$host/api/experiment/models/checkpoints')); - _validateResponse(response); - return jsonDecode(response.body); - } - - /// Gets details for a specific checkpoint - Future> getCheckpointDetails( - String pathAndFileName) async { - final response = await _httpClient.get(Uri.parse( - '$host/api/view_metadata/checkpoints?filename=$pathAndFileName')); - _validateResponse(response); - return jsonDecode(response.body); - } - - /// Gets a list of LoRAs - Future> getLoras() async { - final response = - await _httpClient.get(Uri.parse('$host/api/experiment/models/loras')); - _validateResponse(response); - return jsonDecode(response.body); - } - - /// Gets details for a specific LoRA - Future> getLoraDetails(String pathAndFileName) async { - final response = await _httpClient.get( - Uri.parse('$host/api/view_metadata/loras?filename=$pathAndFileName')); - _validateResponse(response); - return jsonDecode(response.body); - } - - /// Gets a list of VAEs - Future> getVaes() async { - final response = - await _httpClient.get(Uri.parse('$host/api/experiment/models/vae')); - _validateResponse(response); - return jsonDecode(response.body); - } - - /// Gets details for a specific VAE - Future> getVaeDetails(String pathAndFileName) async { - final response = await _httpClient.get( - Uri.parse('$host/api/view_metadata/vae?filename=$pathAndFileName')); - _validateResponse(response); - return jsonDecode(response.body); - } - - /// Gets a list of upscale models - Future> getUpscaleModels() async { - final response = await _httpClient - .get(Uri.parse('$host/api/experiment/models/upscale_models')); - _validateResponse(response); - return jsonDecode(response.body); - } - - /// Gets details for a specific upscale model - Future> getUpscaleModelDetails( - String pathAndFileName) async { - final response = await _httpClient.get(Uri.parse( - '$host/api/view_metadata/upscale_models?filename=$pathAndFileName')); - _validateResponse(response); - return jsonDecode(response.body); - } - - /// Gets a list of embeddings - Future> getEmbeddings() async { - final response = await _httpClient - .get(Uri.parse('$host/api/experiment/models/embeddings')); - _validateResponse(response); - return jsonDecode(response.body); - } - - /// Gets details for a specific embedding - Future> getEmbeddingDetails( - String pathAndFileName) async { - final response = await _httpClient.get(Uri.parse( - '$host/api/view_metadata/embeddings?filename=$pathAndFileName')); - _validateResponse(response); - return jsonDecode(response.body); - } - - /// Gets information about all available objects (nodes) - Future> getObjectInfo() async { - final response = await _httpClient.get(Uri.parse('$host/api/object_info')); - _validateResponse(response); - return jsonDecode(response.body); - } - - /// Submits a prompt (workflow) to generate an image - Future> submitPrompt(Map prompt) async { - final response = await _httpClient.post( - Uri.parse('$host/api/prompt'), - headers: {'Content-Type': 'application/json'}, - body: jsonEncode(prompt), - ); - _validateResponse(response); - return jsonDecode(response.body); - } - - /// Validates HTTP response and throws an exception if needed - void _validateResponse(http.Response response) { - if (response.statusCode < 200 || response.statusCode >= 300) { - throw ComfyUiApiException( - statusCode: response.statusCode, - message: 'API request failed: ${response.body}'); - } - } -} - -/// Exception thrown when the ComfyUI API returns an error -class ComfyUiApiException implements Exception { - final int statusCode; - final String message; - - ComfyUiApiException({required this.statusCode, required this.message}); - - @override - String toString() => 'ComfyUiApiException: $statusCode - $message'; -} diff --git a/example/example.dart b/example/example.dart deleted file mode 100644 index 8700b4c..0000000 --- a/example/example.dart +++ /dev/null @@ -1,94 +0,0 @@ -import 'dart:io'; -import 'package:comfyui_api_sdk/comfyui_api_sdk.dart'; - -void main() async { - // Create the API client - final api = ComfyUiApi(host: 'http://mennos-server:7860'); - - // Connect to the WebSocket for progress updates - await api.connectWebSocket(); - - // Listen for progress updates - api.progressUpdates.listen((update) { - print('Progress update: $update'); - }); - - // Get available checkpoints - final checkpoints = await api.getCheckpoints(); - print('Available checkpoints: ${checkpoints.keys.join(', ')}'); - - // Get queue status - final queue = await api.getQueue(); - print('Queue status: $queue'); - - // Submit a basic text-to-image prompt - final promptWorkflow = { - "prompt": { - "3": { - "inputs": { - "seed": 123456789, - "steps": 20, - "cfg": 7, - "sampler_name": "euler_ancestral", - "scheduler": "normal", - "denoise": 1, - "model": ["4", 0], - "positive": ["6", 0], - "negative": ["7", 0], - "latent_image": ["5", 0] - }, - "class_type": "KSampler" - }, - "4": { - "inputs": {"ckpt_name": "dreamshaper_8.safetensors"}, - "class_type": "CheckpointLoaderSimple" - }, - "5": { - "inputs": {"width": 512, "height": 512, "batch_size": 1}, - "class_type": "EmptyLatentImage" - }, - "6": { - "inputs": { - "text": "a beautiful landscape with mountains and a lake", - "clip": ["4", 1] - }, - "class_type": "CLIPTextEncode" - }, - "7": { - "inputs": { - "text": "ugly, blurry, low quality", - "clip": ["4", 1] - }, - "class_type": "CLIPTextEncode" - }, - "8": { - "inputs": { - "samples": ["3", 0], - "vae": ["4", 2] - }, - "class_type": "VAEDecode" - }, - "9": { - "inputs": { - "filename_prefix": "ComfyUI", - "images": ["8", 0] - }, - "class_type": "SaveImage" - } - }, - "client_id": api.clientId - }; - - try { - final result = await api.submitPrompt(promptWorkflow); - print('Prompt submitted: $result'); - } catch (e) { - print('Error submitting prompt: $e'); - } - - // Wait for some time to receive WebSocket messages - await Future.delayed(Duration(seconds: 60)); - - // Clean up - api.dispose(); -} diff --git a/lib/comfyui_api_sdk.dart b/lib/comfyui_api_sdk.dart index 085ee3c..0f9ccc9 100644 --- a/lib/comfyui_api_sdk.dart +++ b/lib/comfyui_api_sdk.dart @@ -1,4 +1,6 @@ library comfyui_api_sdk; export 'src/comfyui_api.dart'; -export 'src/models/models.dart'; +export 'src/models/websocket_event.dart'; +export 'src/models/progress_event.dart'; +export 'src/models/execution_event.dart'; diff --git a/lib/src/comfyui_api.dart b/lib/src/comfyui_api.dart index e9364fa..98d76e7 100644 --- a/lib/src/comfyui_api.dart +++ b/lib/src/comfyui_api.dart @@ -5,22 +5,69 @@ import 'package:http/http.dart' as http; import 'package:uuid/uuid.dart'; import 'package:web_socket_channel/web_socket_channel.dart'; +import 'models/websocket_event.dart'; +import 'models/progress_event.dart'; +import 'models/execution_event.dart'; + +/// Callback function type for prompt events +typedef PromptEventCallback = void Function(String promptId); + +/// Callback function type for typed WebSocket events +typedef WebSocketEventCallback = void Function(WebSocketEvent event); + +/// Callback function type for progress events +typedef ProgressEventCallback = void Function(ProgressEvent event); + /// A Dart SDK for interacting with the ComfyUI API class ComfyUiApi { final String host; final String clientId; final http.Client _httpClient; WebSocketChannel? _wsChannel; + + // Controllers for different event streams + final StreamController _eventController = + StreamController.broadcast(); final StreamController> _progressController = StreamController.broadcast(); - /// Stream of progress updates from ComfyUI + // Add new controllers for specific event types + final StreamController _progressEventController = + StreamController.broadcast(); + final StreamController _executionEventController = + StreamController.broadcast(); + + /// Stream of typed progress events + Stream get progressEvents => _progressEventController.stream; + + /// Stream of typed execution events + Stream get executionEvents => + _executionEventController.stream; + + // Event callbacks + final Map> _eventCallbacks = { + 'onPromptStart': [], + 'onPromptFinished': [], + }; + + final Map> + _typedEventCallbacks = { + for (var type in WebSocketEventType.values) type: [], + }; + + // Add a separate map for progress event callbacks + final List _progressEventCallbacks = []; + + /// Stream of typed WebSocket events + Stream get events => _eventController.stream; + + /// Stream of progress updates from ComfyUI (legacy format) Stream> get progressUpdates => _progressController.stream; /// Creates a new ComfyUI API client /// - /// [host] The host of the ComfyUI server (e.g. 'http://localhost:8188') + /// [host] The host of the ComfyUI server (e.g. 'http://localhost:7860') /// [clientId] Optional client ID, will be automatically generated if not provided ComfyUiApi({ required this.host, @@ -29,6 +76,26 @@ class ComfyUiApi { }) : clientId = clientId ?? const Uuid().v4(), _httpClient = httpClient ?? http.Client(); + /// Register a callback for when a prompt starts processing + void onPromptStart(PromptEventCallback callback) { + _eventCallbacks['onPromptStart']!.add(callback); + } + + /// Register a callback for when a prompt finishes processing + void onPromptFinished(PromptEventCallback callback) { + _eventCallbacks['onPromptFinished']!.add(callback); + } + + /// Register a callback for specific WebSocket event types + void onEventType(WebSocketEventType type, WebSocketEventCallback callback) { + _typedEventCallbacks[type]!.add(callback); + } + + /// Register a callback for progress updates + void onProgressChanged(ProgressEventCallback callback) { + _progressEventCallbacks.add(callback); + } + /// Connects to the WebSocket for progress updates Future connectWebSocket() async { final wsUrl = @@ -36,8 +103,43 @@ class ComfyUiApi { _wsChannel = WebSocketChannel.connect(Uri.parse(wsUrl)); _wsChannel!.stream.listen((message) { - final data = jsonDecode(message); - _progressController.add(data); + final jsonData = jsonDecode(message); + + // Create a typed event + final event = WebSocketEvent.fromJson(jsonData); + + // Add to the typed event stream + _eventController.add(event); + + // Also add to the legacy progress stream + _progressController.add(jsonData); + + // Convert to more specific event types if possible + if (event.eventType == WebSocketEventType.progress) { + _tryCreateProgressEvent(event); + } else if ([ + WebSocketEventType.executionStart, + WebSocketEventType.executionSuccess, + WebSocketEventType.executionCached, + WebSocketEventType.executed, + WebSocketEventType.executing, + ].contains(event.eventType)) { + _tryCreateExecutionEvent(event); + } + + // Trigger event type specific callbacks + for (final callback in _typedEventCallbacks[event.eventType]!) { + callback(event); + } + + // Handle execution_success event (prompt finished) + if (event.eventType == WebSocketEventType.executionSuccess && + event.promptId != null) { + final promptId = event.promptId!; + for (final callback in _eventCallbacks['onPromptFinished']!) { + callback(promptId); + } + } }, onError: (error) { print('WebSocket error: $error'); }, onDone: () { @@ -45,10 +147,55 @@ class ComfyUiApi { }); } + /// Attempts to create a ProgressEvent from a WebSocketEvent + void _tryCreateProgressEvent(WebSocketEvent event) { + if (event.data.containsKey('value') && + event.data.containsKey('max') && + event.data.containsKey('prompt_id')) { + try { + final progressEvent = ProgressEvent( + value: event.data['value'] as int, + max: event.data['max'] as int, + promptId: event.data['prompt_id'] as String, + node: event.data['node']?.toString(), + ); + _progressEventController.add(progressEvent); + + // Trigger all registered progress callbacks + for (final callback in _progressEventCallbacks) { + callback(progressEvent); + } + } catch (e) { + print('Error creating ProgressEvent: $e'); + } + } + } + + /// Attempts to create an ExecutionEvent from a WebSocketEvent + void _tryCreateExecutionEvent(WebSocketEvent event) { + if (event.data.containsKey('prompt_id')) { + try { + final executionEvent = ExecutionEvent( + promptId: event.data['prompt_id'] as String, + timestamp: event.data['timestamp'] as int? ?? + DateTime.now().millisecondsSinceEpoch, + node: event.data['node']?.toString(), + extra: event.data['extra'] as Map?, + ); + _executionEventController.add(executionEvent); + } catch (e) { + print('Error creating ExecutionEvent: $e'); + } + } + } + /// Closes the WebSocket connection and cleans up resources void dispose() { _wsChannel?.sink.close(); _progressController.close(); + _eventController.close(); + _progressEventController.close(); + _executionEventController.close(); _httpClient.close(); } @@ -181,7 +328,17 @@ class ComfyUiApi { body: jsonEncode(prompt), ); _validateResponse(response); - return jsonDecode(response.body); + final responseData = jsonDecode(response.body); + + // Trigger onPromptStart event if prompt_id exists + if (responseData.containsKey('prompt_id')) { + final promptId = responseData['prompt_id']; + for (final callback in _eventCallbacks['onPromptStart']!) { + callback(promptId); + } + } + + return responseData; } /// Validates HTTP response and throws an exception if needed diff --git a/lib/src/comfyui_api_sdk.dart b/lib/src/comfyui_api_sdk.dart new file mode 100644 index 0000000..e2b1c92 --- /dev/null +++ b/lib/src/comfyui_api_sdk.dart @@ -0,0 +1,7 @@ +// Main API +export 'comfyui_api.dart'; + +// Models +export 'models/websocket_event.dart'; +export 'models/progress_event.dart'; +export 'models/execution_event.dart'; diff --git a/lib/src/models/callbacks.dart b/lib/src/models/callbacks.dart new file mode 100644 index 0000000..2fb375a --- /dev/null +++ b/lib/src/models/callbacks.dart @@ -0,0 +1,11 @@ +import 'websocket_event.dart'; +import 'progress_event.dart'; + +/// Callback function type for prompt events +typedef PromptEventCallback = void Function(String promptId); + +/// Callback function type for typed WebSocket events +typedef WebSocketEventCallback = void Function(WebSocketEvent event); + +/// Callback function type for progress events +typedef ProgressEventCallback = void Function(ProgressEvent event); diff --git a/lib/src/models/execution_event.dart b/lib/src/models/execution_event.dart new file mode 100644 index 0000000..909f402 --- /dev/null +++ b/lib/src/models/execution_event.dart @@ -0,0 +1,31 @@ +class ExecutionEvent { + final String promptId; + final int timestamp; + final String? node; + final Map? extra; + + const ExecutionEvent({ + required this.promptId, + required this.timestamp, + this.node, + this.extra, + }); + + factory ExecutionEvent.fromJson(Map json) { + return ExecutionEvent( + promptId: json['prompt_id'] as String, + timestamp: json['timestamp'] as int, + node: json['node'] as String?, + extra: json['extra'] as Map?, + ); + } + + Map toJson() { + return { + 'prompt_id': promptId, + 'timestamp': timestamp, + 'node': node, + 'extra': extra, + }; + } +} diff --git a/lib/src/models/progress_event.dart b/lib/src/models/progress_event.dart new file mode 100644 index 0000000..8aa34fd --- /dev/null +++ b/lib/src/models/progress_event.dart @@ -0,0 +1,39 @@ +class ProgressEvent { + final int value; + final int max; + final String promptId; + final String? node; + + const ProgressEvent({ + required this.value, + required this.max, + required this.promptId, + this.node, + }); + + factory ProgressEvent.fromJson(Map json) { + return ProgressEvent( + value: json['value'] as int, + max: json['max'] as int, + promptId: json['prompt_id'] as String, + node: json['node'] as String?, + ); + } + + Map toJson() { + return { + 'value': value, + 'max': max, + 'prompt_id': promptId, + 'node': node, + }; + } +} + +extension ProgressEventComputation on ProgressEvent { + /// Returns the progress percentage (0-100) + double get percentage => (value / max) * 100; + + /// Returns true if the progress is complete + bool get isComplete => value >= max; +} diff --git a/lib/src/models/websocket_event.dart b/lib/src/models/websocket_event.dart new file mode 100644 index 0000000..9751a4b --- /dev/null +++ b/lib/src/models/websocket_event.dart @@ -0,0 +1,109 @@ +/// Types of WebSocket events from ComfyUI +enum WebSocketEventType { + status, + progress, + executing, + executed, + executionStart, + executionCached, + executionSuccess, + executionError, + dataOutput, + unknown +} + +/// A typed event from the ComfyUI WebSocket +class WebSocketEvent { + final WebSocketEventType eventType; + final Map data; + final String? promptId; + final String rawType; + + WebSocketEvent({ + required this.eventType, + required this.data, + this.promptId, + required this.rawType, + }); + + @override + String toString() => + 'WebSocketEvent{type: $eventType, rawType: $rawType, promptId: $promptId, data: ${data.keys.join(', ')}}'; + + /// Creates a WebSocketEvent from JSON + factory WebSocketEvent.fromJson(Map json) { + WebSocketEventType type; + Map eventData = {}; + String? promptId; + String rawType = json['type']?.toString() ?? 'unknown'; + + // Extract event type + if (json.containsKey('type')) { + final typeStr = json['type'].toString(); + + // First, try to extract prompt_id from the data field + if (json.containsKey('data') && json['data'] is Map) { + final dataMap = Map.from(json['data'] as Map); + if (dataMap.containsKey('prompt_id')) { + promptId = dataMap['prompt_id'].toString(); + } + eventData = dataMap; + } else { + // If no data field, use the root object + eventData = Map.from(json); + } + + // Try to extract prompt_id from other potential locations + if (promptId == null && json.containsKey('prompt_id')) { + promptId = json['prompt_id'].toString(); + } + + switch (typeStr) { + case 'status': + type = WebSocketEventType.status; + break; + case 'progress': + type = WebSocketEventType.progress; + break; + case 'executing': + type = WebSocketEventType.executing; + break; + case 'executed': + type = WebSocketEventType.executed; + break; + case 'execution_start': + type = WebSocketEventType.executionStart; + break; + case 'execution_cached': + type = WebSocketEventType.executionCached; + break; + case 'execution_success': + type = WebSocketEventType.executionSuccess; + break; + case 'execution_error': + type = WebSocketEventType.executionError; + break; + default: + if (typeStr.startsWith('data_output')) { + type = WebSocketEventType.dataOutput; + } else { + // Default to unknown for unrecognized types + type = WebSocketEventType.unknown; + print('Unknown event type: $typeStr'); + } + } + } else { + // If no type field, default to unknown + type = WebSocketEventType.unknown; + print('WebSocket event missing type field'); + eventData = Map.from(json); + } + + return WebSocketEvent( + eventType: type, + data: eventData, + promptId: promptId, + rawType: rawType, + ); + } +} diff --git a/pubspec.lock b/pubspec.lock index fa0b0dc..43e24e0 100644 --- a/pubspec.lock +++ b/pubspec.lock @@ -74,13 +74,13 @@ packages: source: hosted version: "2.4.4" build_runner: - dependency: "direct main" + dependency: "direct dev" description: name: build_runner - sha256: "74691599a5bc750dc96a6b4bfd48f7d9d66453eab04c7f4063134800d6a5c573" + sha256: "058fe9dce1de7d69c4b84fada934df3e0153dd000758c4d65964d0166779aa99" url: "https://pub.dev" source: hosted - version: "2.4.14" + version: "2.4.15" build_runner_core: dependency: transitive description: @@ -177,6 +177,22 @@ packages: url: "https://pub.dev" source: hosted version: "1.1.1" + freezed: + dependency: "direct dev" + description: + name: freezed + sha256: "7ed2ddaa47524976d5f2aa91432a79da36a76969edd84170777ac5bea82d797c" + url: "https://pub.dev" + source: hosted + version: "3.0.4" + freezed_annotation: + dependency: "direct main" + description: + name: freezed_annotation + sha256: c87ff004c8aa6af2d531668b46a4ea379f7191dc6dfa066acd53d506da6e044b + url: "https://pub.dev" + source: hosted + version: "3.0.0" frontend_server_client: dependency: transitive description: @@ -205,10 +221,10 @@ packages: dependency: "direct main" description: name: http - sha256: "5895291c13fa8a3bd82e76d5627f69e0d85ca6a30dcac95c4ea19a5d555879c2" + sha256: fe7ab022b76f3034adc518fb6ea04a82387620e19977665ea18d30a1cf43442f url: "https://pub.dev" source: hosted - version: "0.13.6" + version: "1.3.0" http_multi_server: dependency: transitive description: @@ -242,21 +258,29 @@ packages: source: hosted version: "0.7.2" json_annotation: - dependency: transitive + dependency: "direct main" description: name: json_annotation sha256: "1ce844379ca14835a50d2f019a3099f419082cfdd231cd86a142af94dd5c6bb1" url: "https://pub.dev" source: hosted version: "4.9.0" + json_serializable: + dependency: "direct dev" + description: + name: json_serializable + sha256: "81f04dee10969f89f604e1249382d46b97a1ccad53872875369622b5bfc9e58a" + url: "https://pub.dev" + source: hosted + version: "6.9.4" lints: dependency: "direct dev" description: name: lints - sha256: "0a217c6c989d21039f1498c3ed9f3ed71b354e69873f13a8dfc3c9fe76f1b452" + sha256: c35bb79562d980e9a453fc715854e1ed39e24e7d0297a880ef54e17f9874a9d7 url: "https://pub.dev" source: hosted - version: "2.1.1" + version: "5.1.1" logging: dependency: transitive description: @@ -290,7 +314,7 @@ packages: source: hosted version: "2.0.0" mockito: - dependency: "direct main" + dependency: "direct dev" description: name: mockito sha256: f99d8d072e249f719a5531735d146d8cf04c580d93920b04de75bef6dfb2daf6 @@ -373,10 +397,10 @@ packages: dependency: transitive description: name: shelf_web_socket - sha256: cc36c297b52866d203dbf9332263c94becc2fe0ceaa9681d07b6ef9807023b67 + sha256: "3632775c8e90d6c9712f883e633716432a27758216dfb61bd86a8321c0580925" url: "https://pub.dev" source: hosted - version: "2.0.1" + version: "3.0.0" source_gen: dependency: transitive description: @@ -385,6 +409,14 @@ packages: url: "https://pub.dev" source: hosted version: "2.0.0" + source_helper: + dependency: transitive + description: + name: source_helper + sha256: "86d247119aedce8e63f4751bd9626fc9613255935558447569ad42f9f5b48b3c" + url: "https://pub.dev" + source: hosted + version: "1.3.5" source_map_stack_trace: dependency: transitive description: @@ -409,6 +441,14 @@ packages: url: "https://pub.dev" source: hosted version: "1.10.1" + sprintf: + dependency: transitive + description: + name: sprintf + sha256: "1fc9ffe69d4df602376b52949af107d8f5703b77cda567c4d7d86a0693120f23" + url: "https://pub.dev" + source: hosted + version: "7.0.0" stack_trace: dependency: transitive description: @@ -493,10 +533,10 @@ packages: dependency: "direct main" description: name: uuid - sha256: "648e103079f7c64a36dc7d39369cabb358d377078a051d6ae2ad3aa539519313" + sha256: a5be9ef6618a7ac1e964353ef476418026db906c4facdedaa299b7a2e71690ff url: "https://pub.dev" source: hosted - version: "3.0.7" + version: "4.5.1" vm_service: dependency: transitive description: @@ -517,18 +557,26 @@ packages: dependency: transitive description: name: web - sha256: "97da13628db363c635202ad97068d47c5b8aa555808e7a9411963c533b449b27" + sha256: "868d88a33d8a87b18ffc05f9f030ba328ffefba92d6c127917a2ba740f9cfe4a" url: "https://pub.dev" source: hosted - version: "0.5.1" + version: "1.1.1" + web_socket: + dependency: transitive + description: + name: web_socket + sha256: "3c12d96c0c9a4eec095246debcea7b86c0324f22df69893d538fcc6f1b8cce83" + url: "https://pub.dev" + source: hosted + version: "0.1.6" web_socket_channel: dependency: "direct main" description: name: web_socket_channel - sha256: "58c6666b342a38816b2e7e50ed0f1e261959630becd4c879c4f26bfa14aa5a42" + sha256: "0b8e2457400d8a859b7b2030786835a28a8e80836ef64402abef392ff4f1d0e5" url: "https://pub.dev" source: hosted - version: "2.4.5" + version: "3.0.2" webkit_inspection_protocol: dependency: transitive description: diff --git a/pubspec.yaml b/pubspec.yaml index db30a66..c8f38dd 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -6,12 +6,16 @@ environment: sdk: '>=3.0.0 <4.0.0' dependencies: - http: ^0.13.5 - web_socket_channel: ^2.3.0 - uuid: ^3.0.7 - mockito: ^5.4.5 - build_runner: ^2.4.14 + http: ^1.3.0 + web_socket_channel: ^3.0.2 + uuid: ^4.5.1 + json_annotation: ^4.9.0 + freezed_annotation: ^3.0.0 dev_dependencies: - lints: ^2.0.0 + lints: ^5.1.1 test: ^1.21.0 + mockito: ^5.4.5 + build_runner: ^2.4.14 + freezed: ^3.0.4 + json_serializable: ^6.7.1 \ No newline at end of file diff --git a/test/comfyui_api_test.dart b/test/comfyui_api_test.dart deleted file mode 100644 index da18aa3..0000000 --- a/test/comfyui_api_test.dart +++ /dev/null @@ -1,197 +0,0 @@ -import 'dart:convert'; -import 'package:comfyui_api_sdk/comfyui_api_sdk.dart'; -import 'package:http/http.dart' as http; -import 'package:http/testing.dart'; -import 'package:mockito/annotations.dart'; -import 'package:mockito/mockito.dart'; -import 'package:test/test.dart'; -import 'package:web_socket_channel/web_socket_channel.dart'; - -import 'comfyui_api_test.mocks.dart'; -import 'test_data.dart'; - -@GenerateMocks([http.Client, WebSocketChannel, WebSocketSink]) -void main() { - late MockClient mockClient; - late ComfyUiApi api; - const String testHost = 'http://localhost:8188'; - const String testClientId = 'test-client-id'; - - setUp(() { - mockClient = MockClient(); - api = ComfyUiApi( - host: testHost, - clientId: testClientId, - httpClient: mockClient, - ); - }); - - group('ComfyUiApi', () { - test('initialize with provided values', () { - expect(api.host, equals(testHost)); - expect(api.clientId, equals(testClientId)); - }); - - test('initialize with generated clientId when not provided', () { - final autoApi = ComfyUiApi(host: testHost, httpClient: mockClient); - expect(autoApi.clientId, isNotEmpty); - expect(autoApi.clientId, isNot(equals(testClientId))); - }); - - test('getQueue returns parsed response', () async { - when(mockClient.get(Uri.parse('$testHost/queue'))).thenAnswer( - (_) async => http.Response(jsonEncode(TestData.queueResponse), 200)); - - final result = await api.getQueue(); - - expect(result, equals(TestData.queueResponse)); - verify(mockClient.get(Uri.parse('$testHost/queue'))).called(1); - }); - - test('getHistory returns parsed response', () async { - when(mockClient.get(Uri.parse('$testHost/api/history?max_items=64'))) - .thenAnswer((_) async => - http.Response(jsonEncode(TestData.historyResponse), 200)); - - final result = await api.getHistory(); - - expect(result, equals(TestData.historyResponse)); - verify(mockClient.get(Uri.parse('$testHost/api/history?max_items=64'))) - .called(1); - }); - - test('getImage returns image bytes', () async { - final bytes = [1, 2, 3, 4]; - when(mockClient.get(Uri.parse('$testHost/api/view?filename=test.png'))) - .thenAnswer((_) async => http.Response.bytes(bytes, 200)); - - final result = await api.getImage('test.png'); - - expect(result, equals(bytes)); - verify(mockClient.get(Uri.parse('$testHost/api/view?filename=test.png'))) - .called(1); - }); - - test('getCheckpoints returns parsed response', () async { - when(mockClient - .get(Uri.parse('$testHost/api/experiment/models/checkpoints'))) - .thenAnswer((_) async => - http.Response(jsonEncode(TestData.checkpointsResponse), 200)); - - final result = await api.getCheckpoints(); - - expect(result, equals(TestData.checkpointsResponse)); - verify(mockClient - .get(Uri.parse('$testHost/api/experiment/models/checkpoints'))) - .called(1); - }); - - test('getCheckpointDetails returns parsed response', () async { - const filename = 'models/checkpoints/test.safetensors'; - when(mockClient.get(Uri.parse( - '$testHost/api/view_metadata/checkpoints?filename=$filename'))) - .thenAnswer((_) async => http.Response( - jsonEncode(TestData.checkpointMetadataResponse), 200)); - - final result = await api.getCheckpointDetails(filename); - - expect(result, equals(TestData.checkpointMetadataResponse)); - verify(mockClient.get(Uri.parse( - '$testHost/api/view_metadata/checkpoints?filename=$filename'))) - .called(1); - }); - - test('getLoras returns parsed response', () async { - when(mockClient.get(Uri.parse('$testHost/api/experiment/models/loras'))) - .thenAnswer((_) async => - http.Response(jsonEncode(TestData.lorasResponse), 200)); - - final result = await api.getLoras(); - - expect(result, equals(TestData.lorasResponse)); - verify(mockClient.get(Uri.parse('$testHost/api/experiment/models/loras'))) - .called(1); - }); - - test('getVaes returns parsed response', () async { - when(mockClient.get(Uri.parse('$testHost/api/experiment/models/vae'))) - .thenAnswer((_) async => - http.Response(jsonEncode(TestData.vaeResponse), 200)); - - final result = await api.getVaes(); - - expect(result, equals(TestData.vaeResponse)); - verify(mockClient.get(Uri.parse('$testHost/api/experiment/models/vae'))) - .called(1); - }); - - test('getObjectInfo returns parsed response', () async { - when(mockClient.get(Uri.parse('$testHost/api/object_info'))).thenAnswer( - (_) async => - http.Response(jsonEncode(TestData.objectInfoResponse), 200)); - - final result = await api.getObjectInfo(); - - expect(result, equals(TestData.objectInfoResponse)); - verify(mockClient.get(Uri.parse('$testHost/api/object_info'))).called(1); - }); - - test('submitPrompt returns parsed response', () async { - when(mockClient.post( - Uri.parse('$testHost/api/prompt'), - headers: {'Content-Type': 'application/json'}, - body: jsonEncode(TestData.promptRequest), - )).thenAnswer( - (_) async => http.Response(jsonEncode(TestData.promptResponse), 200)); - - final result = await api.submitPrompt(TestData.promptRequest); - - expect(result, equals(TestData.promptResponse)); - verify(mockClient.post( - Uri.parse('$testHost/api/prompt'), - headers: {'Content-Type': 'application/json'}, - body: jsonEncode(TestData.promptRequest), - )).called(1); - }); - - test('throws ComfyUiApiException on error response', () async { - when(mockClient.get(Uri.parse('$testHost/queue'))) - .thenAnswer((_) async => http.Response('Error message', 500)); - - expect(() => api.getQueue(), throwsA(isA())); - }); - }); - - group('Models', () { - test('QueueInfo parses from JSON correctly', () { - final queueInfo = QueueInfo.fromJson(TestData.queueResponse); - - expect(queueInfo.queueRunning, equals(0)); - expect(queueInfo.queue.length, equals(0)); - expect(queueInfo.queuePending, isA>()); - }); - - test('PromptExecutionStatus parses from JSON correctly', () { - final status = PromptExecutionStatus.fromJson(TestData.promptResponse); - - expect(status.promptId, equals('123456789')); - expect(status.number, equals(1)); - expect(status.status, equals('success')); - }); - - test('HistoryItem parses from JSON correctly', () { - final item = HistoryItem.fromJson(TestData.historyItemResponse); - - expect(item.promptId, equals('123456789')); - expect(item.prompt, isA>()); - expect(item.outputs, isA>()); - }); - - test('ProgressUpdate parses from JSON correctly', () { - final update = ProgressUpdate.fromJson(TestData.progressUpdateResponse); - - expect(update.type, equals('execution_start')); - expect(update.data, isA>()); - }); - }); -} diff --git a/test/integration_test.dart b/test/integration_test.dart deleted file mode 100644 index 6e17f00..0000000 --- a/test/integration_test.dart +++ /dev/null @@ -1,212 +0,0 @@ -import 'dart:convert'; -import 'dart:io'; -import 'package:comfyui_api_sdk/comfyui_api_sdk.dart'; -import 'package:http/http.dart' as http; -import 'package:http/testing.dart'; -import 'package:test/test.dart'; - -void main() { - late ComfyUiApi api; - late MockClient mockClient; - - const String testHost = 'http://localhost:8188'; - - setUp(() { - // Setup a MockClient that simulates real API responses - mockClient = MockClient((request) async { - final uri = request.url; - final method = request.method; - - // Simulate queue endpoint - if (uri.path == '/queue' && method == 'GET') { - return http.Response( - jsonEncode({'queue_running': 0, 'queue': [], 'queue_pending': {}}), - 200, - ); - } - - // Simulate history endpoint - if (uri.path == '/api/history' && method == 'GET') { - return http.Response( - jsonEncode({ - 'History': { - '123456789': { - 'prompt': { - // Simplified prompt data - '1': {'class_type': 'TestNode'} - }, - 'outputs': { - '8': { - 'images': { - 'filename': 'ComfyUI_00001_.png', - 'subfolder': '', - 'type': 'output', - } - } - } - } - } - }), - 200, - ); - } - - // Simulate checkpoint list endpoint - if (uri.path == '/api/experiment/models/checkpoints' && method == 'GET') { - return http.Response( - jsonEncode({ - 'models/checkpoints/dreamshaper_8.safetensors': { - 'filename': 'dreamshaper_8.safetensors', - 'folder': 'models/checkpoints', - } - }), - 200, - ); - } - - // Simulate checkpoint metadata endpoint - if (uri.path == '/api/view_metadata/checkpoints' && method == 'GET') { - return http.Response( - jsonEncode({ - 'model': { - 'type': 'checkpoint', - 'title': 'Dreamshaper 8', - 'hash': 'abcdef1234567890', - } - }), - 200, - ); - } - - // Simulate object info endpoint - if (uri.path == '/api/object_info' && method == 'GET') { - return http.Response( - jsonEncode({ - 'KSampler': { - 'input': { - 'required': { - 'model': 'MODEL', - 'seed': 'INT', - 'steps': 'INT', - } - }, - 'output': ['LATENT'], - 'output_is_list': [false] - } - }), - 200, - ); - } - - // Simulate prompt submission endpoint - if (uri.path == '/api/prompt' && method == 'POST') { - return http.Response( - jsonEncode( - {'prompt_id': '123456789', 'number': 1, 'status': 'success'}), - 200, - ); - } - - // Simulate image view endpoint - if (uri.path == '/api/view' && method == 'GET') { - // Return a dummy image - return http.Response.bytes([1, 2, 3, 4], 200, - headers: { - 'Content-Type': 'image/png', - }); - } - - // Default response for unhandled routes - return http.Response('Not Found', 404); - }); - - // Create the API with our mock client - api = ComfyUiApi( - host: testHost, - clientId: 'integration-test-client', - httpClient: mockClient, - ); - }); - - group('Integration Tests', () { - test('Get queue information', () async { - final queue = await api.getQueue(); - - expect(queue['queue_running'], equals(0)); - expect(queue['queue'], isEmpty); - expect(queue['queue_pending'], isA()); - }); - - test('Get history information', () async { - final history = await api.getHistory(); - - expect(history['History'], isA()); - expect(history['History']['123456789'], isA()); - expect(history['History']['123456789']['outputs'], isA()); - }); - - test('Get checkpoint list', () async { - final checkpoints = await api.getCheckpoints(); - - expect(checkpoints.keys, - contains('models/checkpoints/dreamshaper_8.safetensors')); - expect( - checkpoints['models/checkpoints/dreamshaper_8.safetensors'] - ['filename'], - equals('dreamshaper_8.safetensors')); - }); - - test('Get checkpoint metadata', () async { - final metadata = await api - .getCheckpointDetails('models/checkpoints/dreamshaper_8.safetensors'); - - expect(metadata['model']['type'], equals('checkpoint')); - expect(metadata['model']['title'], equals('Dreamshaper 8')); - }); - - test('Get object info', () async { - final info = await api.getObjectInfo(); - - expect(info['KSampler'], isA()); - expect(info['KSampler']['input']['required']['seed'], equals('INT')); - }); - - test('Submit prompt', () async { - final promptData = { - 'prompt': { - '1': { - 'inputs': {'text': 'A beautiful landscape'}, - 'class_type': 'CLIPTextEncode' - } - }, - 'client_id': 'integration-test-client' - }; - - final result = await api.submitPrompt(promptData); - - expect(result['prompt_id'], equals('123456789')); - expect(result['status'], equals('success')); - }); - - test('Get image', () async { - final imageBytes = await api.getImage('ComfyUI_00001_.png'); - - expect(imageBytes, equals([1, 2, 3, 4])); - }); - - test('Handle error response', () async { - // Create a client that always returns an error - final errorClient = MockClient((_) async { - return http.Response('Server Error', 500); - }); - - final errorApi = ComfyUiApi( - host: testHost, - clientId: 'error-test-client', - httpClient: errorClient, - ); - - expect(() => errorApi.getQueue(), throwsA(isA())); - }); - }); -} diff --git a/test/models_test.dart b/test/models_test.dart deleted file mode 100644 index d8e44c8..0000000 --- a/test/models_test.dart +++ /dev/null @@ -1,127 +0,0 @@ -import 'package:comfyui_api_sdk/comfyui_api_sdk.dart'; -import 'package:test/test.dart'; - -void main() { - group('QueueInfo', () { - test('fromJson creates instance with correct values', () { - final json = { - 'queue_running': 1, - 'queue': [ - {'prompt_id': '123', 'number': 1} - ], - 'queue_pending': { - '456': {'prompt_id': '456', 'number': 2} - } - }; - - final queueInfo = QueueInfo.fromJson(json); - - expect(queueInfo.queueRunning, equals(1)); - expect(queueInfo.queue.length, equals(1)); - expect(queueInfo.queue[0]['prompt_id'], equals('123')); - expect(queueInfo.queuePending['456']['prompt_id'], equals('456')); - }); - - test('fromJson handles missing or empty values', () { - final json = {'queue_running': 0}; - - final queueInfo = QueueInfo.fromJson(json); - - expect(queueInfo.queueRunning, equals(0)); - expect(queueInfo.queue, isEmpty); - expect(queueInfo.queuePending, isEmpty); - }); - }); - - group('PromptExecutionStatus', () { - test('fromJson creates instance with correct values', () { - final json = { - 'prompt_id': 'abc123', - 'number': 5, - 'status': 'processing', - 'error': null - }; - - final status = PromptExecutionStatus.fromJson(json); - - expect(status.promptId, equals('abc123')); - expect(status.number, equals(5)); - expect(status.status, equals('processing')); - expect(status.error, isNull); - }); - - test('fromJson handles error information', () { - final json = { - 'prompt_id': 'abc123', - 'number': 5, - 'status': 'error', - 'error': 'Something went wrong' - }; - - final status = PromptExecutionStatus.fromJson(json); - - expect(status.status, equals('error')); - expect(status.error, equals('Something went wrong')); - }); - }); - - group('HistoryItem', () { - test('fromJson creates instance with correct values', () { - final json = { - 'prompt_id': 'abc123', - 'prompt': { - '1': {'class_type': 'TestNode'} - }, - 'outputs': { - '2': { - 'images': {'filename': 'test.png'} - } - } - }; - - final item = HistoryItem.fromJson(json); - - expect(item.promptId, equals('abc123')); - expect(item.prompt['1']['class_type'], equals('TestNode')); - expect(item.outputs?['2']['images']['filename'], equals('test.png')); - }); - - test('fromJson handles missing outputs', () { - final json = { - 'prompt_id': 'abc123', - 'prompt': { - '1': {'class_type': 'TestNode'} - } - }; - - final item = HistoryItem.fromJson(json); - - expect(item.promptId, equals('abc123')); - expect(item.outputs, isNull); - }); - }); - - group('ProgressUpdate', () { - test('fromJson creates instance with correct values', () { - final json = { - 'type': 'execution_start', - 'data': {'prompt_id': 'abc123', 'node': 5} - }; - - final update = ProgressUpdate.fromJson(json); - - expect(update.type, equals('execution_start')); - expect(update.data['prompt_id'], equals('abc123')); - expect(update.data['node'], equals(5)); - }); - - test('fromJson handles empty data', () { - final json = {'type': 'status', 'data': {}}; - - final update = ProgressUpdate.fromJson(json); - - expect(update.type, equals('status')); - expect(update.data, isEmpty); - }); - }); -} diff --git a/test/test_data.dart b/test/test_data.dart deleted file mode 100644 index eb42bce..0000000 --- a/test/test_data.dart +++ /dev/null @@ -1,148 +0,0 @@ -/// Test data for ComfyUI API tests -class TestData { - /// Mock queue response - static final Map queueResponse = { - 'queue_running': 0, - 'queue': [], - 'queue_pending': {} - }; - - /// Mock history response - static final Map historyResponse = { - 'History': { - '123456789': { - 'prompt': { - // Prompt data - }, - 'outputs': { - '8': { - 'images': { - 'filename': 'ComfyUI_00001_.png', - 'subfolder': '', - 'type': 'output', - } - } - } - } - } - }; - - /// Mock history item - static final Map historyItemResponse = { - 'prompt_id': '123456789', - 'prompt': { - // Prompt data - }, - 'outputs': { - '8': { - 'images': { - 'filename': 'ComfyUI_00001_.png', - 'subfolder': '', - 'type': 'output', - } - } - } - }; - - /// Mock checkpoints response - static final Map checkpointsResponse = { - 'models/checkpoints/dreamshaper_8.safetensors': { - 'filename': 'dreamshaper_8.safetensors', - 'folder': 'models/checkpoints', - }, - 'models/checkpoints/sd_xl_base_1.0.safetensors': { - 'filename': 'sd_xl_base_1.0.safetensors', - 'folder': 'models/checkpoints', - } - }; - - /// Mock checkpoint metadata response - static final Map checkpointMetadataResponse = { - 'model': { - 'type': 'checkpoint', - 'title': 'Dreamshaper 8', - 'filename': 'dreamshaper_8.safetensors', - 'hash': 'abcdef1234567890', - } - }; - - /// Mock LoRAs response - static final Map lorasResponse = { - 'models/loras/example_lora.safetensors': { - 'filename': 'example_lora.safetensors', - 'folder': 'models/loras', - } - }; - - /// Mock VAE response - static final Map vaeResponse = { - 'models/vae/example_vae.safetensors': { - 'filename': 'example_vae.safetensors', - 'folder': 'models/vae', - } - }; - - /// Mock object info response (simplified) - static final Map objectInfoResponse = { - 'CheckpointLoaderSimple': { - 'input': { - 'required': {'ckpt_name': 'STRING'} - }, - 'output': ['MODEL', 'CLIP', 'VAE'], - 'output_is_list': [false, false, false] - }, - 'KSampler': { - 'input': { - 'required': { - 'model': 'MODEL', - 'seed': 'INT', - 'steps': 'INT', - 'cfg': 'FLOAT', - 'sampler_name': 'STRING', - 'scheduler': 'STRING', - 'positive': 'CONDITIONING', - 'negative': 'CONDITIONING', - 'latent_image': 'LATENT' - }, - 'optional': {'denoise': 'FLOAT'} - }, - 'output': ['LATENT'], - 'output_is_list': [false] - } - }; - - /// Mock prompt request - static final Map promptRequest = { - 'prompt': { - '3': { - 'inputs': { - 'seed': 123456789, - 'steps': 20, - 'cfg': 7, - 'sampler_name': 'euler_ancestral', - 'scheduler': 'normal', - 'denoise': 1, - 'model': ['4', 0], - 'positive': ['6', 0], - 'negative': ['7', 0], - 'latent_image': ['5', 0] - }, - 'class_type': 'KSampler' - } - }, - 'client_id': 'test-client-id' - }; - - /// Mock prompt response - static final Map promptResponse = { - 'prompt_id': '123456789', - 'number': 1, - 'status': 'success' - }; - - /// Mock progress update response - static final Map progressUpdateResponse = { - 'type': 'execution_start', - 'data': {'prompt_id': '123456789'} - }; -} diff --git a/test/websocket_test.dart b/test/websocket_test.dart deleted file mode 100644 index 5bc3cc6..0000000 --- a/test/websocket_test.dart +++ /dev/null @@ -1,145 +0,0 @@ -import 'dart:async'; -import 'dart:convert'; - -import 'package:comfyui_api_sdk/comfyui_api_sdk.dart'; -import 'package:http/http.dart' as http; -import 'package:http/testing.dart'; -import 'package:mockito/annotations.dart'; -import 'package:mockito/mockito.dart'; -import 'package:test/test.dart'; -import 'package:web_socket_channel/web_socket_channel.dart'; - -import 'test_data.dart'; -import 'websocket_test.mocks.dart'; - -@GenerateMocks([http.Client, WebSocketChannel, WebSocketSink, Stream]) -void main() { - late MockClient mockClient; - late MockWebSocketChannel mockWebSocketChannel; - late MockWebSocketSink mockWebSocketSink; - late StreamController streamController; - late ComfyUiApi api; - - const String testHost = 'http://localhost:8188'; - const String testClientId = 'test-client-id'; - - setUp(() { - mockClient = MockClient(); - mockWebSocketChannel = MockWebSocketChannel(); - mockWebSocketSink = MockWebSocketSink(); - streamController = StreamController.broadcast(); - - when(mockWebSocketChannel.sink).thenReturn(mockWebSocketSink); - when(mockWebSocketChannel.stream) - .thenAnswer((_) => streamController.stream); - - api = ComfyUiApi( - host: testHost, - clientId: testClientId, - httpClient: mockClient, - ); - }); - - tearDown(() { - streamController.close(); - }); - - group('WebSocket functionality', () { - test('connectWebSocket connects to correct URL', () async { - // Use a spy to capture the URI passed to WebSocketChannel.connect - final wsUrl = 'ws://localhost:8188/ws?clientId=$testClientId'; - - await api.connectWebSocket(); - - // This is a bit tricky to test without modifying the implementation - // In a real test we'd use a different approach or dependency injection - // For now, we'll just verify that the WebSocket URL format is correct - expect(wsUrl, equals('ws://localhost:8188/ws?clientId=$testClientId')); - }); - - test('progressUpdates stream emits data received from WebSocket', () async { - // We need a way to provide a mock WebSocketChannel to the API - // For this test, we'll use a modified approach - - final mockApi = MockComfyUiApi( - host: testHost, - clientId: testClientId, - httpClient: mockClient, - mockWebSocketChannel: mockWebSocketChannel, - ); - - // Connect and verify mock WebSocket is used - await mockApi.connectWebSocket(); - - // Prepare to capture emitted events - final events = >[]; - final subscription = mockApi.progressUpdates.listen(events.add); - - // Send test data through the mock WebSocket - final testData = TestData.progressUpdateResponse; - streamController.add(jsonEncode(testData)); - - // Wait for async processing - await Future.delayed(Duration(milliseconds: 100)); - - // Verify the data was emitted - expect(events.length, equals(1)); - expect(events.first, equals(testData)); - - // Clean up - await subscription.cancel(); - }); - - test('dispose closes WebSocket and stream', () async { - final mockApi = MockComfyUiApi( - host: testHost, - clientId: testClientId, - httpClient: mockClient, - mockWebSocketChannel: mockWebSocketChannel, - ); - - // Connect - await mockApi.connectWebSocket(); - - // Dispose - mockApi.dispose(); - - // Verify WebSocket was closed - verify(mockWebSocketSink.close()).called(1); - }); - }); -} - -/// A modified version of ComfyUiApi for testing that allows injecting a mock WebSocketChannel -class MockComfyUiApi extends ComfyUiApi { - final WebSocketChannel? mockWebSocketChannel; - - MockComfyUiApi({ - required String host, - required String clientId, - required http.Client httpClient, - this.mockWebSocketChannel, - }) : super( - host: host, - clientId: clientId, - httpClient: httpClient, - ); - - @override - Future connectWebSocket() async { - if (mockWebSocketChannel != null) { - _wsChannel = mockWebSocketChannel; - - _wsChannel!.stream.listen((message) { - final data = jsonDecode(message); - _progressController.add(data); - }, onError: (error) { - print('WebSocket error: $error'); - }, onDone: () { - print('WebSocket connection closed'); - }); - } else { - await super.connectWebSocket(); - } - } -}