From 697d2f812d8c69db16dae48a762c5dfd6264104a Mon Sep 17 00:00:00 2001 From: Menno van Leeuwen Date: Thu, 20 Mar 2025 15:18:26 +0000 Subject: [PATCH] Add WebSocketManager class for handling WebSocket connections and events --- lib/src/comfyui_api.dart | 155 ++++------------------------ lib/src/websocket_manager.dart | 179 +++++++++++++++++++++++++++++++++ 2 files changed, 200 insertions(+), 134 deletions(-) create mode 100644 lib/src/websocket_manager.dart diff --git a/lib/src/comfyui_api.dart b/lib/src/comfyui_api.dart index dc0bf1f..d49c53b 100644 --- a/lib/src/comfyui_api.dart +++ b/lib/src/comfyui_api.dart @@ -3,172 +3,61 @@ import 'dart:convert'; import 'package:http/http.dart' as http; import 'package:uuid/uuid.dart'; -import 'package:web_socket_channel/web_socket_channel.dart'; import 'models/websocket_event.dart'; import 'models/progress_event.dart'; import 'models/execution_event.dart'; import 'exceptions/comfyui_api_exception.dart'; import 'types/callback_types.dart'; -import 'utils/websocket_event_handler.dart'; import 'models/history_response.dart'; import 'models/checkpoint.dart'; import 'models/vae.dart'; import 'models/lora.dart'; +import 'websocket_manager.dart'; /// A Dart SDK for interacting with the ComfyUI API class ComfyUiApi { final String host; final String clientId; final http.Client _httpClient; - WebSocketChannel? _wsChannel; - - // Controllers for different event streams - final StreamController _eventController = - StreamController.broadcast(); - final StreamController> _progressController = - StreamController.broadcast(); - - // Add new controllers for specific event types - final StreamController _progressEventController = - StreamController.broadcast(); - final StreamController _executionEventController = - StreamController.broadcast(); - - /// Stream of typed progress events - Stream get progressEvents => _progressEventController.stream; - - /// Stream of typed execution events - Stream get executionEvents => - _executionEventController.stream; - - // Event callbacks - final Map> _eventCallbacks = { - 'onPromptStart': [], - 'onPromptFinished': [], - }; - - final Map> - _typedEventCallbacks = { - for (var type in WebSocketEventType.values) type: [], - }; - - // Add a separate map for progress event callbacks - final List _progressEventCallbacks = []; - - /// Stream of typed WebSocket events - Stream get events => _eventController.stream; - - /// Stream of progress updates from ComfyUI (legacy format) - Stream> get progressUpdates => - _progressController.stream; + final WebSocketManager _webSocketManager; /// Creates a new ComfyUI API client - /// - /// [host] The host of the ComfyUI server (e.g. 'http://localhost:7860') - /// [clientId] Optional client ID, will be automatically generated if not provided ComfyUiApi({ required this.host, - String? clientId, + required this.clientId, http.Client? httpClient, - }) : clientId = clientId ?? const Uuid().v4(), - _httpClient = httpClient ?? http.Client(); + }) : _httpClient = httpClient ?? http.Client(), + _webSocketManager = WebSocketManager(host: host, clientId: clientId); - /// Register a callback for when a prompt starts processing - void onPromptStart(PromptEventCallback callback) { - _eventCallbacks['onPromptStart']!.add(callback); - } + /// Expose WebSocketManager streams and methods + Stream get events => _webSocketManager.events; + Stream> get progressUpdates => + _webSocketManager.progressUpdates; + Stream get progressEvents => _webSocketManager.progressEvents; + Stream get executionEvents => + _webSocketManager.executionEvents; - /// Register a callback for when a prompt finishes processing - void onPromptFinished(PromptEventCallback callback) { - _eventCallbacks['onPromptFinished']!.add(callback); - } - - /// Register a callback for specific WebSocket event types void onEventType(WebSocketEventType type, WebSocketEventCallback callback) { - _typedEventCallbacks[type]!.add(callback); + _webSocketManager.onEventType(type, callback); } - /// Register a callback for progress updates void onProgressChanged(ProgressEventCallback callback) { - _progressEventCallbacks.add(callback); + _webSocketManager.onProgressChanged(callback); } - /// Connects to the WebSocket for progress updates - Future connectWebSocket() async { - final wsUrl = - 'ws://${host.replaceFirst(RegExp(r'^https?://'), '')}/ws?clientId=$clientId'; - _wsChannel = WebSocketChannel.connect(Uri.parse(wsUrl)); - - _wsChannel!.stream.listen((message) { - final jsonData = jsonDecode(message); - - // Create a typed event - final event = WebSocketEvent.fromJson(jsonData); - - // Add to the typed event stream - _eventController.add(event); - - // Also add to the legacy progress stream - _progressController.add(jsonData); - - // Convert to more specific event types if possible - if (event.eventType == WebSocketEventType.progress) { - _tryCreateProgressEvent(event); - } else if ([ - WebSocketEventType.executionStart, - WebSocketEventType.executionSuccess, - WebSocketEventType.executionCached, - WebSocketEventType.executed, - WebSocketEventType.executing, - ].contains(event.eventType)) { - _tryCreateExecutionEvent(event); - } - - // Trigger event type specific callbacks - for (final callback in _typedEventCallbacks[event.eventType]!) { - callback(event); - } - - // Handle execution_success event (prompt finished) - if (event.eventType == WebSocketEventType.executionSuccess && - event.promptId != null) { - final promptId = event.promptId!; - for (final callback in _eventCallbacks['onPromptFinished']!) { - callback(promptId); - } - } - }, onError: (error) { - print('WebSocket error: $error'); - }, onDone: () { - print('WebSocket connection closed'); - }); + void onPromptStart(PromptEventCallback callback) { + _webSocketManager.onPromptStart(callback); } - /// Attempts to create a ProgressEvent from a WebSocketEvent - void _tryCreateProgressEvent(WebSocketEvent event) { - WebSocketEventHandler.tryCreateProgressEvent( - event, - _progressEventController, - _progressEventCallbacks, - ); + void onPromptFinished(PromptEventCallback callback) { + _webSocketManager.onPromptFinished(callback); } - /// Attempts to create an ExecutionEvent from a WebSocketEvent - void _tryCreateExecutionEvent(WebSocketEvent event) { - WebSocketEventHandler.tryCreateExecutionEvent( - event, - _executionEventController, - ); - } + Future connectWebSocket() => _webSocketManager.connect(); - /// Closes the WebSocket connection and cleans up resources void dispose() { - _wsChannel?.sink.close(); - _progressController.close(); - _eventController.close(); - _progressEventController.close(); - _executionEventController.close(); + _webSocketManager.dispose(); _httpClient.close(); } @@ -309,9 +198,7 @@ class ComfyUiApi { // Trigger onPromptStart event if prompt_id exists if (responseData.containsKey('prompt_id')) { final promptId = responseData['prompt_id']; - for (final callback in _eventCallbacks['onPromptStart']!) { - callback(promptId); - } + _webSocketManager.triggerOnPromptStart(promptId); } return responseData; diff --git a/lib/src/websocket_manager.dart b/lib/src/websocket_manager.dart new file mode 100644 index 0000000..b0d3643 --- /dev/null +++ b/lib/src/websocket_manager.dart @@ -0,0 +1,179 @@ +import 'dart:async'; +import 'dart:convert'; + +import 'package:web_socket_channel/web_socket_channel.dart'; + +import 'models/websocket_event.dart'; +import 'models/progress_event.dart'; +import 'models/execution_event.dart'; +import 'types/callback_types.dart'; +import 'utils/websocket_event_handler.dart'; + +class WebSocketManager { + final String host; + final String clientId; + + WebSocketChannel? _wsChannel; + + // Controllers for different event streams + final StreamController _eventController = + StreamController.broadcast(); + final StreamController> _progressController = + StreamController.broadcast(); + final StreamController _progressEventController = + StreamController.broadcast(); + final StreamController _executionEventController = + StreamController.broadcast(); + + // Event callbacks + final Map> + _typedEventCallbacks = { + for (var type in WebSocketEventType.values) type: [], + }; + final List _progressEventCallbacks = []; + final Map> _eventCallbacks = { + 'onPromptStart': [], + 'onPromptFinished': [], + }; + + WebSocketManager({required this.host, required this.clientId}); + + /// Stream of typed WebSocket events + Stream get events => _eventController.stream; + + /// Stream of progress updates (legacy format) + Stream> get progressUpdates => + _progressController.stream; + + /// Stream of typed progress events + Stream get progressEvents => _progressEventController.stream; + + /// Stream of typed execution events + Stream get executionEvents => + _executionEventController.stream; + + /// Register a callback for specific WebSocket event types + void onEventType(WebSocketEventType type, WebSocketEventCallback callback) { + _typedEventCallbacks[type]!.add(callback); + } + + /// Register a callback for progress updates + void onProgressChanged(ProgressEventCallback callback) { + _progressEventCallbacks.add(callback); + } + + /// Register a callback for when a prompt starts processing + void onPromptStart(PromptEventCallback callback) { + _eventCallbacks['onPromptStart']!.add(callback); + } + + /// Register a callback for when a prompt finishes processing + void onPromptFinished(PromptEventCallback callback) { + _eventCallbacks['onPromptFinished']!.add(callback); + } + + /// Connects to the WebSocket for progress updates + Future connect() async { + final wsUrl = + 'ws://${host.replaceFirst(RegExp(r'^https?://'), '')}/ws?clientId=$clientId'; + _wsChannel = WebSocketChannel.connect(Uri.parse(wsUrl)); + + _wsChannel!.stream.listen((message) { + final jsonData = jsonDecode(message); + + // Create a typed event + final event = WebSocketEvent.fromJson(jsonData); + + // Add to the typed event stream + _eventController.add(event); + + // Also add to the legacy progress stream + _progressController.add(jsonData); + + // Convert to more specific event types if possible + if (event.eventType == WebSocketEventType.progress) { + _tryCreateProgressEvent(event); + } else if ([ + WebSocketEventType.executionStart, + WebSocketEventType.executionSuccess, + WebSocketEventType.executionCached, + WebSocketEventType.executed, + WebSocketEventType.executing, + ].contains(event.eventType)) { + _tryCreateExecutionEvent(event); + } + + // Trigger event type specific callbacks + for (final callback in _typedEventCallbacks[event.eventType]!) { + callback(event); + } + + // Handle execution_success event (prompt finished) + if (event.eventType == WebSocketEventType.executionSuccess && + event.promptId != null) { + final promptId = event.promptId!; + for (final callback in _eventCallbacks['onPromptFinished']!) { + callback(promptId); + } + } + + // Handle progress updates + if (event.eventType == WebSocketEventType.progress) { + for (final callback in _progressEventCallbacks) { + callback(ProgressEvent.fromJson(event.data)); + } + } + }, onError: (error) { + print('WebSocket error: $error'); + }, onDone: () { + print('WebSocket connection closed'); + }); + } + + /// Attempts to create a ProgressEvent from a WebSocketEvent + void _tryCreateProgressEvent(WebSocketEvent event) { + WebSocketEventHandler.tryCreateProgressEvent( + event, + _progressEventController, + _progressEventCallbacks, + ); + } + + /// Attempts to create an ExecutionEvent from a WebSocketEvent + void _tryCreateExecutionEvent(WebSocketEvent event) { + WebSocketEventHandler.tryCreateExecutionEvent( + event, + _executionEventController, + ); + } + + /// Trigger the onPromptStart callbacks + void triggerOnPromptStart(String promptId) { + for (final callback in _eventCallbacks['onPromptStart']!) { + callback(promptId); + } + } + + /// Trigger the onPromptFinished callbacks + void triggerOnPromptFinished(String promptId) { + for (final callback in _eventCallbacks['onPromptFinished']!) { + callback(promptId); + } + } + + /// Trigger the onProgressChanged callbacks + void triggerOnProgressChanged(Map progressData) { + for (final callback in _progressEventCallbacks) { + callback(ProgressEvent.fromJson(progressData)); + } + } + + /// Closes the WebSocket connection and cleans up resources + void dispose() { + _wsChannel?.sink.close(); + _progressController.close(); + _eventController.close(); + _progressEventController.close(); + _executionEventController.close(); + } +}