Add execution interruption handling and executing node tracking in WebSocketManager

This commit is contained in:
Menno van Leeuwen 2025-03-22 00:00:50 +01:00
parent f5d5982606
commit 9b50cf1575
Signed by: vleeuwenmenno
SSH Key Fingerprint: SHA256:OJFmjANpakwD3F2Rsws4GLtbdz1TJ5tkQF0RZmF0TRE
3 changed files with 65 additions and 0 deletions

View File

@ -41,6 +41,14 @@ class ComfyUiApi {
Stream<ProgressEvent> get progressEvents => _webSocketManager.progressEvents;
Stream<ExecutionEvent> get executionEvents =>
_webSocketManager.executionEvents;
Stream<int> get executingNodeStream => _webSocketManager.executingNodeStream;
Stream<void> get executionInterruptedStream =>
_webSocketManager.executionInterruptedStream;
/// Register a callback for when the executing node changes
void onExecutingNodeChanged(void Function(int nodeId) callback) {
_webSocketManager.executingNodeStream.listen(callback);
}
void onEventType(WebSocketEventType type, WebSocketEventCallback callback) {
_webSocketManager.onEventType(type, callback);
@ -58,6 +66,11 @@ class ComfyUiApi {
_webSocketManager.onPromptFinished(callback);
}
/// Register a callback for when execution is interrupted
void onExecutionInterrupted(void Function() callback) {
_webSocketManager.executionInterruptedStream.listen((_) => callback());
}
Future<void> connectWebSocket() => _webSocketManager.connect();
void dispose() {
@ -235,6 +248,21 @@ class ComfyUiApi {
return SubmitPromptResponse.fromJson(responseData);
}
Future<bool> interrupt({bool clearQueue = false}) async {
if (clearQueue) {
final clearQueueResponse = await _httpClient.post(
Uri.parse('$host/api/queue'),
headers: {'Content-Type': 'application/json'},
body: jsonEncode({'clear': true}),
);
_validateResponse(clearQueueResponse);
}
final response = await _httpClient.post(Uri.parse('$host/api/interrupt'));
_validateResponse(response);
return true;
}
/// Validates HTTP response and throws an exception if needed
void _validateResponse(http.Response response) {
if (response.statusCode < 200 || response.statusCode >= 300) {

View File

@ -7,6 +7,7 @@ enum WebSocketEventType {
executionStart,
executionCached,
executionSuccess,
executionInterrupted,
executionError,
dataOutput,
unknown
@ -83,6 +84,9 @@ class WebSocketEvent {
case 'execution_error':
type = WebSocketEventType.executionError;
break;
case 'execution_interrupted':
type = WebSocketEventType.executionInterrupted;
break;
default:
if (typeStr.startsWith('data_output')) {
type = WebSocketEventType.dataOutput;

View File

@ -24,6 +24,10 @@ class WebSocketManager {
StreamController.broadcast();
final StreamController<ExecutionEvent> _executionEventController =
StreamController.broadcast();
final StreamController<int> _executingNodeController =
StreamController.broadcast();
final StreamController<void> _executionInterruptedController =
StreamController.broadcast();
// Event callbacks
final Map<WebSocketEventType, List<WebSocketEventCallback>>
@ -52,6 +56,13 @@ class WebSocketManager {
Stream<ExecutionEvent> get executionEvents =>
_executionEventController.stream;
/// Stream of currently executing node
Stream<int> get executingNodeStream => _executingNodeController.stream;
/// Stream of execution interrupted events
Stream<void> get executionInterruptedStream =>
_executionInterruptedController.stream;
/// Register a callback for specific WebSocket event types
void onEventType(WebSocketEventType type, WebSocketEventCallback callback) {
_typedEventCallbacks[type]!.add(callback);
@ -83,6 +94,8 @@ class WebSocketManager {
_wsChannel!.stream.listen((message) {
final jsonData = jsonDecode(message);
print('WebSocket message: $jsonData');
// Create a typed event
final event = WebSocketEvent.fromJson(jsonData);
@ -101,10 +114,28 @@ class WebSocketManager {
WebSocketEventType.executionCached,
WebSocketEventType.executed,
WebSocketEventType.executing,
WebSocketEventType.executionInterrupted,
WebSocketEventType.status, // Add status event type
].contains(event.eventType)) {
_tryCreateExecutionEvent(event);
}
// Handle "execution_interrupted" event
if (event.eventType == WebSocketEventType.executionInterrupted) {
_executionInterruptedController.add(null);
}
// Handle "executing" event
if (event.eventType == WebSocketEventType.executing &&
event.data['node'] != null) {
final nodeId = int.tryParse(event.data['node'].toString());
if (nodeId != null) {
_executingNodeController.add(nodeId);
} else {
print('Invalid node ID: ${event.data['node']}');
}
}
// Trigger event type specific callbacks
for (final callback in _typedEventCallbacks[event.eventType]!) {
callback(event);
@ -193,5 +224,7 @@ class WebSocketManager {
_eventController.close();
_progressEventController.close();
_executionEventController.close();
_executingNodeController.close();
_executionInterruptedController.close(); // Close the new controller
}
}