Add execution interruption handling and executing node tracking in WebSocketManager
This commit is contained in:
parent
f5d5982606
commit
9b50cf1575
@ -41,6 +41,14 @@ class ComfyUiApi {
|
||||
Stream<ProgressEvent> get progressEvents => _webSocketManager.progressEvents;
|
||||
Stream<ExecutionEvent> get executionEvents =>
|
||||
_webSocketManager.executionEvents;
|
||||
Stream<int> get executingNodeStream => _webSocketManager.executingNodeStream;
|
||||
Stream<void> get executionInterruptedStream =>
|
||||
_webSocketManager.executionInterruptedStream;
|
||||
|
||||
/// Register a callback for when the executing node changes
|
||||
void onExecutingNodeChanged(void Function(int nodeId) callback) {
|
||||
_webSocketManager.executingNodeStream.listen(callback);
|
||||
}
|
||||
|
||||
void onEventType(WebSocketEventType type, WebSocketEventCallback callback) {
|
||||
_webSocketManager.onEventType(type, callback);
|
||||
@ -58,6 +66,11 @@ class ComfyUiApi {
|
||||
_webSocketManager.onPromptFinished(callback);
|
||||
}
|
||||
|
||||
/// Register a callback for when execution is interrupted
|
||||
void onExecutionInterrupted(void Function() callback) {
|
||||
_webSocketManager.executionInterruptedStream.listen((_) => callback());
|
||||
}
|
||||
|
||||
Future<void> connectWebSocket() => _webSocketManager.connect();
|
||||
|
||||
void dispose() {
|
||||
@ -235,6 +248,21 @@ class ComfyUiApi {
|
||||
return SubmitPromptResponse.fromJson(responseData);
|
||||
}
|
||||
|
||||
Future<bool> interrupt({bool clearQueue = false}) async {
|
||||
if (clearQueue) {
|
||||
final clearQueueResponse = await _httpClient.post(
|
||||
Uri.parse('$host/api/queue'),
|
||||
headers: {'Content-Type': 'application/json'},
|
||||
body: jsonEncode({'clear': true}),
|
||||
);
|
||||
_validateResponse(clearQueueResponse);
|
||||
}
|
||||
|
||||
final response = await _httpClient.post(Uri.parse('$host/api/interrupt'));
|
||||
_validateResponse(response);
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Validates HTTP response and throws an exception if needed
|
||||
void _validateResponse(http.Response response) {
|
||||
if (response.statusCode < 200 || response.statusCode >= 300) {
|
||||
|
@ -7,6 +7,7 @@ enum WebSocketEventType {
|
||||
executionStart,
|
||||
executionCached,
|
||||
executionSuccess,
|
||||
executionInterrupted,
|
||||
executionError,
|
||||
dataOutput,
|
||||
unknown
|
||||
@ -83,6 +84,9 @@ class WebSocketEvent {
|
||||
case 'execution_error':
|
||||
type = WebSocketEventType.executionError;
|
||||
break;
|
||||
case 'execution_interrupted':
|
||||
type = WebSocketEventType.executionInterrupted;
|
||||
break;
|
||||
default:
|
||||
if (typeStr.startsWith('data_output')) {
|
||||
type = WebSocketEventType.dataOutput;
|
||||
|
@ -24,6 +24,10 @@ class WebSocketManager {
|
||||
StreamController.broadcast();
|
||||
final StreamController<ExecutionEvent> _executionEventController =
|
||||
StreamController.broadcast();
|
||||
final StreamController<int> _executingNodeController =
|
||||
StreamController.broadcast();
|
||||
final StreamController<void> _executionInterruptedController =
|
||||
StreamController.broadcast();
|
||||
|
||||
// Event callbacks
|
||||
final Map<WebSocketEventType, List<WebSocketEventCallback>>
|
||||
@ -52,6 +56,13 @@ class WebSocketManager {
|
||||
Stream<ExecutionEvent> get executionEvents =>
|
||||
_executionEventController.stream;
|
||||
|
||||
/// Stream of currently executing node
|
||||
Stream<int> get executingNodeStream => _executingNodeController.stream;
|
||||
|
||||
/// Stream of execution interrupted events
|
||||
Stream<void> get executionInterruptedStream =>
|
||||
_executionInterruptedController.stream;
|
||||
|
||||
/// Register a callback for specific WebSocket event types
|
||||
void onEventType(WebSocketEventType type, WebSocketEventCallback callback) {
|
||||
_typedEventCallbacks[type]!.add(callback);
|
||||
@ -83,6 +94,8 @@ class WebSocketManager {
|
||||
_wsChannel!.stream.listen((message) {
|
||||
final jsonData = jsonDecode(message);
|
||||
|
||||
print('WebSocket message: $jsonData');
|
||||
|
||||
// Create a typed event
|
||||
final event = WebSocketEvent.fromJson(jsonData);
|
||||
|
||||
@ -101,10 +114,28 @@ class WebSocketManager {
|
||||
WebSocketEventType.executionCached,
|
||||
WebSocketEventType.executed,
|
||||
WebSocketEventType.executing,
|
||||
WebSocketEventType.executionInterrupted,
|
||||
WebSocketEventType.status, // Add status event type
|
||||
].contains(event.eventType)) {
|
||||
_tryCreateExecutionEvent(event);
|
||||
}
|
||||
|
||||
// Handle "execution_interrupted" event
|
||||
if (event.eventType == WebSocketEventType.executionInterrupted) {
|
||||
_executionInterruptedController.add(null);
|
||||
}
|
||||
|
||||
// Handle "executing" event
|
||||
if (event.eventType == WebSocketEventType.executing &&
|
||||
event.data['node'] != null) {
|
||||
final nodeId = int.tryParse(event.data['node'].toString());
|
||||
if (nodeId != null) {
|
||||
_executingNodeController.add(nodeId);
|
||||
} else {
|
||||
print('Invalid node ID: ${event.data['node']}');
|
||||
}
|
||||
}
|
||||
|
||||
// Trigger event type specific callbacks
|
||||
for (final callback in _typedEventCallbacks[event.eventType]!) {
|
||||
callback(event);
|
||||
@ -193,5 +224,7 @@ class WebSocketManager {
|
||||
_eventController.close();
|
||||
_progressEventController.close();
|
||||
_executionEventController.close();
|
||||
_executingNodeController.close();
|
||||
_executionInterruptedController.close(); // Close the new controller
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user