Add WebSocketManager class for handling WebSocket connections and events

This commit is contained in:
2025-03-20 15:18:26 +00:00
parent 813e6f334e
commit 697d2f812d
2 changed files with 200 additions and 134 deletions

View File

@@ -3,172 +3,61 @@ import 'dart:convert';
import 'package:http/http.dart' as http;
import 'package:uuid/uuid.dart';
import 'package:web_socket_channel/web_socket_channel.dart';
import 'models/websocket_event.dart';
import 'models/progress_event.dart';
import 'models/execution_event.dart';
import 'exceptions/comfyui_api_exception.dart';
import 'types/callback_types.dart';
import 'utils/websocket_event_handler.dart';
import 'models/history_response.dart';
import 'models/checkpoint.dart';
import 'models/vae.dart';
import 'models/lora.dart';
import 'websocket_manager.dart';
/// A Dart SDK for interacting with the ComfyUI API
class ComfyUiApi {
final String host;
final String clientId;
final http.Client _httpClient;
WebSocketChannel? _wsChannel;
// Controllers for different event streams
final StreamController<WebSocketEvent> _eventController =
StreamController.broadcast();
final StreamController<Map<String, dynamic>> _progressController =
StreamController.broadcast();
// Add new controllers for specific event types
final StreamController<ProgressEvent> _progressEventController =
StreamController.broadcast();
final StreamController<ExecutionEvent> _executionEventController =
StreamController.broadcast();
/// Stream of typed progress events
Stream<ProgressEvent> get progressEvents => _progressEventController.stream;
/// Stream of typed execution events
Stream<ExecutionEvent> get executionEvents =>
_executionEventController.stream;
// Event callbacks
final Map<String, List<PromptEventCallback>> _eventCallbacks = {
'onPromptStart': [],
'onPromptFinished': [],
};
final Map<WebSocketEventType, List<WebSocketEventCallback>>
_typedEventCallbacks = {
for (var type in WebSocketEventType.values) type: [],
};
// Add a separate map for progress event callbacks
final List<ProgressEventCallback> _progressEventCallbacks = [];
/// Stream of typed WebSocket events
Stream<WebSocketEvent> get events => _eventController.stream;
/// Stream of progress updates from ComfyUI (legacy format)
Stream<Map<String, dynamic>> get progressUpdates =>
_progressController.stream;
final WebSocketManager _webSocketManager;
/// Creates a new ComfyUI API client
///
/// [host] The host of the ComfyUI server (e.g. 'http://localhost:7860')
/// [clientId] Optional client ID, will be automatically generated if not provided
ComfyUiApi({
required this.host,
String? clientId,
required this.clientId,
http.Client? httpClient,
}) : clientId = clientId ?? const Uuid().v4(),
_httpClient = httpClient ?? http.Client();
}) : _httpClient = httpClient ?? http.Client(),
_webSocketManager = WebSocketManager(host: host, clientId: clientId);
/// Register a callback for when a prompt starts processing
void onPromptStart(PromptEventCallback callback) {
_eventCallbacks['onPromptStart']!.add(callback);
}
/// Expose WebSocketManager streams and methods
Stream<WebSocketEvent> get events => _webSocketManager.events;
Stream<Map<String, dynamic>> get progressUpdates =>
_webSocketManager.progressUpdates;
Stream<ProgressEvent> get progressEvents => _webSocketManager.progressEvents;
Stream<ExecutionEvent> get executionEvents =>
_webSocketManager.executionEvents;
/// Register a callback for when a prompt finishes processing
void onPromptFinished(PromptEventCallback callback) {
_eventCallbacks['onPromptFinished']!.add(callback);
}
/// Register a callback for specific WebSocket event types
void onEventType(WebSocketEventType type, WebSocketEventCallback callback) {
_typedEventCallbacks[type]!.add(callback);
_webSocketManager.onEventType(type, callback);
}
/// Register a callback for progress updates
void onProgressChanged(ProgressEventCallback callback) {
_progressEventCallbacks.add(callback);
_webSocketManager.onProgressChanged(callback);
}
/// Connects to the WebSocket for progress updates
Future<void> connectWebSocket() async {
final wsUrl =
'ws://${host.replaceFirst(RegExp(r'^https?://'), '')}/ws?clientId=$clientId';
_wsChannel = WebSocketChannel.connect(Uri.parse(wsUrl));
_wsChannel!.stream.listen((message) {
final jsonData = jsonDecode(message);
// Create a typed event
final event = WebSocketEvent.fromJson(jsonData);
// Add to the typed event stream
_eventController.add(event);
// Also add to the legacy progress stream
_progressController.add(jsonData);
// Convert to more specific event types if possible
if (event.eventType == WebSocketEventType.progress) {
_tryCreateProgressEvent(event);
} else if ([
WebSocketEventType.executionStart,
WebSocketEventType.executionSuccess,
WebSocketEventType.executionCached,
WebSocketEventType.executed,
WebSocketEventType.executing,
].contains(event.eventType)) {
_tryCreateExecutionEvent(event);
}
// Trigger event type specific callbacks
for (final callback in _typedEventCallbacks[event.eventType]!) {
callback(event);
}
// Handle execution_success event (prompt finished)
if (event.eventType == WebSocketEventType.executionSuccess &&
event.promptId != null) {
final promptId = event.promptId!;
for (final callback in _eventCallbacks['onPromptFinished']!) {
callback(promptId);
}
}
}, onError: (error) {
print('WebSocket error: $error');
}, onDone: () {
print('WebSocket connection closed');
});
void onPromptStart(PromptEventCallback callback) {
_webSocketManager.onPromptStart(callback);
}
/// Attempts to create a ProgressEvent from a WebSocketEvent
void _tryCreateProgressEvent(WebSocketEvent event) {
WebSocketEventHandler.tryCreateProgressEvent(
event,
_progressEventController,
_progressEventCallbacks,
);
void onPromptFinished(PromptEventCallback callback) {
_webSocketManager.onPromptFinished(callback);
}
/// Attempts to create an ExecutionEvent from a WebSocketEvent
void _tryCreateExecutionEvent(WebSocketEvent event) {
WebSocketEventHandler.tryCreateExecutionEvent(
event,
_executionEventController,
);
}
Future<void> connectWebSocket() => _webSocketManager.connect();
/// Closes the WebSocket connection and cleans up resources
void dispose() {
_wsChannel?.sink.close();
_progressController.close();
_eventController.close();
_progressEventController.close();
_executionEventController.close();
_webSocketManager.dispose();
_httpClient.close();
}
@@ -309,9 +198,7 @@ class ComfyUiApi {
// Trigger onPromptStart event if prompt_id exists
if (responseData.containsKey('prompt_id')) {
final promptId = responseData['prompt_id'];
for (final callback in _eventCallbacks['onPromptStart']!) {
callback(promptId);
}
_webSocketManager.triggerOnPromptStart(promptId);
}
return responseData;

View File

@@ -0,0 +1,179 @@
import 'dart:async';
import 'dart:convert';
import 'package:web_socket_channel/web_socket_channel.dart';
import 'models/websocket_event.dart';
import 'models/progress_event.dart';
import 'models/execution_event.dart';
import 'types/callback_types.dart';
import 'utils/websocket_event_handler.dart';
class WebSocketManager {
final String host;
final String clientId;
WebSocketChannel? _wsChannel;
// Controllers for different event streams
final StreamController<WebSocketEvent> _eventController =
StreamController.broadcast();
final StreamController<Map<String, dynamic>> _progressController =
StreamController.broadcast();
final StreamController<ProgressEvent> _progressEventController =
StreamController.broadcast();
final StreamController<ExecutionEvent> _executionEventController =
StreamController.broadcast();
// Event callbacks
final Map<WebSocketEventType, List<WebSocketEventCallback>>
_typedEventCallbacks = {
for (var type in WebSocketEventType.values) type: [],
};
final List<ProgressEventCallback> _progressEventCallbacks = [];
final Map<String, List<PromptEventCallback>> _eventCallbacks = {
'onPromptStart': [],
'onPromptFinished': [],
};
WebSocketManager({required this.host, required this.clientId});
/// Stream of typed WebSocket events
Stream<WebSocketEvent> get events => _eventController.stream;
/// Stream of progress updates (legacy format)
Stream<Map<String, dynamic>> get progressUpdates =>
_progressController.stream;
/// Stream of typed progress events
Stream<ProgressEvent> get progressEvents => _progressEventController.stream;
/// Stream of typed execution events
Stream<ExecutionEvent> get executionEvents =>
_executionEventController.stream;
/// Register a callback for specific WebSocket event types
void onEventType(WebSocketEventType type, WebSocketEventCallback callback) {
_typedEventCallbacks[type]!.add(callback);
}
/// Register a callback for progress updates
void onProgressChanged(ProgressEventCallback callback) {
_progressEventCallbacks.add(callback);
}
/// Register a callback for when a prompt starts processing
void onPromptStart(PromptEventCallback callback) {
_eventCallbacks['onPromptStart']!.add(callback);
}
/// Register a callback for when a prompt finishes processing
void onPromptFinished(PromptEventCallback callback) {
_eventCallbacks['onPromptFinished']!.add(callback);
}
/// Connects to the WebSocket for progress updates
Future<void> connect() async {
final wsUrl =
'ws://${host.replaceFirst(RegExp(r'^https?://'), '')}/ws?clientId=$clientId';
_wsChannel = WebSocketChannel.connect(Uri.parse(wsUrl));
_wsChannel!.stream.listen((message) {
final jsonData = jsonDecode(message);
// Create a typed event
final event = WebSocketEvent.fromJson(jsonData);
// Add to the typed event stream
_eventController.add(event);
// Also add to the legacy progress stream
_progressController.add(jsonData);
// Convert to more specific event types if possible
if (event.eventType == WebSocketEventType.progress) {
_tryCreateProgressEvent(event);
} else if ([
WebSocketEventType.executionStart,
WebSocketEventType.executionSuccess,
WebSocketEventType.executionCached,
WebSocketEventType.executed,
WebSocketEventType.executing,
].contains(event.eventType)) {
_tryCreateExecutionEvent(event);
}
// Trigger event type specific callbacks
for (final callback in _typedEventCallbacks[event.eventType]!) {
callback(event);
}
// Handle execution_success event (prompt finished)
if (event.eventType == WebSocketEventType.executionSuccess &&
event.promptId != null) {
final promptId = event.promptId!;
for (final callback in _eventCallbacks['onPromptFinished']!) {
callback(promptId);
}
}
// Handle progress updates
if (event.eventType == WebSocketEventType.progress) {
for (final callback in _progressEventCallbacks) {
callback(ProgressEvent.fromJson(event.data));
}
}
}, onError: (error) {
print('WebSocket error: $error');
}, onDone: () {
print('WebSocket connection closed');
});
}
/// Attempts to create a ProgressEvent from a WebSocketEvent
void _tryCreateProgressEvent(WebSocketEvent event) {
WebSocketEventHandler.tryCreateProgressEvent(
event,
_progressEventController,
_progressEventCallbacks,
);
}
/// Attempts to create an ExecutionEvent from a WebSocketEvent
void _tryCreateExecutionEvent(WebSocketEvent event) {
WebSocketEventHandler.tryCreateExecutionEvent(
event,
_executionEventController,
);
}
/// Trigger the onPromptStart callbacks
void triggerOnPromptStart(String promptId) {
for (final callback in _eventCallbacks['onPromptStart']!) {
callback(promptId);
}
}
/// Trigger the onPromptFinished callbacks
void triggerOnPromptFinished(String promptId) {
for (final callback in _eventCallbacks['onPromptFinished']!) {
callback(promptId);
}
}
/// Trigger the onProgressChanged callbacks
void triggerOnProgressChanged(Map<String, dynamic> progressData) {
for (final callback in _progressEventCallbacks) {
callback(ProgressEvent.fromJson(progressData));
}
}
/// Closes the WebSocket connection and cleans up resources
void dispose() {
_wsChannel?.sink.close();
_progressController.close();
_eventController.close();
_progressEventController.close();
_executionEventController.close();
}
}