fixes to websocket type casting

Signed-off-by: Menno van Leeuwen <menno@vleeuwen.me>
This commit is contained in:
2025-08-11 15:43:17 +02:00
parent 64687545ee
commit 34a930f7c8
3 changed files with 289 additions and 89 deletions

View File

@@ -1,4 +1,5 @@
import 'package:comfyui_api_sdk/comfyui_api_sdk.dart';
import '../models/websocket_event.dart';
import '../models/progress_event.dart';
/// Callback function type for prompt events
typedef PromptEventCallback = void Function(String promptId);

View File

@@ -1,15 +1,18 @@
import 'dart:async';
import 'dart:convert';
import 'dart:typed_data';
import 'package:web_socket_channel/web_socket_channel.dart';
// Use relative imports to avoid duplicate library instances when this package
// is consumed via a path or symlink (prevents distinct "same" types).
import 'models/websocket_event.dart';
import 'models/progress_event.dart';
import 'models/execution_event.dart';
import 'types/callback_types.dart';
import 'utils/websocket_event_handler.dart';
/// Enum representing the connection state of the WebSocket
/// Connection states for the ComfyUI WebSocket
enum ConnectionState { connected, connecting, disconnected, failed }
class WebSocketManager {
@@ -41,17 +44,27 @@ class WebSocketManager {
final StreamController<void> _executionInterruptedController =
StreamController.broadcast();
// Optional future binary preview frames
final StreamController<Uint8List> _previewFrameController =
StreamController.broadcast();
// Event callbacks
final Map<WebSocketEventType, List<WebSocketEventCallback>>
_typedEventCallbacks = {
for (var type in WebSocketEventType.values) type: [],
for (var type in WebSocketEventType.values) type: <WebSocketEventCallback>[],
};
final List<ProgressEventCallback> _progressEventCallbacks = [];
final Map<String, List<PromptEventCallback>> _eventCallbacks = {
'onPromptStart': [],
'onPromptFinished': [],
'onPromptStart': <PromptEventCallback>[],
'onPromptFinished': <PromptEventCallback>[],
};
// Frame / parse statistics
int _ignoredBinaryFrames = 0;
int _malformedFrames = 0;
int _previewFrames = 0;
int _textFrames = 0;
WebSocketManager({required this.host, required this.clientId});
/// Stream of typed WebSocket events
@@ -91,6 +104,18 @@ class WebSocketManager {
Stream<void> get executionInterruptedStream =>
_executionInterruptedController.stream;
/// Stream of (future) binary preview frames
Stream<Uint8List> get previewFrames => _previewFrameController.stream;
/// Stats about received frames
Map<String, int> get stats => {
'ignoredBinaryFrames': _ignoredBinaryFrames,
'malformedFrames': _malformedFrames,
'previewFrames': _previewFrames,
'textFrames': _textFrames,
'retryAttempts': _retryAttempt,
};
/// Register a callback for specific WebSocket event types
void onEventType(WebSocketEventType type, WebSocketEventCallback callback) {
_typedEventCallbacks[type]!.add(callback);
@@ -122,18 +147,100 @@ class WebSocketManager {
print('WebSocket connecting to $wsUrl');
_wsChannel!.stream.listen((message) {
_wsChannel!.stream.listen((dynamic message) {
// Successfully connected
if (_connectionState != ConnectionState.connected) {
_updateConnectionState(ConnectionState.connected);
_updateRetryAttempt(0); // Reset retry counter on successful connection
}
final jsonData = jsonDecode(message);
print('WebSocket message: $jsonData');
// Determine frame type
String? textFrame;
if (message is String) {
textFrame = message;
_textFrames++;
} else if (message is List<int>) {
// Attempt to interpret as UTF8 JSON text
try {
final decoded = utf8.decode(message, allowMalformed: true);
final trimmed = decoded.trimLeft();
// Heuristic: treat as JSON if it starts with '{' or '['
if (trimmed.startsWith('{') || trimmed.startsWith('[')) {
textFrame = decoded;
_textFrames++;
} else {
// Detect common image headers for potential future preview streaming
final isJpeg =
message.length >= 2 && message[0] == 0xFF && message[1] == 0xD8;
final isPng = message.length >= 8 &&
message[0] == 0x89 &&
message[1] == 0x50 &&
message[2] == 0x4E &&
message[3] == 0x47 &&
message[4] == 0x0D &&
message[5] == 0x0A &&
message[6] == 0x1A &&
message[7] == 0x0A;
// Create a typed event
final event = WebSocketEvent.fromJson(jsonData);
if (isJpeg || isPng) {
// Future: push to preview consumers
_previewFrames++;
// _previewFrameController.add(Uint8List.fromList(message));
print(
'WebSocket binary preview frame received (${message.length} bytes) ignored (preview streaming not enabled).');
} else {
_ignoredBinaryFrames++;
print(
'WebSocket non-JSON binary frame ignored (${message.length} bytes).');
}
return; // Do not proceed to JSON decode
}
} catch (_) {
// Could not decode as UTF8
_ignoredBinaryFrames++;
print(
'WebSocket binary frame ignored (UTF8 decode failed, ${message.length} bytes).');
return;
}
} else {
_ignoredBinaryFrames++;
print(
'WebSocket unsupported frame type ignored (${message.runtimeType}).');
return;
}
if (textFrame == null) {
_malformedFrames++;
print('WebSocket frame had no decodable text content.');
return;
}
dynamic jsonData;
try {
jsonData = jsonDecode(textFrame);
} catch (e) {
_malformedFrames++;
print('WebSocket malformed JSON frame ignored: $e');
return;
}
if (jsonData is! Map<String, dynamic>) {
_malformedFrames++;
print(
'WebSocket JSON root not object (type: ${jsonData.runtimeType}) ignored.');
return;
}
print('WebSocket message (event): $jsonData');
WebSocketEvent event;
try {
event = WebSocketEvent.fromJson(jsonData);
} catch (e) {
_malformedFrames++;
print('WebSocket event parse failed: $e');
return;
}
// Add to the typed event stream
_eventController.add(event);
@@ -142,55 +249,71 @@ class WebSocketManager {
_progressController.add(jsonData);
// Convert to more specific event types if possible
if (event.eventType == WebSocketEventType.progress) {
_tryCreateProgressEvent(event);
} else if ([
WebSocketEventType.executionStart,
WebSocketEventType.executionSuccess,
WebSocketEventType.executionCached,
WebSocketEventType.executed,
WebSocketEventType.executing,
WebSocketEventType.executionInterrupted,
WebSocketEventType.status, // Add status event type
].contains(event.eventType)) {
_tryCreateExecutionEvent(event);
}
// Handle "execution_interrupted" event
if (event.eventType == WebSocketEventType.executionInterrupted) {
_executionInterruptedController.add(null);
}
// Handle "executing" event
if (event.eventType == WebSocketEventType.executing &&
event.data['node'] != null) {
final nodeId = int.tryParse(event.data['node'].toString());
if (nodeId != null) {
_executingNodeController.add(nodeId);
} else {
print('Invalid node ID: ${event.data['node']}');
try {
if (event.eventType == WebSocketEventType.progress) {
_tryCreateProgressEvent(event);
} else if ([
WebSocketEventType.executionStart,
WebSocketEventType.executionSuccess,
WebSocketEventType.executionCached,
WebSocketEventType.executed,
WebSocketEventType.executing,
WebSocketEventType.executionInterrupted,
WebSocketEventType.status, // Add status event type
].contains(event.eventType)) {
_tryCreateExecutionEvent(event);
}
}
// Trigger event type specific callbacks
for (final callback in _typedEventCallbacks[event.eventType]!) {
callback(event);
}
// Handle execution_success event (prompt finished)
if (event.eventType == WebSocketEventType.executionSuccess &&
event.promptId != null) {
final promptId = event.promptId!;
for (final callback in _eventCallbacks['onPromptFinished']!) {
callback(promptId);
// Handle "execution_interrupted" event
if (event.eventType == WebSocketEventType.executionInterrupted) {
_executionInterruptedController.add(null);
}
}
// Handle progress updates
if (event.eventType == WebSocketEventType.progress) {
for (final callback in _progressEventCallbacks) {
callback(ProgressEvent.fromJson(event.data));
// Handle "executing" event
if (event.eventType == WebSocketEventType.executing &&
event.data['node'] != null) {
final nodeId = int.tryParse(event.data['node'].toString());
if (nodeId != null) {
_executingNodeController.add(nodeId);
} else {
print('Invalid node ID: ${event.data['node']}');
}
}
// Trigger event type specific callbacks
for (final callback in _typedEventCallbacks[event.eventType]!) {
try {
callback(event);
} catch (e, st) {
print('WebSocket event callback error: $e\n$st');
}
}
// Handle execution_success event (prompt finished)
if (event.eventType == WebSocketEventType.executionSuccess &&
event.promptId != null) {
final promptId = event.promptId!;
for (final callback in _eventCallbacks['onPromptFinished']!) {
try {
callback(promptId);
} catch (e, st) {
print('WebSocket onPromptFinished callback error: $e\n$st');
}
}
}
// Handle progress updates
if (event.eventType == WebSocketEventType.progress) {
for (final callback in _progressEventCallbacks) {
try {
callback(ProgressEvent.fromJson(event.data));
} catch (e, st) {
print('WebSocket progress callback error: $e\n$st');
}
}
}
} catch (e, st) {
print('WebSocket event dispatch error: $e\n$st');
}
}, onError: (error) {
print('WebSocket error: $error');
@@ -211,7 +334,7 @@ class WebSocketManager {
'Attempting to reconnect WebSocket (attempt $_retryAttempt/$_maxRetryAttempts) in 5 seconds...');
_updateConnectionState(ConnectionState.connecting);
await Future.delayed(Duration(seconds: 5));
await Future.delayed(const Duration(seconds: 5));
try {
await connect();
print('WebSocket reconnected successfully');
@@ -302,6 +425,7 @@ class WebSocketManager {
_executionEventController.close();
_executingNodeController.close();
_executionInterruptedController.close();
_previewFrameController.close();
_connectionStateController.close();
_retryAttemptController.close();
}