Files
comfyui_api_sdk/lib/src/websocket_manager.dart
2025-08-11 15:43:17 +02:00

433 lines
14 KiB
Dart

import 'dart:async';
import 'dart:convert';
import 'dart:typed_data';
import 'package:web_socket_channel/web_socket_channel.dart';
// Use relative imports to avoid duplicate library instances when this package
// is consumed via a path or symlink (prevents distinct "same" types).
import 'models/websocket_event.dart';
import 'models/progress_event.dart';
import 'models/execution_event.dart';
import 'types/callback_types.dart';
import 'utils/websocket_event_handler.dart';
/// Connection states for the ComfyUI WebSocket
enum ConnectionState { connected, connecting, disconnected, failed }
class WebSocketManager {
final String host;
final String clientId;
WebSocketChannel? _wsChannel;
// Connection state tracking
ConnectionState _connectionState = ConnectionState.disconnected;
int _retryAttempt = 0;
static const int _maxRetryAttempts = 3;
final StreamController<ConnectionState> _connectionStateController =
StreamController.broadcast();
final StreamController<int> _retryAttemptController =
StreamController.broadcast();
// Controllers for different event streams
final StreamController<WebSocketEvent> _eventController =
StreamController.broadcast();
final StreamController<Map<String, dynamic>> _progressController =
StreamController.broadcast();
final StreamController<ProgressEvent> _progressEventController =
StreamController.broadcast();
final StreamController<ExecutionEvent> _executionEventController =
StreamController.broadcast();
final StreamController<int> _executingNodeController =
StreamController.broadcast();
final StreamController<void> _executionInterruptedController =
StreamController.broadcast();
// Optional future binary preview frames
final StreamController<Uint8List> _previewFrameController =
StreamController.broadcast();
// Event callbacks
final Map<WebSocketEventType, List<WebSocketEventCallback>>
_typedEventCallbacks = {
for (var type in WebSocketEventType.values) type: <WebSocketEventCallback>[],
};
final List<ProgressEventCallback> _progressEventCallbacks = [];
final Map<String, List<PromptEventCallback>> _eventCallbacks = {
'onPromptStart': <PromptEventCallback>[],
'onPromptFinished': <PromptEventCallback>[],
};
// Frame / parse statistics
int _ignoredBinaryFrames = 0;
int _malformedFrames = 0;
int _previewFrames = 0;
int _textFrames = 0;
WebSocketManager({required this.host, required this.clientId});
/// Stream of typed WebSocket events
Stream<WebSocketEvent> get events => _eventController.stream;
/// Stream of connection state changes
Stream<ConnectionState> get connectionState =>
_connectionStateController.stream;
/// Stream of retry attempt changes
Stream<int> get retryAttemptChanges => _retryAttemptController.stream;
/// Current connection state
ConnectionState get currentConnectionState => _connectionState;
/// Current retry attempt count
int get retryAttempt => _retryAttempt;
/// Maximum number of retry attempts
static int get maxRetryAttempts => _maxRetryAttempts;
/// Stream of progress updates (legacy format)
Stream<Map<String, dynamic>> get progressUpdates =>
_progressController.stream;
/// Stream of typed progress events
Stream<ProgressEvent> get progressEvents => _progressEventController.stream;
/// Stream of typed execution events
Stream<ExecutionEvent> get executionEvents =>
_executionEventController.stream;
/// Stream of currently executing node
Stream<int> get executingNodeStream => _executingNodeController.stream;
/// Stream of execution interrupted events
Stream<void> get executionInterruptedStream =>
_executionInterruptedController.stream;
/// Stream of (future) binary preview frames
Stream<Uint8List> get previewFrames => _previewFrameController.stream;
/// Stats about received frames
Map<String, int> get stats => {
'ignoredBinaryFrames': _ignoredBinaryFrames,
'malformedFrames': _malformedFrames,
'previewFrames': _previewFrames,
'textFrames': _textFrames,
'retryAttempts': _retryAttempt,
};
/// Register a callback for specific WebSocket event types
void onEventType(WebSocketEventType type, WebSocketEventCallback callback) {
_typedEventCallbacks[type]!.add(callback);
}
/// Register a callback for progress updates
void onProgressChanged(ProgressEventCallback callback) {
_progressEventCallbacks.add(callback);
}
/// Register a callback for when a prompt starts processing
void onPromptStart(PromptEventCallback callback) {
_eventCallbacks['onPromptStart']!.add(callback);
}
/// Register a callback for when a prompt finishes processing
void onPromptFinished(PromptEventCallback callback) {
_eventCallbacks['onPromptFinished']!.add(callback);
}
/// Connects to the WebSocket for progress updates
Future<void> connect() async {
// Update connection state to connecting
_updateConnectionState(ConnectionState.connecting);
final wsUrl =
'ws://${host.replaceFirst(RegExp(r'^https?://'), '')}/ws?clientId=$clientId';
_wsChannel = WebSocketChannel.connect(Uri.parse(wsUrl));
print('WebSocket connecting to $wsUrl');
_wsChannel!.stream.listen((dynamic message) {
// Successfully connected
if (_connectionState != ConnectionState.connected) {
_updateConnectionState(ConnectionState.connected);
_updateRetryAttempt(0); // Reset retry counter on successful connection
}
// Determine frame type
String? textFrame;
if (message is String) {
textFrame = message;
_textFrames++;
} else if (message is List<int>) {
// Attempt to interpret as UTF8 JSON text
try {
final decoded = utf8.decode(message, allowMalformed: true);
final trimmed = decoded.trimLeft();
// Heuristic: treat as JSON if it starts with '{' or '['
if (trimmed.startsWith('{') || trimmed.startsWith('[')) {
textFrame = decoded;
_textFrames++;
} else {
// Detect common image headers for potential future preview streaming
final isJpeg =
message.length >= 2 && message[0] == 0xFF && message[1] == 0xD8;
final isPng = message.length >= 8 &&
message[0] == 0x89 &&
message[1] == 0x50 &&
message[2] == 0x4E &&
message[3] == 0x47 &&
message[4] == 0x0D &&
message[5] == 0x0A &&
message[6] == 0x1A &&
message[7] == 0x0A;
if (isJpeg || isPng) {
// Future: push to preview consumers
_previewFrames++;
// _previewFrameController.add(Uint8List.fromList(message));
print(
'WebSocket binary preview frame received (${message.length} bytes) ignored (preview streaming not enabled).');
} else {
_ignoredBinaryFrames++;
print(
'WebSocket non-JSON binary frame ignored (${message.length} bytes).');
}
return; // Do not proceed to JSON decode
}
} catch (_) {
// Could not decode as UTF8
_ignoredBinaryFrames++;
print(
'WebSocket binary frame ignored (UTF8 decode failed, ${message.length} bytes).');
return;
}
} else {
_ignoredBinaryFrames++;
print(
'WebSocket unsupported frame type ignored (${message.runtimeType}).');
return;
}
if (textFrame == null) {
_malformedFrames++;
print('WebSocket frame had no decodable text content.');
return;
}
dynamic jsonData;
try {
jsonData = jsonDecode(textFrame);
} catch (e) {
_malformedFrames++;
print('WebSocket malformed JSON frame ignored: $e');
return;
}
if (jsonData is! Map<String, dynamic>) {
_malformedFrames++;
print(
'WebSocket JSON root not object (type: ${jsonData.runtimeType}) ignored.');
return;
}
print('WebSocket message (event): $jsonData');
WebSocketEvent event;
try {
event = WebSocketEvent.fromJson(jsonData);
} catch (e) {
_malformedFrames++;
print('WebSocket event parse failed: $e');
return;
}
// Add to the typed event stream
_eventController.add(event);
// Also add to the legacy progress stream
_progressController.add(jsonData);
// Convert to more specific event types if possible
try {
if (event.eventType == WebSocketEventType.progress) {
_tryCreateProgressEvent(event);
} else if ([
WebSocketEventType.executionStart,
WebSocketEventType.executionSuccess,
WebSocketEventType.executionCached,
WebSocketEventType.executed,
WebSocketEventType.executing,
WebSocketEventType.executionInterrupted,
WebSocketEventType.status, // Add status event type
].contains(event.eventType)) {
_tryCreateExecutionEvent(event);
}
// Handle "execution_interrupted" event
if (event.eventType == WebSocketEventType.executionInterrupted) {
_executionInterruptedController.add(null);
}
// Handle "executing" event
if (event.eventType == WebSocketEventType.executing &&
event.data['node'] != null) {
final nodeId = int.tryParse(event.data['node'].toString());
if (nodeId != null) {
_executingNodeController.add(nodeId);
} else {
print('Invalid node ID: ${event.data['node']}');
}
}
// Trigger event type specific callbacks
for (final callback in _typedEventCallbacks[event.eventType]!) {
try {
callback(event);
} catch (e, st) {
print('WebSocket event callback error: $e\n$st');
}
}
// Handle execution_success event (prompt finished)
if (event.eventType == WebSocketEventType.executionSuccess &&
event.promptId != null) {
final promptId = event.promptId!;
for (final callback in _eventCallbacks['onPromptFinished']!) {
try {
callback(promptId);
} catch (e, st) {
print('WebSocket onPromptFinished callback error: $e\n$st');
}
}
}
// Handle progress updates
if (event.eventType == WebSocketEventType.progress) {
for (final callback in _progressEventCallbacks) {
try {
callback(ProgressEvent.fromJson(event.data));
} catch (e, st) {
print('WebSocket progress callback error: $e\n$st');
}
}
}
} catch (e, st) {
print('WebSocket event dispatch error: $e\n$st');
}
}, onError: (error) {
print('WebSocket error: $error');
_reconnect();
}, onDone: () {
print('WebSocket connection closed');
_reconnect();
});
}
/// Reconnects the WebSocket with a delay
Future<void> _reconnect() async {
// Increment retry attempt and notify listeners
_updateRetryAttempt(_retryAttempt + 1);
if (_retryAttempt <= _maxRetryAttempts) {
print(
'Attempting to reconnect WebSocket (attempt $_retryAttempt/$_maxRetryAttempts) in 5 seconds...');
_updateConnectionState(ConnectionState.connecting);
await Future.delayed(const Duration(seconds: 5));
try {
await connect();
print('WebSocket reconnected successfully');
} catch (e) {
print('WebSocket reconnection failed: $e');
if (_retryAttempt < _maxRetryAttempts) {
_reconnect(); // Retry again if under max attempts
} else {
// Max retries reached
print('Max reconnection attempts reached');
_updateConnectionState(ConnectionState.failed);
}
}
} else {
// Max retries reached
print('Max reconnection attempts reached');
_updateConnectionState(ConnectionState.failed);
}
}
/// Updates the connection state and notifies listeners
void _updateConnectionState(ConnectionState state) {
_connectionState = state;
_connectionStateController.add(state);
}
/// Updates the retry attempt count and notifies listeners
void _updateRetryAttempt(int attempt) {
_retryAttempt = attempt;
_retryAttemptController.add(attempt);
// Also re-emit the current state to ensure UI updates
_connectionStateController.add(_connectionState);
}
/// Manually retry connection after failure
Future<void> retryConnection() async {
if (_connectionState == ConnectionState.failed) {
_updateRetryAttempt(0); // Reset retry counter
await connect();
}
}
/// Attempts to create a ProgressEvent from a WebSocketEvent
void _tryCreateProgressEvent(WebSocketEvent event) {
WebSocketEventHandler.tryCreateProgressEvent(
event,
_progressEventController,
_progressEventCallbacks,
);
}
/// Attempts to create an ExecutionEvent from a WebSocketEvent
void _tryCreateExecutionEvent(WebSocketEvent event) {
WebSocketEventHandler.tryCreateExecutionEvent(
event,
_executionEventController,
);
}
/// Trigger the onPromptStart callbacks
void triggerOnPromptStart(String promptId) {
for (final callback in _eventCallbacks['onPromptStart']!) {
callback(promptId);
}
}
/// Trigger the onPromptFinished callbacks
void triggerOnPromptFinished(String promptId) {
for (final callback in _eventCallbacks['onPromptFinished']!) {
callback(promptId);
}
}
/// Trigger the onProgressChanged callbacks
void triggerOnProgressChanged(Map<String, dynamic> progressData) {
for (final callback in _progressEventCallbacks) {
callback(ProgressEvent.fromJson(progressData));
}
}
/// Closes the WebSocket connection and cleans up resources
void dispose() {
print('Disposing WebSocketManager...');
_wsChannel?.sink.close();
_progressController.close();
_eventController.close();
_progressEventController.close();
_executionEventController.close();
_executingNodeController.close();
_executionInterruptedController.close();
_previewFrameController.close();
_connectionStateController.close();
_retryAttemptController.close();
}
}