Add initial structure for comfyui_api_sdk with API models and event handling

This commit is contained in:
2025-03-20 14:27:29 +00:00
parent 1f9409ce0e
commit 0b2769310b
19 changed files with 491 additions and 1172 deletions

View File

@@ -5,22 +5,69 @@ import 'package:http/http.dart' as http;
import 'package:uuid/uuid.dart';
import 'package:web_socket_channel/web_socket_channel.dart';
import 'models/websocket_event.dart';
import 'models/progress_event.dart';
import 'models/execution_event.dart';
/// Callback function type for prompt events
typedef PromptEventCallback = void Function(String promptId);
/// Callback function type for typed WebSocket events
typedef WebSocketEventCallback = void Function(WebSocketEvent event);
/// Callback function type for progress events
typedef ProgressEventCallback = void Function(ProgressEvent event);
/// A Dart SDK for interacting with the ComfyUI API
class ComfyUiApi {
final String host;
final String clientId;
final http.Client _httpClient;
WebSocketChannel? _wsChannel;
// Controllers for different event streams
final StreamController<WebSocketEvent> _eventController =
StreamController.broadcast();
final StreamController<Map<String, dynamic>> _progressController =
StreamController.broadcast();
/// Stream of progress updates from ComfyUI
// Add new controllers for specific event types
final StreamController<ProgressEvent> _progressEventController =
StreamController.broadcast();
final StreamController<ExecutionEvent> _executionEventController =
StreamController.broadcast();
/// Stream of typed progress events
Stream<ProgressEvent> get progressEvents => _progressEventController.stream;
/// Stream of typed execution events
Stream<ExecutionEvent> get executionEvents =>
_executionEventController.stream;
// Event callbacks
final Map<String, List<PromptEventCallback>> _eventCallbacks = {
'onPromptStart': [],
'onPromptFinished': [],
};
final Map<WebSocketEventType, List<WebSocketEventCallback>>
_typedEventCallbacks = {
for (var type in WebSocketEventType.values) type: [],
};
// Add a separate map for progress event callbacks
final List<ProgressEventCallback> _progressEventCallbacks = [];
/// Stream of typed WebSocket events
Stream<WebSocketEvent> get events => _eventController.stream;
/// Stream of progress updates from ComfyUI (legacy format)
Stream<Map<String, dynamic>> get progressUpdates =>
_progressController.stream;
/// Creates a new ComfyUI API client
///
/// [host] The host of the ComfyUI server (e.g. 'http://localhost:8188')
/// [host] The host of the ComfyUI server (e.g. 'http://localhost:7860')
/// [clientId] Optional client ID, will be automatically generated if not provided
ComfyUiApi({
required this.host,
@@ -29,6 +76,26 @@ class ComfyUiApi {
}) : clientId = clientId ?? const Uuid().v4(),
_httpClient = httpClient ?? http.Client();
/// Register a callback for when a prompt starts processing
void onPromptStart(PromptEventCallback callback) {
_eventCallbacks['onPromptStart']!.add(callback);
}
/// Register a callback for when a prompt finishes processing
void onPromptFinished(PromptEventCallback callback) {
_eventCallbacks['onPromptFinished']!.add(callback);
}
/// Register a callback for specific WebSocket event types
void onEventType(WebSocketEventType type, WebSocketEventCallback callback) {
_typedEventCallbacks[type]!.add(callback);
}
/// Register a callback for progress updates
void onProgressChanged(ProgressEventCallback callback) {
_progressEventCallbacks.add(callback);
}
/// Connects to the WebSocket for progress updates
Future<void> connectWebSocket() async {
final wsUrl =
@@ -36,8 +103,43 @@ class ComfyUiApi {
_wsChannel = WebSocketChannel.connect(Uri.parse(wsUrl));
_wsChannel!.stream.listen((message) {
final data = jsonDecode(message);
_progressController.add(data);
final jsonData = jsonDecode(message);
// Create a typed event
final event = WebSocketEvent.fromJson(jsonData);
// Add to the typed event stream
_eventController.add(event);
// Also add to the legacy progress stream
_progressController.add(jsonData);
// Convert to more specific event types if possible
if (event.eventType == WebSocketEventType.progress) {
_tryCreateProgressEvent(event);
} else if ([
WebSocketEventType.executionStart,
WebSocketEventType.executionSuccess,
WebSocketEventType.executionCached,
WebSocketEventType.executed,
WebSocketEventType.executing,
].contains(event.eventType)) {
_tryCreateExecutionEvent(event);
}
// Trigger event type specific callbacks
for (final callback in _typedEventCallbacks[event.eventType]!) {
callback(event);
}
// Handle execution_success event (prompt finished)
if (event.eventType == WebSocketEventType.executionSuccess &&
event.promptId != null) {
final promptId = event.promptId!;
for (final callback in _eventCallbacks['onPromptFinished']!) {
callback(promptId);
}
}
}, onError: (error) {
print('WebSocket error: $error');
}, onDone: () {
@@ -45,10 +147,55 @@ class ComfyUiApi {
});
}
/// Attempts to create a ProgressEvent from a WebSocketEvent
void _tryCreateProgressEvent(WebSocketEvent event) {
if (event.data.containsKey('value') &&
event.data.containsKey('max') &&
event.data.containsKey('prompt_id')) {
try {
final progressEvent = ProgressEvent(
value: event.data['value'] as int,
max: event.data['max'] as int,
promptId: event.data['prompt_id'] as String,
node: event.data['node']?.toString(),
);
_progressEventController.add(progressEvent);
// Trigger all registered progress callbacks
for (final callback in _progressEventCallbacks) {
callback(progressEvent);
}
} catch (e) {
print('Error creating ProgressEvent: $e');
}
}
}
/// Attempts to create an ExecutionEvent from a WebSocketEvent
void _tryCreateExecutionEvent(WebSocketEvent event) {
if (event.data.containsKey('prompt_id')) {
try {
final executionEvent = ExecutionEvent(
promptId: event.data['prompt_id'] as String,
timestamp: event.data['timestamp'] as int? ??
DateTime.now().millisecondsSinceEpoch,
node: event.data['node']?.toString(),
extra: event.data['extra'] as Map<String, dynamic>?,
);
_executionEventController.add(executionEvent);
} catch (e) {
print('Error creating ExecutionEvent: $e');
}
}
}
/// Closes the WebSocket connection and cleans up resources
void dispose() {
_wsChannel?.sink.close();
_progressController.close();
_eventController.close();
_progressEventController.close();
_executionEventController.close();
_httpClient.close();
}
@@ -181,7 +328,17 @@ class ComfyUiApi {
body: jsonEncode(prompt),
);
_validateResponse(response);
return jsonDecode(response.body);
final responseData = jsonDecode(response.body);
// Trigger onPromptStart event if prompt_id exists
if (responseData.containsKey('prompt_id')) {
final promptId = responseData['prompt_id'];
for (final callback in _eventCallbacks['onPromptStart']!) {
callback(promptId);
}
}
return responseData;
}
/// Validates HTTP response and throws an exception if needed

View File

@@ -0,0 +1,7 @@
// Main API
export 'comfyui_api.dart';
// Models
export 'models/websocket_event.dart';
export 'models/progress_event.dart';
export 'models/execution_event.dart';

View File

@@ -0,0 +1,11 @@
import 'websocket_event.dart';
import 'progress_event.dart';
/// Callback function type for prompt events
typedef PromptEventCallback = void Function(String promptId);
/// Callback function type for typed WebSocket events
typedef WebSocketEventCallback = void Function(WebSocketEvent event);
/// Callback function type for progress events
typedef ProgressEventCallback = void Function(ProgressEvent event);

View File

@@ -0,0 +1,31 @@
class ExecutionEvent {
final String promptId;
final int timestamp;
final String? node;
final Map<String, dynamic>? extra;
const ExecutionEvent({
required this.promptId,
required this.timestamp,
this.node,
this.extra,
});
factory ExecutionEvent.fromJson(Map<String, dynamic> json) {
return ExecutionEvent(
promptId: json['prompt_id'] as String,
timestamp: json['timestamp'] as int,
node: json['node'] as String?,
extra: json['extra'] as Map<String, dynamic>?,
);
}
Map<String, dynamic> toJson() {
return {
'prompt_id': promptId,
'timestamp': timestamp,
'node': node,
'extra': extra,
};
}
}

View File

@@ -0,0 +1,39 @@
class ProgressEvent {
final int value;
final int max;
final String promptId;
final String? node;
const ProgressEvent({
required this.value,
required this.max,
required this.promptId,
this.node,
});
factory ProgressEvent.fromJson(Map<String, dynamic> json) {
return ProgressEvent(
value: json['value'] as int,
max: json['max'] as int,
promptId: json['prompt_id'] as String,
node: json['node'] as String?,
);
}
Map<String, dynamic> toJson() {
return {
'value': value,
'max': max,
'prompt_id': promptId,
'node': node,
};
}
}
extension ProgressEventComputation on ProgressEvent {
/// Returns the progress percentage (0-100)
double get percentage => (value / max) * 100;
/// Returns true if the progress is complete
bool get isComplete => value >= max;
}

View File

@@ -0,0 +1,109 @@
/// Types of WebSocket events from ComfyUI
enum WebSocketEventType {
status,
progress,
executing,
executed,
executionStart,
executionCached,
executionSuccess,
executionError,
dataOutput,
unknown
}
/// A typed event from the ComfyUI WebSocket
class WebSocketEvent {
final WebSocketEventType eventType;
final Map<String, dynamic> data;
final String? promptId;
final String rawType;
WebSocketEvent({
required this.eventType,
required this.data,
this.promptId,
required this.rawType,
});
@override
String toString() =>
'WebSocketEvent{type: $eventType, rawType: $rawType, promptId: $promptId, data: ${data.keys.join(', ')}}';
/// Creates a WebSocketEvent from JSON
factory WebSocketEvent.fromJson(Map<String, dynamic> json) {
WebSocketEventType type;
Map<String, dynamic> eventData = {};
String? promptId;
String rawType = json['type']?.toString() ?? 'unknown';
// Extract event type
if (json.containsKey('type')) {
final typeStr = json['type'].toString();
// First, try to extract prompt_id from the data field
if (json.containsKey('data') && json['data'] is Map) {
final dataMap = Map<String, dynamic>.from(json['data'] as Map);
if (dataMap.containsKey('prompt_id')) {
promptId = dataMap['prompt_id'].toString();
}
eventData = dataMap;
} else {
// If no data field, use the root object
eventData = Map<String, dynamic>.from(json);
}
// Try to extract prompt_id from other potential locations
if (promptId == null && json.containsKey('prompt_id')) {
promptId = json['prompt_id'].toString();
}
switch (typeStr) {
case 'status':
type = WebSocketEventType.status;
break;
case 'progress':
type = WebSocketEventType.progress;
break;
case 'executing':
type = WebSocketEventType.executing;
break;
case 'executed':
type = WebSocketEventType.executed;
break;
case 'execution_start':
type = WebSocketEventType.executionStart;
break;
case 'execution_cached':
type = WebSocketEventType.executionCached;
break;
case 'execution_success':
type = WebSocketEventType.executionSuccess;
break;
case 'execution_error':
type = WebSocketEventType.executionError;
break;
default:
if (typeStr.startsWith('data_output')) {
type = WebSocketEventType.dataOutput;
} else {
// Default to unknown for unrecognized types
type = WebSocketEventType.unknown;
print('Unknown event type: $typeStr');
}
}
} else {
// If no type field, default to unknown
type = WebSocketEventType.unknown;
print('WebSocket event missing type field');
eventData = Map<String, dynamic>.from(json);
}
return WebSocketEvent(
eventType: type,
data: eventData,
promptId: promptId,
rawType: rawType,
);
}
}