diff --git a/lib/src/comfyui_api.dart b/lib/src/comfyui_api.dart index 98d76e7..dc0bf1f 100644 --- a/lib/src/comfyui_api.dart +++ b/lib/src/comfyui_api.dart @@ -8,15 +8,13 @@ import 'package:web_socket_channel/web_socket_channel.dart'; import 'models/websocket_event.dart'; import 'models/progress_event.dart'; import 'models/execution_event.dart'; - -/// Callback function type for prompt events -typedef PromptEventCallback = void Function(String promptId); - -/// Callback function type for typed WebSocket events -typedef WebSocketEventCallback = void Function(WebSocketEvent event); - -/// Callback function type for progress events -typedef ProgressEventCallback = void Function(ProgressEvent event); +import 'exceptions/comfyui_api_exception.dart'; +import 'types/callback_types.dart'; +import 'utils/websocket_event_handler.dart'; +import 'models/history_response.dart'; +import 'models/checkpoint.dart'; +import 'models/vae.dart'; +import 'models/lora.dart'; /// A Dart SDK for interacting with the ComfyUI API class ComfyUiApi { @@ -149,44 +147,19 @@ class ComfyUiApi { /// Attempts to create a ProgressEvent from a WebSocketEvent void _tryCreateProgressEvent(WebSocketEvent event) { - if (event.data.containsKey('value') && - event.data.containsKey('max') && - event.data.containsKey('prompt_id')) { - try { - final progressEvent = ProgressEvent( - value: event.data['value'] as int, - max: event.data['max'] as int, - promptId: event.data['prompt_id'] as String, - node: event.data['node']?.toString(), - ); - _progressEventController.add(progressEvent); - - // Trigger all registered progress callbacks - for (final callback in _progressEventCallbacks) { - callback(progressEvent); - } - } catch (e) { - print('Error creating ProgressEvent: $e'); - } - } + WebSocketEventHandler.tryCreateProgressEvent( + event, + _progressEventController, + _progressEventCallbacks, + ); } /// Attempts to create an ExecutionEvent from a WebSocketEvent void _tryCreateExecutionEvent(WebSocketEvent event) { - if (event.data.containsKey('prompt_id')) { - try { - final executionEvent = ExecutionEvent( - promptId: event.data['prompt_id'] as String, - timestamp: event.data['timestamp'] as int? ?? - DateTime.now().millisecondsSinceEpoch, - node: event.data['node']?.toString(), - extra: event.data['extra'] as Map?, - ); - _executionEventController.add(executionEvent); - } catch (e) { - print('Error creating ExecutionEvent: $e'); - } - } + WebSocketEventHandler.tryCreateExecutionEvent( + event, + _executionEventController, + ); } /// Closes the WebSocket connection and cleans up resources @@ -207,11 +180,11 @@ class ComfyUiApi { } /// Gets the history of the queue - Future> getHistory({int maxItems = 64}) async { + Future getHistory({int maxItems = 64}) async { final response = await _httpClient .get(Uri.parse('$host/api/history?max_items=$maxItems')); _validateResponse(response); - return jsonDecode(response.body); + return HistoryResponse.fromJson(jsonDecode(response.body)); } /// Gets image data by filename @@ -231,11 +204,12 @@ class ComfyUiApi { } /// Gets a list of checkpoints - Future> getCheckpoints() async { + Future> getCheckpoints() async { final response = await _httpClient .get(Uri.parse('$host/api/experiment/models/checkpoints')); _validateResponse(response); - return jsonDecode(response.body); + final List jsonData = jsonDecode(response.body); + return jsonData.map((item) => Checkpoint.fromJson(item)).toList(); } /// Gets details for a specific checkpoint @@ -248,11 +222,12 @@ class ComfyUiApi { } /// Gets a list of LoRAs - Future> getLoras() async { + Future> getLoras() async { final response = await _httpClient.get(Uri.parse('$host/api/experiment/models/loras')); _validateResponse(response); - return jsonDecode(response.body); + final List jsonData = jsonDecode(response.body); + return jsonData.map((item) => Lora.fromJson(item)).toList(); } /// Gets details for a specific LoRA @@ -264,11 +239,12 @@ class ComfyUiApi { } /// Gets a list of VAEs - Future> getVaes() async { + Future> getVaes() async { final response = await _httpClient.get(Uri.parse('$host/api/experiment/models/vae')); _validateResponse(response); - return jsonDecode(response.body); + final List jsonData = jsonDecode(response.body); + return jsonData.map((item) => Vae.fromJson(item)).toList(); } /// Gets details for a specific VAE @@ -350,14 +326,3 @@ class ComfyUiApi { } } } - -/// Exception thrown when the ComfyUI API returns an error -class ComfyUiApiException implements Exception { - final int statusCode; - final String message; - - ComfyUiApiException({required this.statusCode, required this.message}); - - @override - String toString() => 'ComfyUiApiException: $statusCode - $message'; -} diff --git a/lib/src/exceptions/comfyui_api_exception.dart b/lib/src/exceptions/comfyui_api_exception.dart new file mode 100644 index 0000000..4c9906d --- /dev/null +++ b/lib/src/exceptions/comfyui_api_exception.dart @@ -0,0 +1,10 @@ +/// Exception thrown when the ComfyUI API returns an error +class ComfyUiApiException implements Exception { + final int statusCode; + final String message; + + ComfyUiApiException({required this.statusCode, required this.message}); + + @override + String toString() => 'ComfyUiApiException: $statusCode - $message'; +} diff --git a/lib/src/models/base_model.dart b/lib/src/models/base_model.dart new file mode 100644 index 0000000..c8e7a5e --- /dev/null +++ b/lib/src/models/base_model.dart @@ -0,0 +1,6 @@ +abstract class BaseModel { + final String name; + final int pathIndex; + + BaseModel({required this.name, required this.pathIndex}); +} diff --git a/lib/src/models/checkpoint.dart b/lib/src/models/checkpoint.dart new file mode 100644 index 0000000..e21b177 --- /dev/null +++ b/lib/src/models/checkpoint.dart @@ -0,0 +1,13 @@ +import 'base_model.dart'; + +class Checkpoint extends BaseModel { + Checkpoint({required String name, required int pathIndex}) + : super(name: name, pathIndex: pathIndex); + + factory Checkpoint.fromJson(Map json) { + return Checkpoint( + name: json['name'] as String, + pathIndex: json['pathIndex'] as int, + ); + } +} diff --git a/lib/src/models/history_response.dart b/lib/src/models/history_response.dart new file mode 100644 index 0000000..d0b87c2 --- /dev/null +++ b/lib/src/models/history_response.dart @@ -0,0 +1,228 @@ +class HistoryResponse { + final Map items; + + HistoryResponse({required this.items}); + + factory HistoryResponse.fromJson(Map json) { + return HistoryResponse( + items: + json.map((key, value) => MapEntry(key, HistoryItem.fromJson(value))), + ); + } +} + +class HistoryItem { + final Prompt prompt; + final Outputs outputs; + final Status status; + final Map meta; + + HistoryItem({ + required this.prompt, + required this.outputs, + required this.status, + required this.meta, + }); + + factory HistoryItem.fromJson(Map json) { + return HistoryItem( + prompt: Prompt.fromJson(json['prompt']), + outputs: Outputs.fromJson(json['outputs']), + status: Status.fromJson(json['status']), + meta: (json['meta'] as Map).map( + (key, value) => MapEntry(key, Meta.fromJson(value)), + ), + ); + } +} + +class Prompt { + final int id; + final String promptId; + final Map nodes; + final ExtraPngInfo extraPngInfo; + + Prompt({ + required this.id, + required this.promptId, + required this.nodes, + required this.extraPngInfo, + }); + + factory Prompt.fromJson(List json) { + return Prompt( + id: json[0] as int, + promptId: json[1] as String, + nodes: (json[2] as Map).map( + (key, value) => MapEntry(key, Node.fromJson(value)), + ), + extraPngInfo: ExtraPngInfo.fromJson(json[3]['extra_pnginfo']), + ); + } +} + +class Node { + final Map inputs; + final String classType; + final Meta meta; + + Node({ + required this.inputs, + required this.classType, + required this.meta, + }); + + factory Node.fromJson(Map json) { + return Node( + inputs: json['inputs'] as Map, + classType: json['class_type'] as String, + meta: Meta.fromJson(json['_meta']), + ); + } +} + +class ExtraPngInfo { + final Workflow workflow; + + ExtraPngInfo({required this.workflow}); + + factory ExtraPngInfo.fromJson(Map json) { + return ExtraPngInfo( + workflow: Workflow.fromJson(json['workflow']), + ); + } +} + +class Workflow { + final List nodes; + final List links; + + Workflow({ + required this.nodes, + required this.links, + }); + + factory Workflow.fromJson(Map json) { + return Workflow( + nodes: (json['nodes'] as List) + .map((e) => NodeInfo.fromJson(e)) + .toList(), + links: (json['links'] as List) + .map((e) => Link.fromJson(e)) + .toList(), + ); + } +} + +class NodeInfo { + final int id; + final String type; + + NodeInfo({ + required this.id, + required this.type, + }); + + factory NodeInfo.fromJson(Map json) { + return NodeInfo( + id: json['id'] as int, + type: json['type'] as String, + ); + } +} + +class Link { + final int id; + final int sourceNodeId; + final int targetNodeId; + + Link({ + required this.id, + required this.sourceNodeId, + required this.targetNodeId, + }); + + factory Link.fromJson(List json) { + return Link( + id: json[0] as int, + sourceNodeId: json[1] as int, + targetNodeId: json[2] as int, + ); + } +} + +class Outputs { + final Map nodes; + + Outputs({required this.nodes}); + + factory Outputs.fromJson(Map json) { + return Outputs( + nodes: + json.map((key, value) => MapEntry(key, OutputNode.fromJson(value))), + ); + } +} + +class OutputNode { + final List images; + + OutputNode({required this.images}); + + factory OutputNode.fromJson(Map json) { + return OutputNode( + images: (json['images'] as List) + .map((e) => Image.fromJson(e)) + .toList(), + ); + } +} + +class Image { + final String filename; + final String subfolder; + final String type; + + Image({ + required this.filename, + required this.subfolder, + required this.type, + }); + + factory Image.fromJson(Map json) { + return Image( + filename: json['filename'] as String, + subfolder: json['subfolder'] as String, + type: json['type'] as String, + ); + } +} + +class Status { + final String statusStr; + final bool completed; + + Status({ + required this.statusStr, + required this.completed, + }); + + factory Status.fromJson(Map json) { + return Status( + statusStr: json['status_str'] as String, + completed: json['completed'] as bool, + ); + } +} + +class Meta { + final String? nodeId; + + Meta({this.nodeId}); + + factory Meta.fromJson(Map json) { + return Meta( + nodeId: json['node_id'] as String?, + ); + } +} diff --git a/lib/src/models/lora.dart b/lib/src/models/lora.dart new file mode 100644 index 0000000..7a21778 --- /dev/null +++ b/lib/src/models/lora.dart @@ -0,0 +1,13 @@ +import 'base_model.dart'; + +class Lora extends BaseModel { + Lora({required String name, required int pathIndex}) + : super(name: name, pathIndex: pathIndex); + + factory Lora.fromJson(Map json) { + return Lora( + name: json['name'] as String, + pathIndex: json['pathIndex'] as int, + ); + } +} diff --git a/lib/src/models/models.dart b/lib/src/models/models.dart deleted file mode 100644 index 11d391d..0000000 --- a/lib/src/models/models.dart +++ /dev/null @@ -1,87 +0,0 @@ -/// Models that represent ComfyUI API responses - -/// Represents queue information from ComfyUI -class QueueInfo { - final int queueRunning; - final List> queue; - final Map queuePending; - - QueueInfo({ - required this.queueRunning, - required this.queue, - required this.queuePending, - }); - - factory QueueInfo.fromJson(Map json) { - return QueueInfo( - queueRunning: json['queue_running'] as int, - queue: List>.from(json['queue'] ?? []), - queuePending: Map.from(json['queue_pending'] ?? {}), - ); - } -} - -/// Represents a prompt execution status -class PromptExecutionStatus { - final String? promptId; - final int? number; - final String? status; - final dynamic error; - - PromptExecutionStatus({ - this.promptId, - this.number, - this.status, - this.error, - }); - - factory PromptExecutionStatus.fromJson(Map json) { - return PromptExecutionStatus( - promptId: json['prompt_id'] as String?, - number: json['number'] as int?, - status: json['status'] as String?, - error: json['error'], - ); - } -} - -/// Represents history data -class HistoryItem { - final String promptId; - final Map prompt; - final Map? outputs; - - HistoryItem({ - required this.promptId, - required this.prompt, - this.outputs, - }); - - factory HistoryItem.fromJson(Map json) { - return HistoryItem( - promptId: json['prompt_id'] as String, - prompt: Map.from(json['prompt'] ?? {}), - outputs: json['outputs'] != null - ? Map.from(json['outputs']) - : null, - ); - } -} - -/// Represents a progress update received via WebSocket -class ProgressUpdate { - final String type; - final Map data; - - ProgressUpdate({ - required this.type, - required this.data, - }); - - factory ProgressUpdate.fromJson(Map json) { - return ProgressUpdate( - type: json['type'] as String, - data: Map.from(json['data'] ?? {}), - ); - } -} diff --git a/lib/src/models/vae.dart b/lib/src/models/vae.dart new file mode 100644 index 0000000..6d06c58 --- /dev/null +++ b/lib/src/models/vae.dart @@ -0,0 +1,13 @@ +import 'base_model.dart'; + +class Vae extends BaseModel { + Vae({required String name, required int pathIndex}) + : super(name: name, pathIndex: pathIndex); + + factory Vae.fromJson(Map json) { + return Vae( + name: json['name'] as String, + pathIndex: json['pathIndex'] as int, + ); + } +} diff --git a/lib/src/types/callback_types.dart b/lib/src/types/callback_types.dart new file mode 100644 index 0000000..aa59c2c --- /dev/null +++ b/lib/src/types/callback_types.dart @@ -0,0 +1,10 @@ +import 'package:comfyui_api_sdk/comfyui_api_sdk.dart'; + +/// Callback function type for prompt events +typedef PromptEventCallback = void Function(String promptId); + +/// Callback function type for typed WebSocket events +typedef WebSocketEventCallback = void Function(WebSocketEvent event); + +/// Callback function type for progress events +typedef ProgressEventCallback = void Function(ProgressEvent event); diff --git a/lib/src/utils/websocket_event_handler.dart b/lib/src/utils/websocket_event_handler.dart new file mode 100644 index 0000000..e261765 --- /dev/null +++ b/lib/src/utils/websocket_event_handler.dart @@ -0,0 +1,54 @@ +import 'dart:async'; + +import '../models/websocket_event.dart'; +import '../models/progress_event.dart'; +import '../models/execution_event.dart'; + +class WebSocketEventHandler { + static void tryCreateProgressEvent( + WebSocketEvent event, + StreamController progressEventController, + List progressEventCallbacks, + ) { + if (event.data.containsKey('value') && + event.data.containsKey('max') && + event.data.containsKey('prompt_id')) { + try { + final progressEvent = ProgressEvent( + value: event.data['value'] as int, + max: event.data['max'] as int, + promptId: event.data['prompt_id'] as String, + node: event.data['node']?.toString(), + ); + progressEventController.add(progressEvent); + + // Trigger all registered progress callbacks + for (final callback in progressEventCallbacks) { + callback(progressEvent); + } + } catch (e) { + print('Error creating ProgressEvent: $e'); + } + } + } + + static void tryCreateExecutionEvent( + WebSocketEvent event, + StreamController executionEventController, + ) { + if (event.data.containsKey('prompt_id')) { + try { + final executionEvent = ExecutionEvent( + promptId: event.data['prompt_id'] as String, + timestamp: event.data['timestamp'] as int? ?? + DateTime.now().millisecondsSinceEpoch, + node: event.data['node']?.toString(), + extra: event.data['extra'] as Map?, + ); + executionEventController.add(executionEvent); + } catch (e) { + print('Error creating ExecutionEvent: $e'); + } + } + } +}