Add execution interruption handling and executing node tracking in WebSocketManager
This commit is contained in:
parent
f5d5982606
commit
9b50cf1575
@ -41,6 +41,14 @@ class ComfyUiApi {
|
|||||||
Stream<ProgressEvent> get progressEvents => _webSocketManager.progressEvents;
|
Stream<ProgressEvent> get progressEvents => _webSocketManager.progressEvents;
|
||||||
Stream<ExecutionEvent> get executionEvents =>
|
Stream<ExecutionEvent> get executionEvents =>
|
||||||
_webSocketManager.executionEvents;
|
_webSocketManager.executionEvents;
|
||||||
|
Stream<int> get executingNodeStream => _webSocketManager.executingNodeStream;
|
||||||
|
Stream<void> get executionInterruptedStream =>
|
||||||
|
_webSocketManager.executionInterruptedStream;
|
||||||
|
|
||||||
|
/// Register a callback for when the executing node changes
|
||||||
|
void onExecutingNodeChanged(void Function(int nodeId) callback) {
|
||||||
|
_webSocketManager.executingNodeStream.listen(callback);
|
||||||
|
}
|
||||||
|
|
||||||
void onEventType(WebSocketEventType type, WebSocketEventCallback callback) {
|
void onEventType(WebSocketEventType type, WebSocketEventCallback callback) {
|
||||||
_webSocketManager.onEventType(type, callback);
|
_webSocketManager.onEventType(type, callback);
|
||||||
@ -58,6 +66,11 @@ class ComfyUiApi {
|
|||||||
_webSocketManager.onPromptFinished(callback);
|
_webSocketManager.onPromptFinished(callback);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Register a callback for when execution is interrupted
|
||||||
|
void onExecutionInterrupted(void Function() callback) {
|
||||||
|
_webSocketManager.executionInterruptedStream.listen((_) => callback());
|
||||||
|
}
|
||||||
|
|
||||||
Future<void> connectWebSocket() => _webSocketManager.connect();
|
Future<void> connectWebSocket() => _webSocketManager.connect();
|
||||||
|
|
||||||
void dispose() {
|
void dispose() {
|
||||||
@ -235,6 +248,21 @@ class ComfyUiApi {
|
|||||||
return SubmitPromptResponse.fromJson(responseData);
|
return SubmitPromptResponse.fromJson(responseData);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Future<bool> interrupt({bool clearQueue = false}) async {
|
||||||
|
if (clearQueue) {
|
||||||
|
final clearQueueResponse = await _httpClient.post(
|
||||||
|
Uri.parse('$host/api/queue'),
|
||||||
|
headers: {'Content-Type': 'application/json'},
|
||||||
|
body: jsonEncode({'clear': true}),
|
||||||
|
);
|
||||||
|
_validateResponse(clearQueueResponse);
|
||||||
|
}
|
||||||
|
|
||||||
|
final response = await _httpClient.post(Uri.parse('$host/api/interrupt'));
|
||||||
|
_validateResponse(response);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
/// Validates HTTP response and throws an exception if needed
|
/// Validates HTTP response and throws an exception if needed
|
||||||
void _validateResponse(http.Response response) {
|
void _validateResponse(http.Response response) {
|
||||||
if (response.statusCode < 200 || response.statusCode >= 300) {
|
if (response.statusCode < 200 || response.statusCode >= 300) {
|
||||||
|
@ -7,6 +7,7 @@ enum WebSocketEventType {
|
|||||||
executionStart,
|
executionStart,
|
||||||
executionCached,
|
executionCached,
|
||||||
executionSuccess,
|
executionSuccess,
|
||||||
|
executionInterrupted,
|
||||||
executionError,
|
executionError,
|
||||||
dataOutput,
|
dataOutput,
|
||||||
unknown
|
unknown
|
||||||
@ -83,6 +84,9 @@ class WebSocketEvent {
|
|||||||
case 'execution_error':
|
case 'execution_error':
|
||||||
type = WebSocketEventType.executionError;
|
type = WebSocketEventType.executionError;
|
||||||
break;
|
break;
|
||||||
|
case 'execution_interrupted':
|
||||||
|
type = WebSocketEventType.executionInterrupted;
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
if (typeStr.startsWith('data_output')) {
|
if (typeStr.startsWith('data_output')) {
|
||||||
type = WebSocketEventType.dataOutput;
|
type = WebSocketEventType.dataOutput;
|
||||||
|
@ -24,6 +24,10 @@ class WebSocketManager {
|
|||||||
StreamController.broadcast();
|
StreamController.broadcast();
|
||||||
final StreamController<ExecutionEvent> _executionEventController =
|
final StreamController<ExecutionEvent> _executionEventController =
|
||||||
StreamController.broadcast();
|
StreamController.broadcast();
|
||||||
|
final StreamController<int> _executingNodeController =
|
||||||
|
StreamController.broadcast();
|
||||||
|
final StreamController<void> _executionInterruptedController =
|
||||||
|
StreamController.broadcast();
|
||||||
|
|
||||||
// Event callbacks
|
// Event callbacks
|
||||||
final Map<WebSocketEventType, List<WebSocketEventCallback>>
|
final Map<WebSocketEventType, List<WebSocketEventCallback>>
|
||||||
@ -52,6 +56,13 @@ class WebSocketManager {
|
|||||||
Stream<ExecutionEvent> get executionEvents =>
|
Stream<ExecutionEvent> get executionEvents =>
|
||||||
_executionEventController.stream;
|
_executionEventController.stream;
|
||||||
|
|
||||||
|
/// Stream of currently executing node
|
||||||
|
Stream<int> get executingNodeStream => _executingNodeController.stream;
|
||||||
|
|
||||||
|
/// Stream of execution interrupted events
|
||||||
|
Stream<void> get executionInterruptedStream =>
|
||||||
|
_executionInterruptedController.stream;
|
||||||
|
|
||||||
/// Register a callback for specific WebSocket event types
|
/// Register a callback for specific WebSocket event types
|
||||||
void onEventType(WebSocketEventType type, WebSocketEventCallback callback) {
|
void onEventType(WebSocketEventType type, WebSocketEventCallback callback) {
|
||||||
_typedEventCallbacks[type]!.add(callback);
|
_typedEventCallbacks[type]!.add(callback);
|
||||||
@ -83,6 +94,8 @@ class WebSocketManager {
|
|||||||
_wsChannel!.stream.listen((message) {
|
_wsChannel!.stream.listen((message) {
|
||||||
final jsonData = jsonDecode(message);
|
final jsonData = jsonDecode(message);
|
||||||
|
|
||||||
|
print('WebSocket message: $jsonData');
|
||||||
|
|
||||||
// Create a typed event
|
// Create a typed event
|
||||||
final event = WebSocketEvent.fromJson(jsonData);
|
final event = WebSocketEvent.fromJson(jsonData);
|
||||||
|
|
||||||
@ -101,10 +114,28 @@ class WebSocketManager {
|
|||||||
WebSocketEventType.executionCached,
|
WebSocketEventType.executionCached,
|
||||||
WebSocketEventType.executed,
|
WebSocketEventType.executed,
|
||||||
WebSocketEventType.executing,
|
WebSocketEventType.executing,
|
||||||
|
WebSocketEventType.executionInterrupted,
|
||||||
|
WebSocketEventType.status, // Add status event type
|
||||||
].contains(event.eventType)) {
|
].contains(event.eventType)) {
|
||||||
_tryCreateExecutionEvent(event);
|
_tryCreateExecutionEvent(event);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle "execution_interrupted" event
|
||||||
|
if (event.eventType == WebSocketEventType.executionInterrupted) {
|
||||||
|
_executionInterruptedController.add(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle "executing" event
|
||||||
|
if (event.eventType == WebSocketEventType.executing &&
|
||||||
|
event.data['node'] != null) {
|
||||||
|
final nodeId = int.tryParse(event.data['node'].toString());
|
||||||
|
if (nodeId != null) {
|
||||||
|
_executingNodeController.add(nodeId);
|
||||||
|
} else {
|
||||||
|
print('Invalid node ID: ${event.data['node']}');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Trigger event type specific callbacks
|
// Trigger event type specific callbacks
|
||||||
for (final callback in _typedEventCallbacks[event.eventType]!) {
|
for (final callback in _typedEventCallbacks[event.eventType]!) {
|
||||||
callback(event);
|
callback(event);
|
||||||
@ -193,5 +224,7 @@ class WebSocketManager {
|
|||||||
_eventController.close();
|
_eventController.close();
|
||||||
_progressEventController.close();
|
_progressEventController.close();
|
||||||
_executionEventController.close();
|
_executionEventController.close();
|
||||||
|
_executingNodeController.close();
|
||||||
|
_executionInterruptedController.close(); // Close the new controller
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user