Add WebSocketManager class for handling WebSocket connections and events
This commit is contained in:
@@ -3,172 +3,61 @@ import 'dart:convert';
|
||||
|
||||
import 'package:http/http.dart' as http;
|
||||
import 'package:uuid/uuid.dart';
|
||||
import 'package:web_socket_channel/web_socket_channel.dart';
|
||||
|
||||
import 'models/websocket_event.dart';
|
||||
import 'models/progress_event.dart';
|
||||
import 'models/execution_event.dart';
|
||||
import 'exceptions/comfyui_api_exception.dart';
|
||||
import 'types/callback_types.dart';
|
||||
import 'utils/websocket_event_handler.dart';
|
||||
import 'models/history_response.dart';
|
||||
import 'models/checkpoint.dart';
|
||||
import 'models/vae.dart';
|
||||
import 'models/lora.dart';
|
||||
import 'websocket_manager.dart';
|
||||
|
||||
/// A Dart SDK for interacting with the ComfyUI API
|
||||
class ComfyUiApi {
|
||||
final String host;
|
||||
final String clientId;
|
||||
final http.Client _httpClient;
|
||||
WebSocketChannel? _wsChannel;
|
||||
|
||||
// Controllers for different event streams
|
||||
final StreamController<WebSocketEvent> _eventController =
|
||||
StreamController.broadcast();
|
||||
final StreamController<Map<String, dynamic>> _progressController =
|
||||
StreamController.broadcast();
|
||||
|
||||
// Add new controllers for specific event types
|
||||
final StreamController<ProgressEvent> _progressEventController =
|
||||
StreamController.broadcast();
|
||||
final StreamController<ExecutionEvent> _executionEventController =
|
||||
StreamController.broadcast();
|
||||
|
||||
/// Stream of typed progress events
|
||||
Stream<ProgressEvent> get progressEvents => _progressEventController.stream;
|
||||
|
||||
/// Stream of typed execution events
|
||||
Stream<ExecutionEvent> get executionEvents =>
|
||||
_executionEventController.stream;
|
||||
|
||||
// Event callbacks
|
||||
final Map<String, List<PromptEventCallback>> _eventCallbacks = {
|
||||
'onPromptStart': [],
|
||||
'onPromptFinished': [],
|
||||
};
|
||||
|
||||
final Map<WebSocketEventType, List<WebSocketEventCallback>>
|
||||
_typedEventCallbacks = {
|
||||
for (var type in WebSocketEventType.values) type: [],
|
||||
};
|
||||
|
||||
// Add a separate map for progress event callbacks
|
||||
final List<ProgressEventCallback> _progressEventCallbacks = [];
|
||||
|
||||
/// Stream of typed WebSocket events
|
||||
Stream<WebSocketEvent> get events => _eventController.stream;
|
||||
|
||||
/// Stream of progress updates from ComfyUI (legacy format)
|
||||
Stream<Map<String, dynamic>> get progressUpdates =>
|
||||
_progressController.stream;
|
||||
final WebSocketManager _webSocketManager;
|
||||
|
||||
/// Creates a new ComfyUI API client
|
||||
///
|
||||
/// [host] The host of the ComfyUI server (e.g. 'http://localhost:7860')
|
||||
/// [clientId] Optional client ID, will be automatically generated if not provided
|
||||
ComfyUiApi({
|
||||
required this.host,
|
||||
String? clientId,
|
||||
required this.clientId,
|
||||
http.Client? httpClient,
|
||||
}) : clientId = clientId ?? const Uuid().v4(),
|
||||
_httpClient = httpClient ?? http.Client();
|
||||
}) : _httpClient = httpClient ?? http.Client(),
|
||||
_webSocketManager = WebSocketManager(host: host, clientId: clientId);
|
||||
|
||||
/// Register a callback for when a prompt starts processing
|
||||
void onPromptStart(PromptEventCallback callback) {
|
||||
_eventCallbacks['onPromptStart']!.add(callback);
|
||||
}
|
||||
/// Expose WebSocketManager streams and methods
|
||||
Stream<WebSocketEvent> get events => _webSocketManager.events;
|
||||
Stream<Map<String, dynamic>> get progressUpdates =>
|
||||
_webSocketManager.progressUpdates;
|
||||
Stream<ProgressEvent> get progressEvents => _webSocketManager.progressEvents;
|
||||
Stream<ExecutionEvent> get executionEvents =>
|
||||
_webSocketManager.executionEvents;
|
||||
|
||||
/// Register a callback for when a prompt finishes processing
|
||||
void onPromptFinished(PromptEventCallback callback) {
|
||||
_eventCallbacks['onPromptFinished']!.add(callback);
|
||||
}
|
||||
|
||||
/// Register a callback for specific WebSocket event types
|
||||
void onEventType(WebSocketEventType type, WebSocketEventCallback callback) {
|
||||
_typedEventCallbacks[type]!.add(callback);
|
||||
_webSocketManager.onEventType(type, callback);
|
||||
}
|
||||
|
||||
/// Register a callback for progress updates
|
||||
void onProgressChanged(ProgressEventCallback callback) {
|
||||
_progressEventCallbacks.add(callback);
|
||||
_webSocketManager.onProgressChanged(callback);
|
||||
}
|
||||
|
||||
/// Connects to the WebSocket for progress updates
|
||||
Future<void> connectWebSocket() async {
|
||||
final wsUrl =
|
||||
'ws://${host.replaceFirst(RegExp(r'^https?://'), '')}/ws?clientId=$clientId';
|
||||
_wsChannel = WebSocketChannel.connect(Uri.parse(wsUrl));
|
||||
|
||||
_wsChannel!.stream.listen((message) {
|
||||
final jsonData = jsonDecode(message);
|
||||
|
||||
// Create a typed event
|
||||
final event = WebSocketEvent.fromJson(jsonData);
|
||||
|
||||
// Add to the typed event stream
|
||||
_eventController.add(event);
|
||||
|
||||
// Also add to the legacy progress stream
|
||||
_progressController.add(jsonData);
|
||||
|
||||
// Convert to more specific event types if possible
|
||||
if (event.eventType == WebSocketEventType.progress) {
|
||||
_tryCreateProgressEvent(event);
|
||||
} else if ([
|
||||
WebSocketEventType.executionStart,
|
||||
WebSocketEventType.executionSuccess,
|
||||
WebSocketEventType.executionCached,
|
||||
WebSocketEventType.executed,
|
||||
WebSocketEventType.executing,
|
||||
].contains(event.eventType)) {
|
||||
_tryCreateExecutionEvent(event);
|
||||
}
|
||||
|
||||
// Trigger event type specific callbacks
|
||||
for (final callback in _typedEventCallbacks[event.eventType]!) {
|
||||
callback(event);
|
||||
}
|
||||
|
||||
// Handle execution_success event (prompt finished)
|
||||
if (event.eventType == WebSocketEventType.executionSuccess &&
|
||||
event.promptId != null) {
|
||||
final promptId = event.promptId!;
|
||||
for (final callback in _eventCallbacks['onPromptFinished']!) {
|
||||
callback(promptId);
|
||||
}
|
||||
}
|
||||
}, onError: (error) {
|
||||
print('WebSocket error: $error');
|
||||
}, onDone: () {
|
||||
print('WebSocket connection closed');
|
||||
});
|
||||
void onPromptStart(PromptEventCallback callback) {
|
||||
_webSocketManager.onPromptStart(callback);
|
||||
}
|
||||
|
||||
/// Attempts to create a ProgressEvent from a WebSocketEvent
|
||||
void _tryCreateProgressEvent(WebSocketEvent event) {
|
||||
WebSocketEventHandler.tryCreateProgressEvent(
|
||||
event,
|
||||
_progressEventController,
|
||||
_progressEventCallbacks,
|
||||
);
|
||||
void onPromptFinished(PromptEventCallback callback) {
|
||||
_webSocketManager.onPromptFinished(callback);
|
||||
}
|
||||
|
||||
/// Attempts to create an ExecutionEvent from a WebSocketEvent
|
||||
void _tryCreateExecutionEvent(WebSocketEvent event) {
|
||||
WebSocketEventHandler.tryCreateExecutionEvent(
|
||||
event,
|
||||
_executionEventController,
|
||||
);
|
||||
}
|
||||
Future<void> connectWebSocket() => _webSocketManager.connect();
|
||||
|
||||
/// Closes the WebSocket connection and cleans up resources
|
||||
void dispose() {
|
||||
_wsChannel?.sink.close();
|
||||
_progressController.close();
|
||||
_eventController.close();
|
||||
_progressEventController.close();
|
||||
_executionEventController.close();
|
||||
_webSocketManager.dispose();
|
||||
_httpClient.close();
|
||||
}
|
||||
|
||||
@@ -309,9 +198,7 @@ class ComfyUiApi {
|
||||
// Trigger onPromptStart event if prompt_id exists
|
||||
if (responseData.containsKey('prompt_id')) {
|
||||
final promptId = responseData['prompt_id'];
|
||||
for (final callback in _eventCallbacks['onPromptStart']!) {
|
||||
callback(promptId);
|
||||
}
|
||||
_webSocketManager.triggerOnPromptStart(promptId);
|
||||
}
|
||||
|
||||
return responseData;
|
||||
|
179
lib/src/websocket_manager.dart
Normal file
179
lib/src/websocket_manager.dart
Normal file
@@ -0,0 +1,179 @@
|
||||
import 'dart:async';
|
||||
import 'dart:convert';
|
||||
|
||||
import 'package:web_socket_channel/web_socket_channel.dart';
|
||||
|
||||
import 'models/websocket_event.dart';
|
||||
import 'models/progress_event.dart';
|
||||
import 'models/execution_event.dart';
|
||||
import 'types/callback_types.dart';
|
||||
import 'utils/websocket_event_handler.dart';
|
||||
|
||||
class WebSocketManager {
|
||||
final String host;
|
||||
final String clientId;
|
||||
|
||||
WebSocketChannel? _wsChannel;
|
||||
|
||||
// Controllers for different event streams
|
||||
final StreamController<WebSocketEvent> _eventController =
|
||||
StreamController.broadcast();
|
||||
final StreamController<Map<String, dynamic>> _progressController =
|
||||
StreamController.broadcast();
|
||||
final StreamController<ProgressEvent> _progressEventController =
|
||||
StreamController.broadcast();
|
||||
final StreamController<ExecutionEvent> _executionEventController =
|
||||
StreamController.broadcast();
|
||||
|
||||
// Event callbacks
|
||||
final Map<WebSocketEventType, List<WebSocketEventCallback>>
|
||||
_typedEventCallbacks = {
|
||||
for (var type in WebSocketEventType.values) type: [],
|
||||
};
|
||||
final List<ProgressEventCallback> _progressEventCallbacks = [];
|
||||
final Map<String, List<PromptEventCallback>> _eventCallbacks = {
|
||||
'onPromptStart': [],
|
||||
'onPromptFinished': [],
|
||||
};
|
||||
|
||||
WebSocketManager({required this.host, required this.clientId});
|
||||
|
||||
/// Stream of typed WebSocket events
|
||||
Stream<WebSocketEvent> get events => _eventController.stream;
|
||||
|
||||
/// Stream of progress updates (legacy format)
|
||||
Stream<Map<String, dynamic>> get progressUpdates =>
|
||||
_progressController.stream;
|
||||
|
||||
/// Stream of typed progress events
|
||||
Stream<ProgressEvent> get progressEvents => _progressEventController.stream;
|
||||
|
||||
/// Stream of typed execution events
|
||||
Stream<ExecutionEvent> get executionEvents =>
|
||||
_executionEventController.stream;
|
||||
|
||||
/// Register a callback for specific WebSocket event types
|
||||
void onEventType(WebSocketEventType type, WebSocketEventCallback callback) {
|
||||
_typedEventCallbacks[type]!.add(callback);
|
||||
}
|
||||
|
||||
/// Register a callback for progress updates
|
||||
void onProgressChanged(ProgressEventCallback callback) {
|
||||
_progressEventCallbacks.add(callback);
|
||||
}
|
||||
|
||||
/// Register a callback for when a prompt starts processing
|
||||
void onPromptStart(PromptEventCallback callback) {
|
||||
_eventCallbacks['onPromptStart']!.add(callback);
|
||||
}
|
||||
|
||||
/// Register a callback for when a prompt finishes processing
|
||||
void onPromptFinished(PromptEventCallback callback) {
|
||||
_eventCallbacks['onPromptFinished']!.add(callback);
|
||||
}
|
||||
|
||||
/// Connects to the WebSocket for progress updates
|
||||
Future<void> connect() async {
|
||||
final wsUrl =
|
||||
'ws://${host.replaceFirst(RegExp(r'^https?://'), '')}/ws?clientId=$clientId';
|
||||
_wsChannel = WebSocketChannel.connect(Uri.parse(wsUrl));
|
||||
|
||||
_wsChannel!.stream.listen((message) {
|
||||
final jsonData = jsonDecode(message);
|
||||
|
||||
// Create a typed event
|
||||
final event = WebSocketEvent.fromJson(jsonData);
|
||||
|
||||
// Add to the typed event stream
|
||||
_eventController.add(event);
|
||||
|
||||
// Also add to the legacy progress stream
|
||||
_progressController.add(jsonData);
|
||||
|
||||
// Convert to more specific event types if possible
|
||||
if (event.eventType == WebSocketEventType.progress) {
|
||||
_tryCreateProgressEvent(event);
|
||||
} else if ([
|
||||
WebSocketEventType.executionStart,
|
||||
WebSocketEventType.executionSuccess,
|
||||
WebSocketEventType.executionCached,
|
||||
WebSocketEventType.executed,
|
||||
WebSocketEventType.executing,
|
||||
].contains(event.eventType)) {
|
||||
_tryCreateExecutionEvent(event);
|
||||
}
|
||||
|
||||
// Trigger event type specific callbacks
|
||||
for (final callback in _typedEventCallbacks[event.eventType]!) {
|
||||
callback(event);
|
||||
}
|
||||
|
||||
// Handle execution_success event (prompt finished)
|
||||
if (event.eventType == WebSocketEventType.executionSuccess &&
|
||||
event.promptId != null) {
|
||||
final promptId = event.promptId!;
|
||||
for (final callback in _eventCallbacks['onPromptFinished']!) {
|
||||
callback(promptId);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle progress updates
|
||||
if (event.eventType == WebSocketEventType.progress) {
|
||||
for (final callback in _progressEventCallbacks) {
|
||||
callback(ProgressEvent.fromJson(event.data));
|
||||
}
|
||||
}
|
||||
}, onError: (error) {
|
||||
print('WebSocket error: $error');
|
||||
}, onDone: () {
|
||||
print('WebSocket connection closed');
|
||||
});
|
||||
}
|
||||
|
||||
/// Attempts to create a ProgressEvent from a WebSocketEvent
|
||||
void _tryCreateProgressEvent(WebSocketEvent event) {
|
||||
WebSocketEventHandler.tryCreateProgressEvent(
|
||||
event,
|
||||
_progressEventController,
|
||||
_progressEventCallbacks,
|
||||
);
|
||||
}
|
||||
|
||||
/// Attempts to create an ExecutionEvent from a WebSocketEvent
|
||||
void _tryCreateExecutionEvent(WebSocketEvent event) {
|
||||
WebSocketEventHandler.tryCreateExecutionEvent(
|
||||
event,
|
||||
_executionEventController,
|
||||
);
|
||||
}
|
||||
|
||||
/// Trigger the onPromptStart callbacks
|
||||
void triggerOnPromptStart(String promptId) {
|
||||
for (final callback in _eventCallbacks['onPromptStart']!) {
|
||||
callback(promptId);
|
||||
}
|
||||
}
|
||||
|
||||
/// Trigger the onPromptFinished callbacks
|
||||
void triggerOnPromptFinished(String promptId) {
|
||||
for (final callback in _eventCallbacks['onPromptFinished']!) {
|
||||
callback(promptId);
|
||||
}
|
||||
}
|
||||
|
||||
/// Trigger the onProgressChanged callbacks
|
||||
void triggerOnProgressChanged(Map<String, dynamic> progressData) {
|
||||
for (final callback in _progressEventCallbacks) {
|
||||
callback(ProgressEvent.fromJson(progressData));
|
||||
}
|
||||
}
|
||||
|
||||
/// Closes the WebSocket connection and cleans up resources
|
||||
void dispose() {
|
||||
_wsChannel?.sink.close();
|
||||
_progressController.close();
|
||||
_eventController.close();
|
||||
_progressEventController.close();
|
||||
_executionEventController.close();
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user