fixes to websocket type casting
Signed-off-by: Menno van Leeuwen <menno@vleeuwen.me>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import 'package:comfyui_api_sdk/comfyui_api_sdk.dart';
|
||||
import '../models/websocket_event.dart';
|
||||
import '../models/progress_event.dart';
|
||||
|
||||
/// Callback function type for prompt events
|
||||
typedef PromptEventCallback = void Function(String promptId);
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
import 'dart:async';
|
||||
import 'dart:convert';
|
||||
import 'dart:typed_data';
|
||||
|
||||
import 'package:web_socket_channel/web_socket_channel.dart';
|
||||
|
||||
// Use relative imports to avoid duplicate library instances when this package
|
||||
// is consumed via a path or symlink (prevents distinct "same" types).
|
||||
import 'models/websocket_event.dart';
|
||||
import 'models/progress_event.dart';
|
||||
import 'models/execution_event.dart';
|
||||
import 'types/callback_types.dart';
|
||||
import 'utils/websocket_event_handler.dart';
|
||||
|
||||
/// Enum representing the connection state of the WebSocket
|
||||
/// Connection states for the ComfyUI WebSocket
|
||||
enum ConnectionState { connected, connecting, disconnected, failed }
|
||||
|
||||
class WebSocketManager {
|
||||
@@ -41,17 +44,27 @@ class WebSocketManager {
|
||||
final StreamController<void> _executionInterruptedController =
|
||||
StreamController.broadcast();
|
||||
|
||||
// Optional future binary preview frames
|
||||
final StreamController<Uint8List> _previewFrameController =
|
||||
StreamController.broadcast();
|
||||
|
||||
// Event callbacks
|
||||
final Map<WebSocketEventType, List<WebSocketEventCallback>>
|
||||
_typedEventCallbacks = {
|
||||
for (var type in WebSocketEventType.values) type: [],
|
||||
for (var type in WebSocketEventType.values) type: <WebSocketEventCallback>[],
|
||||
};
|
||||
final List<ProgressEventCallback> _progressEventCallbacks = [];
|
||||
final Map<String, List<PromptEventCallback>> _eventCallbacks = {
|
||||
'onPromptStart': [],
|
||||
'onPromptFinished': [],
|
||||
'onPromptStart': <PromptEventCallback>[],
|
||||
'onPromptFinished': <PromptEventCallback>[],
|
||||
};
|
||||
|
||||
// Frame / parse statistics
|
||||
int _ignoredBinaryFrames = 0;
|
||||
int _malformedFrames = 0;
|
||||
int _previewFrames = 0;
|
||||
int _textFrames = 0;
|
||||
|
||||
WebSocketManager({required this.host, required this.clientId});
|
||||
|
||||
/// Stream of typed WebSocket events
|
||||
@@ -91,6 +104,18 @@ class WebSocketManager {
|
||||
Stream<void> get executionInterruptedStream =>
|
||||
_executionInterruptedController.stream;
|
||||
|
||||
/// Stream of (future) binary preview frames
|
||||
Stream<Uint8List> get previewFrames => _previewFrameController.stream;
|
||||
|
||||
/// Stats about received frames
|
||||
Map<String, int> get stats => {
|
||||
'ignoredBinaryFrames': _ignoredBinaryFrames,
|
||||
'malformedFrames': _malformedFrames,
|
||||
'previewFrames': _previewFrames,
|
||||
'textFrames': _textFrames,
|
||||
'retryAttempts': _retryAttempt,
|
||||
};
|
||||
|
||||
/// Register a callback for specific WebSocket event types
|
||||
void onEventType(WebSocketEventType type, WebSocketEventCallback callback) {
|
||||
_typedEventCallbacks[type]!.add(callback);
|
||||
@@ -122,18 +147,100 @@ class WebSocketManager {
|
||||
|
||||
print('WebSocket connecting to $wsUrl');
|
||||
|
||||
_wsChannel!.stream.listen((message) {
|
||||
_wsChannel!.stream.listen((dynamic message) {
|
||||
// Successfully connected
|
||||
if (_connectionState != ConnectionState.connected) {
|
||||
_updateConnectionState(ConnectionState.connected);
|
||||
_updateRetryAttempt(0); // Reset retry counter on successful connection
|
||||
}
|
||||
final jsonData = jsonDecode(message);
|
||||
|
||||
print('WebSocket message: $jsonData');
|
||||
// Determine frame type
|
||||
String? textFrame;
|
||||
if (message is String) {
|
||||
textFrame = message;
|
||||
_textFrames++;
|
||||
} else if (message is List<int>) {
|
||||
// Attempt to interpret as UTF8 JSON text
|
||||
try {
|
||||
final decoded = utf8.decode(message, allowMalformed: true);
|
||||
final trimmed = decoded.trimLeft();
|
||||
// Heuristic: treat as JSON if it starts with '{' or '['
|
||||
if (trimmed.startsWith('{') || trimmed.startsWith('[')) {
|
||||
textFrame = decoded;
|
||||
_textFrames++;
|
||||
} else {
|
||||
// Detect common image headers for potential future preview streaming
|
||||
final isJpeg =
|
||||
message.length >= 2 && message[0] == 0xFF && message[1] == 0xD8;
|
||||
final isPng = message.length >= 8 &&
|
||||
message[0] == 0x89 &&
|
||||
message[1] == 0x50 &&
|
||||
message[2] == 0x4E &&
|
||||
message[3] == 0x47 &&
|
||||
message[4] == 0x0D &&
|
||||
message[5] == 0x0A &&
|
||||
message[6] == 0x1A &&
|
||||
message[7] == 0x0A;
|
||||
|
||||
// Create a typed event
|
||||
final event = WebSocketEvent.fromJson(jsonData);
|
||||
if (isJpeg || isPng) {
|
||||
// Future: push to preview consumers
|
||||
_previewFrames++;
|
||||
// _previewFrameController.add(Uint8List.fromList(message));
|
||||
print(
|
||||
'WebSocket binary preview frame received (${message.length} bytes) ignored (preview streaming not enabled).');
|
||||
} else {
|
||||
_ignoredBinaryFrames++;
|
||||
print(
|
||||
'WebSocket non-JSON binary frame ignored (${message.length} bytes).');
|
||||
}
|
||||
return; // Do not proceed to JSON decode
|
||||
}
|
||||
} catch (_) {
|
||||
// Could not decode as UTF8
|
||||
_ignoredBinaryFrames++;
|
||||
print(
|
||||
'WebSocket binary frame ignored (UTF8 decode failed, ${message.length} bytes).');
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
_ignoredBinaryFrames++;
|
||||
print(
|
||||
'WebSocket unsupported frame type ignored (${message.runtimeType}).');
|
||||
return;
|
||||
}
|
||||
|
||||
if (textFrame == null) {
|
||||
_malformedFrames++;
|
||||
print('WebSocket frame had no decodable text content.');
|
||||
return;
|
||||
}
|
||||
|
||||
dynamic jsonData;
|
||||
try {
|
||||
jsonData = jsonDecode(textFrame);
|
||||
} catch (e) {
|
||||
_malformedFrames++;
|
||||
print('WebSocket malformed JSON frame ignored: $e');
|
||||
return;
|
||||
}
|
||||
|
||||
if (jsonData is! Map<String, dynamic>) {
|
||||
_malformedFrames++;
|
||||
print(
|
||||
'WebSocket JSON root not object (type: ${jsonData.runtimeType}) ignored.');
|
||||
return;
|
||||
}
|
||||
|
||||
print('WebSocket message (event): $jsonData');
|
||||
|
||||
WebSocketEvent event;
|
||||
try {
|
||||
event = WebSocketEvent.fromJson(jsonData);
|
||||
} catch (e) {
|
||||
_malformedFrames++;
|
||||
print('WebSocket event parse failed: $e');
|
||||
return;
|
||||
}
|
||||
|
||||
// Add to the typed event stream
|
||||
_eventController.add(event);
|
||||
@@ -142,55 +249,71 @@ class WebSocketManager {
|
||||
_progressController.add(jsonData);
|
||||
|
||||
// Convert to more specific event types if possible
|
||||
if (event.eventType == WebSocketEventType.progress) {
|
||||
_tryCreateProgressEvent(event);
|
||||
} else if ([
|
||||
WebSocketEventType.executionStart,
|
||||
WebSocketEventType.executionSuccess,
|
||||
WebSocketEventType.executionCached,
|
||||
WebSocketEventType.executed,
|
||||
WebSocketEventType.executing,
|
||||
WebSocketEventType.executionInterrupted,
|
||||
WebSocketEventType.status, // Add status event type
|
||||
].contains(event.eventType)) {
|
||||
_tryCreateExecutionEvent(event);
|
||||
}
|
||||
|
||||
// Handle "execution_interrupted" event
|
||||
if (event.eventType == WebSocketEventType.executionInterrupted) {
|
||||
_executionInterruptedController.add(null);
|
||||
}
|
||||
|
||||
// Handle "executing" event
|
||||
if (event.eventType == WebSocketEventType.executing &&
|
||||
event.data['node'] != null) {
|
||||
final nodeId = int.tryParse(event.data['node'].toString());
|
||||
if (nodeId != null) {
|
||||
_executingNodeController.add(nodeId);
|
||||
} else {
|
||||
print('Invalid node ID: ${event.data['node']}');
|
||||
try {
|
||||
if (event.eventType == WebSocketEventType.progress) {
|
||||
_tryCreateProgressEvent(event);
|
||||
} else if ([
|
||||
WebSocketEventType.executionStart,
|
||||
WebSocketEventType.executionSuccess,
|
||||
WebSocketEventType.executionCached,
|
||||
WebSocketEventType.executed,
|
||||
WebSocketEventType.executing,
|
||||
WebSocketEventType.executionInterrupted,
|
||||
WebSocketEventType.status, // Add status event type
|
||||
].contains(event.eventType)) {
|
||||
_tryCreateExecutionEvent(event);
|
||||
}
|
||||
}
|
||||
|
||||
// Trigger event type specific callbacks
|
||||
for (final callback in _typedEventCallbacks[event.eventType]!) {
|
||||
callback(event);
|
||||
}
|
||||
|
||||
// Handle execution_success event (prompt finished)
|
||||
if (event.eventType == WebSocketEventType.executionSuccess &&
|
||||
event.promptId != null) {
|
||||
final promptId = event.promptId!;
|
||||
for (final callback in _eventCallbacks['onPromptFinished']!) {
|
||||
callback(promptId);
|
||||
// Handle "execution_interrupted" event
|
||||
if (event.eventType == WebSocketEventType.executionInterrupted) {
|
||||
_executionInterruptedController.add(null);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle progress updates
|
||||
if (event.eventType == WebSocketEventType.progress) {
|
||||
for (final callback in _progressEventCallbacks) {
|
||||
callback(ProgressEvent.fromJson(event.data));
|
||||
// Handle "executing" event
|
||||
if (event.eventType == WebSocketEventType.executing &&
|
||||
event.data['node'] != null) {
|
||||
final nodeId = int.tryParse(event.data['node'].toString());
|
||||
if (nodeId != null) {
|
||||
_executingNodeController.add(nodeId);
|
||||
} else {
|
||||
print('Invalid node ID: ${event.data['node']}');
|
||||
}
|
||||
}
|
||||
|
||||
// Trigger event type specific callbacks
|
||||
for (final callback in _typedEventCallbacks[event.eventType]!) {
|
||||
try {
|
||||
callback(event);
|
||||
} catch (e, st) {
|
||||
print('WebSocket event callback error: $e\n$st');
|
||||
}
|
||||
}
|
||||
|
||||
// Handle execution_success event (prompt finished)
|
||||
if (event.eventType == WebSocketEventType.executionSuccess &&
|
||||
event.promptId != null) {
|
||||
final promptId = event.promptId!;
|
||||
for (final callback in _eventCallbacks['onPromptFinished']!) {
|
||||
try {
|
||||
callback(promptId);
|
||||
} catch (e, st) {
|
||||
print('WebSocket onPromptFinished callback error: $e\n$st');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle progress updates
|
||||
if (event.eventType == WebSocketEventType.progress) {
|
||||
for (final callback in _progressEventCallbacks) {
|
||||
try {
|
||||
callback(ProgressEvent.fromJson(event.data));
|
||||
} catch (e, st) {
|
||||
print('WebSocket progress callback error: $e\n$st');
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (e, st) {
|
||||
print('WebSocket event dispatch error: $e\n$st');
|
||||
}
|
||||
}, onError: (error) {
|
||||
print('WebSocket error: $error');
|
||||
@@ -211,7 +334,7 @@ class WebSocketManager {
|
||||
'Attempting to reconnect WebSocket (attempt $_retryAttempt/$_maxRetryAttempts) in 5 seconds...');
|
||||
_updateConnectionState(ConnectionState.connecting);
|
||||
|
||||
await Future.delayed(Duration(seconds: 5));
|
||||
await Future.delayed(const Duration(seconds: 5));
|
||||
try {
|
||||
await connect();
|
||||
print('WebSocket reconnected successfully');
|
||||
@@ -302,6 +425,7 @@ class WebSocketManager {
|
||||
_executionEventController.close();
|
||||
_executingNodeController.close();
|
||||
_executionInterruptedController.close();
|
||||
_previewFrameController.close();
|
||||
_connectionStateController.close();
|
||||
_retryAttemptController.close();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user