433 lines
14 KiB
Dart
433 lines
14 KiB
Dart
import 'dart:async';
|
|
import 'dart:convert';
|
|
import 'dart:typed_data';
|
|
|
|
import 'package:web_socket_channel/web_socket_channel.dart';
|
|
|
|
// Use relative imports to avoid duplicate library instances when this package
|
|
// is consumed via a path or symlink (prevents distinct "same" types).
|
|
import 'models/websocket_event.dart';
|
|
import 'models/progress_event.dart';
|
|
import 'models/execution_event.dart';
|
|
import 'types/callback_types.dart';
|
|
import 'utils/websocket_event_handler.dart';
|
|
|
|
/// Connection states for the ComfyUI WebSocket
|
|
enum ConnectionState { connected, connecting, disconnected, failed }
|
|
|
|
class WebSocketManager {
|
|
final String host;
|
|
final String clientId;
|
|
|
|
WebSocketChannel? _wsChannel;
|
|
|
|
// Connection state tracking
|
|
ConnectionState _connectionState = ConnectionState.disconnected;
|
|
int _retryAttempt = 0;
|
|
static const int _maxRetryAttempts = 3;
|
|
final StreamController<ConnectionState> _connectionStateController =
|
|
StreamController.broadcast();
|
|
final StreamController<int> _retryAttemptController =
|
|
StreamController.broadcast();
|
|
|
|
// Controllers for different event streams
|
|
final StreamController<WebSocketEvent> _eventController =
|
|
StreamController.broadcast();
|
|
final StreamController<Map<String, dynamic>> _progressController =
|
|
StreamController.broadcast();
|
|
final StreamController<ProgressEvent> _progressEventController =
|
|
StreamController.broadcast();
|
|
final StreamController<ExecutionEvent> _executionEventController =
|
|
StreamController.broadcast();
|
|
final StreamController<int> _executingNodeController =
|
|
StreamController.broadcast();
|
|
final StreamController<void> _executionInterruptedController =
|
|
StreamController.broadcast();
|
|
|
|
// Optional future binary preview frames
|
|
final StreamController<Uint8List> _previewFrameController =
|
|
StreamController.broadcast();
|
|
|
|
// Event callbacks
|
|
final Map<WebSocketEventType, List<WebSocketEventCallback>>
|
|
_typedEventCallbacks = {
|
|
for (var type in WebSocketEventType.values) type: <WebSocketEventCallback>[],
|
|
};
|
|
final List<ProgressEventCallback> _progressEventCallbacks = [];
|
|
final Map<String, List<PromptEventCallback>> _eventCallbacks = {
|
|
'onPromptStart': <PromptEventCallback>[],
|
|
'onPromptFinished': <PromptEventCallback>[],
|
|
};
|
|
|
|
// Frame / parse statistics
|
|
int _ignoredBinaryFrames = 0;
|
|
int _malformedFrames = 0;
|
|
int _previewFrames = 0;
|
|
int _textFrames = 0;
|
|
|
|
WebSocketManager({required this.host, required this.clientId});
|
|
|
|
/// Stream of typed WebSocket events
|
|
Stream<WebSocketEvent> get events => _eventController.stream;
|
|
|
|
/// Stream of connection state changes
|
|
Stream<ConnectionState> get connectionState =>
|
|
_connectionStateController.stream;
|
|
|
|
/// Stream of retry attempt changes
|
|
Stream<int> get retryAttemptChanges => _retryAttemptController.stream;
|
|
|
|
/// Current connection state
|
|
ConnectionState get currentConnectionState => _connectionState;
|
|
|
|
/// Current retry attempt count
|
|
int get retryAttempt => _retryAttempt;
|
|
|
|
/// Maximum number of retry attempts
|
|
static int get maxRetryAttempts => _maxRetryAttempts;
|
|
|
|
/// Stream of progress updates (legacy format)
|
|
Stream<Map<String, dynamic>> get progressUpdates =>
|
|
_progressController.stream;
|
|
|
|
/// Stream of typed progress events
|
|
Stream<ProgressEvent> get progressEvents => _progressEventController.stream;
|
|
|
|
/// Stream of typed execution events
|
|
Stream<ExecutionEvent> get executionEvents =>
|
|
_executionEventController.stream;
|
|
|
|
/// Stream of currently executing node
|
|
Stream<int> get executingNodeStream => _executingNodeController.stream;
|
|
|
|
/// Stream of execution interrupted events
|
|
Stream<void> get executionInterruptedStream =>
|
|
_executionInterruptedController.stream;
|
|
|
|
/// Stream of (future) binary preview frames
|
|
Stream<Uint8List> get previewFrames => _previewFrameController.stream;
|
|
|
|
/// Stats about received frames
|
|
Map<String, int> get stats => {
|
|
'ignoredBinaryFrames': _ignoredBinaryFrames,
|
|
'malformedFrames': _malformedFrames,
|
|
'previewFrames': _previewFrames,
|
|
'textFrames': _textFrames,
|
|
'retryAttempts': _retryAttempt,
|
|
};
|
|
|
|
/// Register a callback for specific WebSocket event types
|
|
void onEventType(WebSocketEventType type, WebSocketEventCallback callback) {
|
|
_typedEventCallbacks[type]!.add(callback);
|
|
}
|
|
|
|
/// Register a callback for progress updates
|
|
void onProgressChanged(ProgressEventCallback callback) {
|
|
_progressEventCallbacks.add(callback);
|
|
}
|
|
|
|
/// Register a callback for when a prompt starts processing
|
|
void onPromptStart(PromptEventCallback callback) {
|
|
_eventCallbacks['onPromptStart']!.add(callback);
|
|
}
|
|
|
|
/// Register a callback for when a prompt finishes processing
|
|
void onPromptFinished(PromptEventCallback callback) {
|
|
_eventCallbacks['onPromptFinished']!.add(callback);
|
|
}
|
|
|
|
/// Connects to the WebSocket for progress updates
|
|
Future<void> connect() async {
|
|
// Update connection state to connecting
|
|
_updateConnectionState(ConnectionState.connecting);
|
|
|
|
final wsUrl =
|
|
'ws://${host.replaceFirst(RegExp(r'^https?://'), '')}/ws?clientId=$clientId';
|
|
_wsChannel = WebSocketChannel.connect(Uri.parse(wsUrl));
|
|
|
|
print('WebSocket connecting to $wsUrl');
|
|
|
|
_wsChannel!.stream.listen((dynamic message) {
|
|
// Successfully connected
|
|
if (_connectionState != ConnectionState.connected) {
|
|
_updateConnectionState(ConnectionState.connected);
|
|
_updateRetryAttempt(0); // Reset retry counter on successful connection
|
|
}
|
|
|
|
// Determine frame type
|
|
String? textFrame;
|
|
if (message is String) {
|
|
textFrame = message;
|
|
_textFrames++;
|
|
} else if (message is List<int>) {
|
|
// Attempt to interpret as UTF8 JSON text
|
|
try {
|
|
final decoded = utf8.decode(message, allowMalformed: true);
|
|
final trimmed = decoded.trimLeft();
|
|
// Heuristic: treat as JSON if it starts with '{' or '['
|
|
if (trimmed.startsWith('{') || trimmed.startsWith('[')) {
|
|
textFrame = decoded;
|
|
_textFrames++;
|
|
} else {
|
|
// Detect common image headers for potential future preview streaming
|
|
final isJpeg =
|
|
message.length >= 2 && message[0] == 0xFF && message[1] == 0xD8;
|
|
final isPng = message.length >= 8 &&
|
|
message[0] == 0x89 &&
|
|
message[1] == 0x50 &&
|
|
message[2] == 0x4E &&
|
|
message[3] == 0x47 &&
|
|
message[4] == 0x0D &&
|
|
message[5] == 0x0A &&
|
|
message[6] == 0x1A &&
|
|
message[7] == 0x0A;
|
|
|
|
if (isJpeg || isPng) {
|
|
// Future: push to preview consumers
|
|
_previewFrames++;
|
|
// _previewFrameController.add(Uint8List.fromList(message));
|
|
print(
|
|
'WebSocket binary preview frame received (${message.length} bytes) ignored (preview streaming not enabled).');
|
|
} else {
|
|
_ignoredBinaryFrames++;
|
|
print(
|
|
'WebSocket non-JSON binary frame ignored (${message.length} bytes).');
|
|
}
|
|
return; // Do not proceed to JSON decode
|
|
}
|
|
} catch (_) {
|
|
// Could not decode as UTF8
|
|
_ignoredBinaryFrames++;
|
|
print(
|
|
'WebSocket binary frame ignored (UTF8 decode failed, ${message.length} bytes).');
|
|
return;
|
|
}
|
|
} else {
|
|
_ignoredBinaryFrames++;
|
|
print(
|
|
'WebSocket unsupported frame type ignored (${message.runtimeType}).');
|
|
return;
|
|
}
|
|
|
|
if (textFrame == null) {
|
|
_malformedFrames++;
|
|
print('WebSocket frame had no decodable text content.');
|
|
return;
|
|
}
|
|
|
|
dynamic jsonData;
|
|
try {
|
|
jsonData = jsonDecode(textFrame);
|
|
} catch (e) {
|
|
_malformedFrames++;
|
|
print('WebSocket malformed JSON frame ignored: $e');
|
|
return;
|
|
}
|
|
|
|
if (jsonData is! Map<String, dynamic>) {
|
|
_malformedFrames++;
|
|
print(
|
|
'WebSocket JSON root not object (type: ${jsonData.runtimeType}) ignored.');
|
|
return;
|
|
}
|
|
|
|
print('WebSocket message (event): $jsonData');
|
|
|
|
WebSocketEvent event;
|
|
try {
|
|
event = WebSocketEvent.fromJson(jsonData);
|
|
} catch (e) {
|
|
_malformedFrames++;
|
|
print('WebSocket event parse failed: $e');
|
|
return;
|
|
}
|
|
|
|
// Add to the typed event stream
|
|
_eventController.add(event);
|
|
|
|
// Also add to the legacy progress stream
|
|
_progressController.add(jsonData);
|
|
|
|
// Convert to more specific event types if possible
|
|
try {
|
|
if (event.eventType == WebSocketEventType.progress) {
|
|
_tryCreateProgressEvent(event);
|
|
} else if ([
|
|
WebSocketEventType.executionStart,
|
|
WebSocketEventType.executionSuccess,
|
|
WebSocketEventType.executionCached,
|
|
WebSocketEventType.executed,
|
|
WebSocketEventType.executing,
|
|
WebSocketEventType.executionInterrupted,
|
|
WebSocketEventType.status, // Add status event type
|
|
].contains(event.eventType)) {
|
|
_tryCreateExecutionEvent(event);
|
|
}
|
|
|
|
// Handle "execution_interrupted" event
|
|
if (event.eventType == WebSocketEventType.executionInterrupted) {
|
|
_executionInterruptedController.add(null);
|
|
}
|
|
|
|
// Handle "executing" event
|
|
if (event.eventType == WebSocketEventType.executing &&
|
|
event.data['node'] != null) {
|
|
final nodeId = int.tryParse(event.data['node'].toString());
|
|
if (nodeId != null) {
|
|
_executingNodeController.add(nodeId);
|
|
} else {
|
|
print('Invalid node ID: ${event.data['node']}');
|
|
}
|
|
}
|
|
|
|
// Trigger event type specific callbacks
|
|
for (final callback in _typedEventCallbacks[event.eventType]!) {
|
|
try {
|
|
callback(event);
|
|
} catch (e, st) {
|
|
print('WebSocket event callback error: $e\n$st');
|
|
}
|
|
}
|
|
|
|
// Handle execution_success event (prompt finished)
|
|
if (event.eventType == WebSocketEventType.executionSuccess &&
|
|
event.promptId != null) {
|
|
final promptId = event.promptId!;
|
|
for (final callback in _eventCallbacks['onPromptFinished']!) {
|
|
try {
|
|
callback(promptId);
|
|
} catch (e, st) {
|
|
print('WebSocket onPromptFinished callback error: $e\n$st');
|
|
}
|
|
}
|
|
}
|
|
|
|
// Handle progress updates
|
|
if (event.eventType == WebSocketEventType.progress) {
|
|
for (final callback in _progressEventCallbacks) {
|
|
try {
|
|
callback(ProgressEvent.fromJson(event.data));
|
|
} catch (e, st) {
|
|
print('WebSocket progress callback error: $e\n$st');
|
|
}
|
|
}
|
|
}
|
|
} catch (e, st) {
|
|
print('WebSocket event dispatch error: $e\n$st');
|
|
}
|
|
}, onError: (error) {
|
|
print('WebSocket error: $error');
|
|
_reconnect();
|
|
}, onDone: () {
|
|
print('WebSocket connection closed');
|
|
_reconnect();
|
|
});
|
|
}
|
|
|
|
/// Reconnects the WebSocket with a delay
|
|
Future<void> _reconnect() async {
|
|
// Increment retry attempt and notify listeners
|
|
_updateRetryAttempt(_retryAttempt + 1);
|
|
|
|
if (_retryAttempt <= _maxRetryAttempts) {
|
|
print(
|
|
'Attempting to reconnect WebSocket (attempt $_retryAttempt/$_maxRetryAttempts) in 5 seconds...');
|
|
_updateConnectionState(ConnectionState.connecting);
|
|
|
|
await Future.delayed(const Duration(seconds: 5));
|
|
try {
|
|
await connect();
|
|
print('WebSocket reconnected successfully');
|
|
} catch (e) {
|
|
print('WebSocket reconnection failed: $e');
|
|
if (_retryAttempt < _maxRetryAttempts) {
|
|
_reconnect(); // Retry again if under max attempts
|
|
} else {
|
|
// Max retries reached
|
|
print('Max reconnection attempts reached');
|
|
_updateConnectionState(ConnectionState.failed);
|
|
}
|
|
}
|
|
} else {
|
|
// Max retries reached
|
|
print('Max reconnection attempts reached');
|
|
_updateConnectionState(ConnectionState.failed);
|
|
}
|
|
}
|
|
|
|
/// Updates the connection state and notifies listeners
|
|
void _updateConnectionState(ConnectionState state) {
|
|
_connectionState = state;
|
|
_connectionStateController.add(state);
|
|
}
|
|
|
|
/// Updates the retry attempt count and notifies listeners
|
|
void _updateRetryAttempt(int attempt) {
|
|
_retryAttempt = attempt;
|
|
_retryAttemptController.add(attempt);
|
|
// Also re-emit the current state to ensure UI updates
|
|
_connectionStateController.add(_connectionState);
|
|
}
|
|
|
|
/// Manually retry connection after failure
|
|
Future<void> retryConnection() async {
|
|
if (_connectionState == ConnectionState.failed) {
|
|
_updateRetryAttempt(0); // Reset retry counter
|
|
await connect();
|
|
}
|
|
}
|
|
|
|
/// Attempts to create a ProgressEvent from a WebSocketEvent
|
|
void _tryCreateProgressEvent(WebSocketEvent event) {
|
|
WebSocketEventHandler.tryCreateProgressEvent(
|
|
event,
|
|
_progressEventController,
|
|
_progressEventCallbacks,
|
|
);
|
|
}
|
|
|
|
/// Attempts to create an ExecutionEvent from a WebSocketEvent
|
|
void _tryCreateExecutionEvent(WebSocketEvent event) {
|
|
WebSocketEventHandler.tryCreateExecutionEvent(
|
|
event,
|
|
_executionEventController,
|
|
);
|
|
}
|
|
|
|
/// Trigger the onPromptStart callbacks
|
|
void triggerOnPromptStart(String promptId) {
|
|
for (final callback in _eventCallbacks['onPromptStart']!) {
|
|
callback(promptId);
|
|
}
|
|
}
|
|
|
|
/// Trigger the onPromptFinished callbacks
|
|
void triggerOnPromptFinished(String promptId) {
|
|
for (final callback in _eventCallbacks['onPromptFinished']!) {
|
|
callback(promptId);
|
|
}
|
|
}
|
|
|
|
/// Trigger the onProgressChanged callbacks
|
|
void triggerOnProgressChanged(Map<String, dynamic> progressData) {
|
|
for (final callback in _progressEventCallbacks) {
|
|
callback(ProgressEvent.fromJson(progressData));
|
|
}
|
|
}
|
|
|
|
/// Closes the WebSocket connection and cleans up resources
|
|
void dispose() {
|
|
print('Disposing WebSocketManager...');
|
|
_wsChannel?.sink.close();
|
|
_progressController.close();
|
|
_eventController.close();
|
|
_progressEventController.close();
|
|
_executionEventController.close();
|
|
_executingNodeController.close();
|
|
_executionInterruptedController.close();
|
|
_previewFrameController.close();
|
|
_connectionStateController.close();
|
|
_retryAttemptController.close();
|
|
}
|
|
}
|