import 'dart:async'; import 'dart:convert'; import 'dart:typed_data'; import 'package:web_socket_channel/web_socket_channel.dart'; // Use relative imports to avoid duplicate library instances when this package // is consumed via a path or symlink (prevents distinct "same" types). import 'models/websocket_event.dart'; import 'models/progress_event.dart'; import 'models/execution_event.dart'; import 'types/callback_types.dart'; import 'utils/websocket_event_handler.dart'; /// Connection states for the ComfyUI WebSocket enum ConnectionState { connected, connecting, disconnected, failed } class WebSocketManager { final String host; final String clientId; WebSocketChannel? _wsChannel; // Connection state tracking ConnectionState _connectionState = ConnectionState.disconnected; int _retryAttempt = 0; static const int _maxRetryAttempts = 3; final StreamController _connectionStateController = StreamController.broadcast(); final StreamController _retryAttemptController = StreamController.broadcast(); // Controllers for different event streams final StreamController _eventController = StreamController.broadcast(); final StreamController> _progressController = StreamController.broadcast(); final StreamController _progressEventController = StreamController.broadcast(); final StreamController _executionEventController = StreamController.broadcast(); final StreamController _executingNodeController = StreamController.broadcast(); final StreamController _executionInterruptedController = StreamController.broadcast(); // Optional future binary preview frames final StreamController _previewFrameController = StreamController.broadcast(); // Event callbacks final Map> _typedEventCallbacks = { for (var type in WebSocketEventType.values) type: [], }; final List _progressEventCallbacks = []; final Map> _eventCallbacks = { 'onPromptStart': [], 'onPromptFinished': [], }; // Frame / parse statistics int _ignoredBinaryFrames = 0; int _malformedFrames = 0; int _previewFrames = 0; int _textFrames = 0; WebSocketManager({required this.host, required this.clientId}); /// Stream of typed WebSocket events Stream get events => _eventController.stream; /// Stream of connection state changes Stream get connectionState => _connectionStateController.stream; /// Stream of retry attempt changes Stream get retryAttemptChanges => _retryAttemptController.stream; /// Current connection state ConnectionState get currentConnectionState => _connectionState; /// Current retry attempt count int get retryAttempt => _retryAttempt; /// Maximum number of retry attempts static int get maxRetryAttempts => _maxRetryAttempts; /// Stream of progress updates (legacy format) Stream> get progressUpdates => _progressController.stream; /// Stream of typed progress events Stream get progressEvents => _progressEventController.stream; /// Stream of typed execution events Stream get executionEvents => _executionEventController.stream; /// Stream of currently executing node Stream get executingNodeStream => _executingNodeController.stream; /// Stream of execution interrupted events Stream get executionInterruptedStream => _executionInterruptedController.stream; /// Stream of (future) binary preview frames Stream get previewFrames => _previewFrameController.stream; /// Stats about received frames Map get stats => { 'ignoredBinaryFrames': _ignoredBinaryFrames, 'malformedFrames': _malformedFrames, 'previewFrames': _previewFrames, 'textFrames': _textFrames, 'retryAttempts': _retryAttempt, }; /// Register a callback for specific WebSocket event types void onEventType(WebSocketEventType type, WebSocketEventCallback callback) { _typedEventCallbacks[type]!.add(callback); } /// Register a callback for progress updates void onProgressChanged(ProgressEventCallback callback) { _progressEventCallbacks.add(callback); } /// Register a callback for when a prompt starts processing void onPromptStart(PromptEventCallback callback) { _eventCallbacks['onPromptStart']!.add(callback); } /// Register a callback for when a prompt finishes processing void onPromptFinished(PromptEventCallback callback) { _eventCallbacks['onPromptFinished']!.add(callback); } /// Connects to the WebSocket for progress updates Future connect() async { // Update connection state to connecting _updateConnectionState(ConnectionState.connecting); final wsUrl = 'ws://${host.replaceFirst(RegExp(r'^https?://'), '')}/ws?clientId=$clientId'; _wsChannel = WebSocketChannel.connect(Uri.parse(wsUrl)); print('WebSocket connecting to $wsUrl'); _wsChannel!.stream.listen((dynamic message) { // Successfully connected if (_connectionState != ConnectionState.connected) { _updateConnectionState(ConnectionState.connected); _updateRetryAttempt(0); // Reset retry counter on successful connection } // Determine frame type String? textFrame; if (message is String) { textFrame = message; _textFrames++; } else if (message is List) { // Attempt to interpret as UTF8 JSON text try { final decoded = utf8.decode(message, allowMalformed: true); final trimmed = decoded.trimLeft(); // Heuristic: treat as JSON if it starts with '{' or '[' if (trimmed.startsWith('{') || trimmed.startsWith('[')) { textFrame = decoded; _textFrames++; } else { // Detect common image headers for potential future preview streaming final isJpeg = message.length >= 2 && message[0] == 0xFF && message[1] == 0xD8; final isPng = message.length >= 8 && message[0] == 0x89 && message[1] == 0x50 && message[2] == 0x4E && message[3] == 0x47 && message[4] == 0x0D && message[5] == 0x0A && message[6] == 0x1A && message[7] == 0x0A; if (isJpeg || isPng) { // Future: push to preview consumers _previewFrames++; // _previewFrameController.add(Uint8List.fromList(message)); print( 'WebSocket binary preview frame received (${message.length} bytes) ignored (preview streaming not enabled).'); } else { _ignoredBinaryFrames++; print( 'WebSocket non-JSON binary frame ignored (${message.length} bytes).'); } return; // Do not proceed to JSON decode } } catch (_) { // Could not decode as UTF8 _ignoredBinaryFrames++; print( 'WebSocket binary frame ignored (UTF8 decode failed, ${message.length} bytes).'); return; } } else { _ignoredBinaryFrames++; print( 'WebSocket unsupported frame type ignored (${message.runtimeType}).'); return; } if (textFrame == null) { _malformedFrames++; print('WebSocket frame had no decodable text content.'); return; } dynamic jsonData; try { jsonData = jsonDecode(textFrame); } catch (e) { _malformedFrames++; print('WebSocket malformed JSON frame ignored: $e'); return; } if (jsonData is! Map) { _malformedFrames++; print( 'WebSocket JSON root not object (type: ${jsonData.runtimeType}) ignored.'); return; } print('WebSocket message (event): $jsonData'); WebSocketEvent event; try { event = WebSocketEvent.fromJson(jsonData); } catch (e) { _malformedFrames++; print('WebSocket event parse failed: $e'); return; } // Add to the typed event stream _eventController.add(event); // Also add to the legacy progress stream _progressController.add(jsonData); // Convert to more specific event types if possible try { if (event.eventType == WebSocketEventType.progress) { _tryCreateProgressEvent(event); } else if ([ WebSocketEventType.executionStart, WebSocketEventType.executionSuccess, WebSocketEventType.executionCached, WebSocketEventType.executed, WebSocketEventType.executing, WebSocketEventType.executionInterrupted, WebSocketEventType.status, // Add status event type ].contains(event.eventType)) { _tryCreateExecutionEvent(event); } // Handle "execution_interrupted" event if (event.eventType == WebSocketEventType.executionInterrupted) { _executionInterruptedController.add(null); } // Handle "executing" event if (event.eventType == WebSocketEventType.executing && event.data['node'] != null) { final nodeId = int.tryParse(event.data['node'].toString()); if (nodeId != null) { _executingNodeController.add(nodeId); } else { print('Invalid node ID: ${event.data['node']}'); } } // Trigger event type specific callbacks for (final callback in _typedEventCallbacks[event.eventType]!) { try { callback(event); } catch (e, st) { print('WebSocket event callback error: $e\n$st'); } } // Handle execution_success event (prompt finished) if (event.eventType == WebSocketEventType.executionSuccess && event.promptId != null) { final promptId = event.promptId!; for (final callback in _eventCallbacks['onPromptFinished']!) { try { callback(promptId); } catch (e, st) { print('WebSocket onPromptFinished callback error: $e\n$st'); } } } // Handle progress updates if (event.eventType == WebSocketEventType.progress) { for (final callback in _progressEventCallbacks) { try { callback(ProgressEvent.fromJson(event.data)); } catch (e, st) { print('WebSocket progress callback error: $e\n$st'); } } } } catch (e, st) { print('WebSocket event dispatch error: $e\n$st'); } }, onError: (error) { print('WebSocket error: $error'); _reconnect(); }, onDone: () { print('WebSocket connection closed'); _reconnect(); }); } /// Reconnects the WebSocket with a delay Future _reconnect() async { // Increment retry attempt and notify listeners _updateRetryAttempt(_retryAttempt + 1); if (_retryAttempt <= _maxRetryAttempts) { print( 'Attempting to reconnect WebSocket (attempt $_retryAttempt/$_maxRetryAttempts) in 5 seconds...'); _updateConnectionState(ConnectionState.connecting); await Future.delayed(const Duration(seconds: 5)); try { await connect(); print('WebSocket reconnected successfully'); } catch (e) { print('WebSocket reconnection failed: $e'); if (_retryAttempt < _maxRetryAttempts) { _reconnect(); // Retry again if under max attempts } else { // Max retries reached print('Max reconnection attempts reached'); _updateConnectionState(ConnectionState.failed); } } } else { // Max retries reached print('Max reconnection attempts reached'); _updateConnectionState(ConnectionState.failed); } } /// Updates the connection state and notifies listeners void _updateConnectionState(ConnectionState state) { _connectionState = state; _connectionStateController.add(state); } /// Updates the retry attempt count and notifies listeners void _updateRetryAttempt(int attempt) { _retryAttempt = attempt; _retryAttemptController.add(attempt); // Also re-emit the current state to ensure UI updates _connectionStateController.add(_connectionState); } /// Manually retry connection after failure Future retryConnection() async { if (_connectionState == ConnectionState.failed) { _updateRetryAttempt(0); // Reset retry counter await connect(); } } /// Attempts to create a ProgressEvent from a WebSocketEvent void _tryCreateProgressEvent(WebSocketEvent event) { WebSocketEventHandler.tryCreateProgressEvent( event, _progressEventController, _progressEventCallbacks, ); } /// Attempts to create an ExecutionEvent from a WebSocketEvent void _tryCreateExecutionEvent(WebSocketEvent event) { WebSocketEventHandler.tryCreateExecutionEvent( event, _executionEventController, ); } /// Trigger the onPromptStart callbacks void triggerOnPromptStart(String promptId) { for (final callback in _eventCallbacks['onPromptStart']!) { callback(promptId); } } /// Trigger the onPromptFinished callbacks void triggerOnPromptFinished(String promptId) { for (final callback in _eventCallbacks['onPromptFinished']!) { callback(promptId); } } /// Trigger the onProgressChanged callbacks void triggerOnProgressChanged(Map progressData) { for (final callback in _progressEventCallbacks) { callback(ProgressEvent.fromJson(progressData)); } } /// Closes the WebSocket connection and cleans up resources void dispose() { print('Disposing WebSocketManager...'); _wsChannel?.sink.close(); _progressController.close(); _eventController.close(); _progressEventController.close(); _executionEventController.close(); _executingNodeController.close(); _executionInterruptedController.close(); _previewFrameController.close(); _connectionStateController.close(); _retryAttemptController.close(); } }