From 9b50cf15756e32152e3a4d47935a6721f349e8ea Mon Sep 17 00:00:00 2001 From: Menno van Leeuwen Date: Sat, 22 Mar 2025 00:00:50 +0100 Subject: [PATCH] Add execution interruption handling and executing node tracking in WebSocketManager --- lib/src/comfyui_api.dart | 28 ++++++++++++++++++++++++ lib/src/models/websocket_event.dart | 4 ++++ lib/src/websocket_manager.dart | 33 +++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+) diff --git a/lib/src/comfyui_api.dart b/lib/src/comfyui_api.dart index 1b6d1a9..27def7d 100644 --- a/lib/src/comfyui_api.dart +++ b/lib/src/comfyui_api.dart @@ -41,6 +41,14 @@ class ComfyUiApi { Stream get progressEvents => _webSocketManager.progressEvents; Stream get executionEvents => _webSocketManager.executionEvents; + Stream get executingNodeStream => _webSocketManager.executingNodeStream; + Stream get executionInterruptedStream => + _webSocketManager.executionInterruptedStream; + + /// Register a callback for when the executing node changes + void onExecutingNodeChanged(void Function(int nodeId) callback) { + _webSocketManager.executingNodeStream.listen(callback); + } void onEventType(WebSocketEventType type, WebSocketEventCallback callback) { _webSocketManager.onEventType(type, callback); @@ -58,6 +66,11 @@ class ComfyUiApi { _webSocketManager.onPromptFinished(callback); } + /// Register a callback for when execution is interrupted + void onExecutionInterrupted(void Function() callback) { + _webSocketManager.executionInterruptedStream.listen((_) => callback()); + } + Future connectWebSocket() => _webSocketManager.connect(); void dispose() { @@ -235,6 +248,21 @@ class ComfyUiApi { return SubmitPromptResponse.fromJson(responseData); } + Future interrupt({bool clearQueue = false}) async { + if (clearQueue) { + final clearQueueResponse = await _httpClient.post( + Uri.parse('$host/api/queue'), + headers: {'Content-Type': 'application/json'}, + body: jsonEncode({'clear': true}), + ); + _validateResponse(clearQueueResponse); + } + + final response = await _httpClient.post(Uri.parse('$host/api/interrupt')); + _validateResponse(response); + return true; + } + /// Validates HTTP response and throws an exception if needed void _validateResponse(http.Response response) { if (response.statusCode < 200 || response.statusCode >= 300) { diff --git a/lib/src/models/websocket_event.dart b/lib/src/models/websocket_event.dart index 9751a4b..535553f 100644 --- a/lib/src/models/websocket_event.dart +++ b/lib/src/models/websocket_event.dart @@ -7,6 +7,7 @@ enum WebSocketEventType { executionStart, executionCached, executionSuccess, + executionInterrupted, executionError, dataOutput, unknown @@ -83,6 +84,9 @@ class WebSocketEvent { case 'execution_error': type = WebSocketEventType.executionError; break; + case 'execution_interrupted': + type = WebSocketEventType.executionInterrupted; + break; default: if (typeStr.startsWith('data_output')) { type = WebSocketEventType.dataOutput; diff --git a/lib/src/websocket_manager.dart b/lib/src/websocket_manager.dart index f5f85ae..75097a0 100644 --- a/lib/src/websocket_manager.dart +++ b/lib/src/websocket_manager.dart @@ -24,6 +24,10 @@ class WebSocketManager { StreamController.broadcast(); final StreamController _executionEventController = StreamController.broadcast(); + final StreamController _executingNodeController = + StreamController.broadcast(); + final StreamController _executionInterruptedController = + StreamController.broadcast(); // Event callbacks final Map> @@ -52,6 +56,13 @@ class WebSocketManager { Stream get executionEvents => _executionEventController.stream; + /// Stream of currently executing node + Stream get executingNodeStream => _executingNodeController.stream; + + /// Stream of execution interrupted events + Stream get executionInterruptedStream => + _executionInterruptedController.stream; + /// Register a callback for specific WebSocket event types void onEventType(WebSocketEventType type, WebSocketEventCallback callback) { _typedEventCallbacks[type]!.add(callback); @@ -83,6 +94,8 @@ class WebSocketManager { _wsChannel!.stream.listen((message) { final jsonData = jsonDecode(message); + print('WebSocket message: $jsonData'); + // Create a typed event final event = WebSocketEvent.fromJson(jsonData); @@ -101,10 +114,28 @@ class WebSocketManager { WebSocketEventType.executionCached, WebSocketEventType.executed, WebSocketEventType.executing, + WebSocketEventType.executionInterrupted, + WebSocketEventType.status, // Add status event type ].contains(event.eventType)) { _tryCreateExecutionEvent(event); } + // Handle "execution_interrupted" event + if (event.eventType == WebSocketEventType.executionInterrupted) { + _executionInterruptedController.add(null); + } + + // Handle "executing" event + if (event.eventType == WebSocketEventType.executing && + event.data['node'] != null) { + final nodeId = int.tryParse(event.data['node'].toString()); + if (nodeId != null) { + _executingNodeController.add(nodeId); + } else { + print('Invalid node ID: ${event.data['node']}'); + } + } + // Trigger event type specific callbacks for (final callback in _typedEventCallbacks[event.eventType]!) { callback(event); @@ -193,5 +224,7 @@ class WebSocketManager { _eventController.close(); _progressEventController.close(); _executionEventController.close(); + _executingNodeController.close(); + _executionInterruptedController.close(); // Close the new controller } }