Add initial structure for comfyui_api_sdk with API models and event handling

This commit is contained in:
Menno van Leeuwen 2025-03-20 14:27:29 +00:00
parent 1f9409ce0e
commit 0b2769310b
19 changed files with 491 additions and 1172 deletions

View File

@ -57,7 +57,7 @@
},
{
"name": "build_runner",
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/build_runner-2.4.14",
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/build_runner-2.4.15",
"packageUri": "lib/",
"languageVersion": "3.6"
},
@ -133,6 +133,18 @@
"packageUri": "lib/",
"languageVersion": "3.1"
},
{
"name": "freezed",
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/freezed-3.0.4",
"packageUri": "lib/",
"languageVersion": "3.6"
},
{
"name": "freezed_annotation",
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/freezed_annotation-3.0.0",
"packageUri": "lib/",
"languageVersion": "3.0"
},
{
"name": "frontend_server_client",
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/frontend_server_client-4.0.0",
@ -153,9 +165,9 @@
},
{
"name": "http",
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/http-0.13.6",
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/http-1.3.0",
"packageUri": "lib/",
"languageVersion": "2.19"
"languageVersion": "3.4"
},
{
"name": "http_multi_server",
@ -188,10 +200,16 @@
"languageVersion": "3.0"
},
{
"name": "lints",
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/lints-2.1.1",
"name": "json_serializable",
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/json_serializable-6.9.4",
"packageUri": "lib/",
"languageVersion": "3.0"
"languageVersion": "3.6"
},
{
"name": "lints",
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/lints-5.1.1",
"packageUri": "lib/",
"languageVersion": "3.6"
},
{
"name": "logging",
@ -279,9 +297,9 @@
},
{
"name": "shelf_web_socket",
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/shelf_web_socket-2.0.1",
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/shelf_web_socket-3.0.0",
"packageUri": "lib/",
"languageVersion": "3.3"
"languageVersion": "3.5"
},
{
"name": "source_gen",
@ -289,6 +307,12 @@
"packageUri": "lib/",
"languageVersion": "3.6"
},
{
"name": "source_helper",
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/source_helper-1.3.5",
"packageUri": "lib/",
"languageVersion": "3.4"
},
{
"name": "source_map_stack_trace",
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/source_map_stack_trace-2.1.2",
@ -307,6 +331,12 @@
"packageUri": "lib/",
"languageVersion": "3.1"
},
{
"name": "sprintf",
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/sprintf-7.0.0",
"packageUri": "lib/",
"languageVersion": "2.12"
},
{
"name": "stack_trace",
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/stack_trace-1.12.1",
@ -369,9 +399,9 @@
},
{
"name": "uuid",
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/uuid-3.0.7",
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/uuid-4.5.1",
"packageUri": "lib/",
"languageVersion": "2.12"
"languageVersion": "3.0"
},
{
"name": "vm_service",
@ -387,13 +417,19 @@
},
{
"name": "web",
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/web-0.5.1",
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/web-1.1.1",
"packageUri": "lib/",
"languageVersion": "3.4"
},
{
"name": "web_socket",
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/web_socket-0.1.6",
"packageUri": "lib/",
"languageVersion": "3.3"
},
{
"name": "web_socket_channel",
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/web_socket_channel-2.4.5",
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/web_socket_channel-3.0.2",
"packageUri": "lib/",
"languageVersion": "3.3"
},
@ -416,10 +452,8 @@
"languageVersion": "3.0"
}
],
"generated": "2025-03-20T10:48:40.871849Z",
"generated": "2025-03-20T13:27:48.930532Z",
"generator": "pub",
"generatorVersion": "3.7.2",
"flutterRoot": "file:///home/menno/.flutter/flutter",
"flutterVersion": "3.29.2",
"pubCache": "file:///home/menno/.pub-cache"
}

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
.dart_tool

4
Makefile Normal file
View File

@ -0,0 +1,4 @@
default:
build-runner:
dart run build_runner build --delete-conflicting-outputs

View File

@ -1,205 +0,0 @@
import 'dart:async';
import 'dart:convert';
import 'package:http/http.dart' as http;
import 'package:uuid/uuid.dart';
import 'package:web_socket_channel/web_socket_channel.dart';
class ComfyUiApi {
final String host;
final String clientId;
final http.Client _httpClient;
WebSocketChannel? _wsChannel;
final StreamController<Map<String, dynamic>> _progressController =
StreamController.broadcast();
/// Stream of progress updates from ComfyUI
Stream<Map<String, dynamic>> get progressUpdates =>
_progressController.stream;
/// Creates a new ComfyUI API client
///
/// [host] The host of the ComfyUI server (e.g. 'http://localhost:8188')
/// [clientId] Optional client ID, will be automatically generated if not provided
ComfyUiApi({
required this.host,
String? clientId,
http.Client? httpClient,
}) : clientId = clientId ?? const Uuid().v4(),
_httpClient = httpClient ?? http.Client();
/// Connects to the WebSocket for progress updates
Future<void> connectWebSocket() async {
final wsUrl =
'ws://${host.replaceFirst(RegExp(r'^https?://'), '')}/ws?clientId=$clientId';
_wsChannel = WebSocketChannel.connect(Uri.parse(wsUrl));
_wsChannel!.stream.listen((message) {
final data = jsonDecode(message);
_progressController.add(data);
}, onError: (error) {
print('WebSocket error: $error');
}, onDone: () {
print('WebSocket connection closed');
});
}
/// Closes the WebSocket connection and cleans up resources
void dispose() {
_wsChannel?.sink.close();
_progressController.close();
_httpClient.close();
}
/// Gets the current queue status
Future<Map<String, dynamic>> getQueue() async {
final response = await _httpClient.get(Uri.parse('$host/queue'));
_validateResponse(response);
return jsonDecode(response.body);
}
/// Gets the history of the queue
Future<Map<String, dynamic>> getHistory({int maxItems = 64}) async {
final response = await _httpClient
.get(Uri.parse('$host/api/history?max_items=$maxItems'));
_validateResponse(response);
return jsonDecode(response.body);
}
/// Gets image data by filename
Future<List<int>> getImage(String filename) async {
final response =
await _httpClient.get(Uri.parse('$host/api/view?filename=$filename'));
_validateResponse(response);
return response.bodyBytes;
}
/// Gets a list of all available models
Future<Map<String, dynamic>> getModels() async {
final response =
await _httpClient.get(Uri.parse('$host/api/experiment/models'));
_validateResponse(response);
return jsonDecode(response.body);
}
/// Gets a list of checkpoints
Future<Map<String, dynamic>> getCheckpoints() async {
final response = await _httpClient
.get(Uri.parse('$host/api/experiment/models/checkpoints'));
_validateResponse(response);
return jsonDecode(response.body);
}
/// Gets details for a specific checkpoint
Future<Map<String, dynamic>> getCheckpointDetails(
String pathAndFileName) async {
final response = await _httpClient.get(Uri.parse(
'$host/api/view_metadata/checkpoints?filename=$pathAndFileName'));
_validateResponse(response);
return jsonDecode(response.body);
}
/// Gets a list of LoRAs
Future<Map<String, dynamic>> getLoras() async {
final response =
await _httpClient.get(Uri.parse('$host/api/experiment/models/loras'));
_validateResponse(response);
return jsonDecode(response.body);
}
/// Gets details for a specific LoRA
Future<Map<String, dynamic>> getLoraDetails(String pathAndFileName) async {
final response = await _httpClient.get(
Uri.parse('$host/api/view_metadata/loras?filename=$pathAndFileName'));
_validateResponse(response);
return jsonDecode(response.body);
}
/// Gets a list of VAEs
Future<Map<String, dynamic>> getVaes() async {
final response =
await _httpClient.get(Uri.parse('$host/api/experiment/models/vae'));
_validateResponse(response);
return jsonDecode(response.body);
}
/// Gets details for a specific VAE
Future<Map<String, dynamic>> getVaeDetails(String pathAndFileName) async {
final response = await _httpClient.get(
Uri.parse('$host/api/view_metadata/vae?filename=$pathAndFileName'));
_validateResponse(response);
return jsonDecode(response.body);
}
/// Gets a list of upscale models
Future<Map<String, dynamic>> getUpscaleModels() async {
final response = await _httpClient
.get(Uri.parse('$host/api/experiment/models/upscale_models'));
_validateResponse(response);
return jsonDecode(response.body);
}
/// Gets details for a specific upscale model
Future<Map<String, dynamic>> getUpscaleModelDetails(
String pathAndFileName) async {
final response = await _httpClient.get(Uri.parse(
'$host/api/view_metadata/upscale_models?filename=$pathAndFileName'));
_validateResponse(response);
return jsonDecode(response.body);
}
/// Gets a list of embeddings
Future<Map<String, dynamic>> getEmbeddings() async {
final response = await _httpClient
.get(Uri.parse('$host/api/experiment/models/embeddings'));
_validateResponse(response);
return jsonDecode(response.body);
}
/// Gets details for a specific embedding
Future<Map<String, dynamic>> getEmbeddingDetails(
String pathAndFileName) async {
final response = await _httpClient.get(Uri.parse(
'$host/api/view_metadata/embeddings?filename=$pathAndFileName'));
_validateResponse(response);
return jsonDecode(response.body);
}
/// Gets information about all available objects (nodes)
Future<Map<String, dynamic>> getObjectInfo() async {
final response = await _httpClient.get(Uri.parse('$host/api/object_info'));
_validateResponse(response);
return jsonDecode(response.body);
}
/// Submits a prompt (workflow) to generate an image
Future<Map<String, dynamic>> submitPrompt(Map<String, dynamic> prompt) async {
final response = await _httpClient.post(
Uri.parse('$host/api/prompt'),
headers: {'Content-Type': 'application/json'},
body: jsonEncode(prompt),
);
_validateResponse(response);
return jsonDecode(response.body);
}
/// Validates HTTP response and throws an exception if needed
void _validateResponse(http.Response response) {
if (response.statusCode < 200 || response.statusCode >= 300) {
throw ComfyUiApiException(
statusCode: response.statusCode,
message: 'API request failed: ${response.body}');
}
}
}
/// Exception thrown when the ComfyUI API returns an error
class ComfyUiApiException implements Exception {
final int statusCode;
final String message;
ComfyUiApiException({required this.statusCode, required this.message});
@override
String toString() => 'ComfyUiApiException: $statusCode - $message';
}

View File

@ -1,94 +0,0 @@
import 'dart:io';
import 'package:comfyui_api_sdk/comfyui_api_sdk.dart';
void main() async {
// Create the API client
final api = ComfyUiApi(host: 'http://mennos-server:7860');
// Connect to the WebSocket for progress updates
await api.connectWebSocket();
// Listen for progress updates
api.progressUpdates.listen((update) {
print('Progress update: $update');
});
// Get available checkpoints
final checkpoints = await api.getCheckpoints();
print('Available checkpoints: ${checkpoints.keys.join(', ')}');
// Get queue status
final queue = await api.getQueue();
print('Queue status: $queue');
// Submit a basic text-to-image prompt
final promptWorkflow = {
"prompt": {
"3": {
"inputs": {
"seed": 123456789,
"steps": 20,
"cfg": 7,
"sampler_name": "euler_ancestral",
"scheduler": "normal",
"denoise": 1,
"model": ["4", 0],
"positive": ["6", 0],
"negative": ["7", 0],
"latent_image": ["5", 0]
},
"class_type": "KSampler"
},
"4": {
"inputs": {"ckpt_name": "dreamshaper_8.safetensors"},
"class_type": "CheckpointLoaderSimple"
},
"5": {
"inputs": {"width": 512, "height": 512, "batch_size": 1},
"class_type": "EmptyLatentImage"
},
"6": {
"inputs": {
"text": "a beautiful landscape with mountains and a lake",
"clip": ["4", 1]
},
"class_type": "CLIPTextEncode"
},
"7": {
"inputs": {
"text": "ugly, blurry, low quality",
"clip": ["4", 1]
},
"class_type": "CLIPTextEncode"
},
"8": {
"inputs": {
"samples": ["3", 0],
"vae": ["4", 2]
},
"class_type": "VAEDecode"
},
"9": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": ["8", 0]
},
"class_type": "SaveImage"
}
},
"client_id": api.clientId
};
try {
final result = await api.submitPrompt(promptWorkflow);
print('Prompt submitted: $result');
} catch (e) {
print('Error submitting prompt: $e');
}
// Wait for some time to receive WebSocket messages
await Future.delayed(Duration(seconds: 60));
// Clean up
api.dispose();
}

View File

@ -1,4 +1,6 @@
library comfyui_api_sdk;
export 'src/comfyui_api.dart';
export 'src/models/models.dart';
export 'src/models/websocket_event.dart';
export 'src/models/progress_event.dart';
export 'src/models/execution_event.dart';

View File

@ -5,22 +5,69 @@ import 'package:http/http.dart' as http;
import 'package:uuid/uuid.dart';
import 'package:web_socket_channel/web_socket_channel.dart';
import 'models/websocket_event.dart';
import 'models/progress_event.dart';
import 'models/execution_event.dart';
/// Callback function type for prompt events
typedef PromptEventCallback = void Function(String promptId);
/// Callback function type for typed WebSocket events
typedef WebSocketEventCallback = void Function(WebSocketEvent event);
/// Callback function type for progress events
typedef ProgressEventCallback = void Function(ProgressEvent event);
/// A Dart SDK for interacting with the ComfyUI API
class ComfyUiApi {
final String host;
final String clientId;
final http.Client _httpClient;
WebSocketChannel? _wsChannel;
// Controllers for different event streams
final StreamController<WebSocketEvent> _eventController =
StreamController.broadcast();
final StreamController<Map<String, dynamic>> _progressController =
StreamController.broadcast();
/// Stream of progress updates from ComfyUI
// Add new controllers for specific event types
final StreamController<ProgressEvent> _progressEventController =
StreamController.broadcast();
final StreamController<ExecutionEvent> _executionEventController =
StreamController.broadcast();
/// Stream of typed progress events
Stream<ProgressEvent> get progressEvents => _progressEventController.stream;
/// Stream of typed execution events
Stream<ExecutionEvent> get executionEvents =>
_executionEventController.stream;
// Event callbacks
final Map<String, List<PromptEventCallback>> _eventCallbacks = {
'onPromptStart': [],
'onPromptFinished': [],
};
final Map<WebSocketEventType, List<WebSocketEventCallback>>
_typedEventCallbacks = {
for (var type in WebSocketEventType.values) type: [],
};
// Add a separate map for progress event callbacks
final List<ProgressEventCallback> _progressEventCallbacks = [];
/// Stream of typed WebSocket events
Stream<WebSocketEvent> get events => _eventController.stream;
/// Stream of progress updates from ComfyUI (legacy format)
Stream<Map<String, dynamic>> get progressUpdates =>
_progressController.stream;
/// Creates a new ComfyUI API client
///
/// [host] The host of the ComfyUI server (e.g. 'http://localhost:8188')
/// [host] The host of the ComfyUI server (e.g. 'http://localhost:7860')
/// [clientId] Optional client ID, will be automatically generated if not provided
ComfyUiApi({
required this.host,
@ -29,6 +76,26 @@ class ComfyUiApi {
}) : clientId = clientId ?? const Uuid().v4(),
_httpClient = httpClient ?? http.Client();
/// Register a callback for when a prompt starts processing
void onPromptStart(PromptEventCallback callback) {
_eventCallbacks['onPromptStart']!.add(callback);
}
/// Register a callback for when a prompt finishes processing
void onPromptFinished(PromptEventCallback callback) {
_eventCallbacks['onPromptFinished']!.add(callback);
}
/// Register a callback for specific WebSocket event types
void onEventType(WebSocketEventType type, WebSocketEventCallback callback) {
_typedEventCallbacks[type]!.add(callback);
}
/// Register a callback for progress updates
void onProgressChanged(ProgressEventCallback callback) {
_progressEventCallbacks.add(callback);
}
/// Connects to the WebSocket for progress updates
Future<void> connectWebSocket() async {
final wsUrl =
@ -36,8 +103,43 @@ class ComfyUiApi {
_wsChannel = WebSocketChannel.connect(Uri.parse(wsUrl));
_wsChannel!.stream.listen((message) {
final data = jsonDecode(message);
_progressController.add(data);
final jsonData = jsonDecode(message);
// Create a typed event
final event = WebSocketEvent.fromJson(jsonData);
// Add to the typed event stream
_eventController.add(event);
// Also add to the legacy progress stream
_progressController.add(jsonData);
// Convert to more specific event types if possible
if (event.eventType == WebSocketEventType.progress) {
_tryCreateProgressEvent(event);
} else if ([
WebSocketEventType.executionStart,
WebSocketEventType.executionSuccess,
WebSocketEventType.executionCached,
WebSocketEventType.executed,
WebSocketEventType.executing,
].contains(event.eventType)) {
_tryCreateExecutionEvent(event);
}
// Trigger event type specific callbacks
for (final callback in _typedEventCallbacks[event.eventType]!) {
callback(event);
}
// Handle execution_success event (prompt finished)
if (event.eventType == WebSocketEventType.executionSuccess &&
event.promptId != null) {
final promptId = event.promptId!;
for (final callback in _eventCallbacks['onPromptFinished']!) {
callback(promptId);
}
}
}, onError: (error) {
print('WebSocket error: $error');
}, onDone: () {
@ -45,10 +147,55 @@ class ComfyUiApi {
});
}
/// Attempts to create a ProgressEvent from a WebSocketEvent
void _tryCreateProgressEvent(WebSocketEvent event) {
if (event.data.containsKey('value') &&
event.data.containsKey('max') &&
event.data.containsKey('prompt_id')) {
try {
final progressEvent = ProgressEvent(
value: event.data['value'] as int,
max: event.data['max'] as int,
promptId: event.data['prompt_id'] as String,
node: event.data['node']?.toString(),
);
_progressEventController.add(progressEvent);
// Trigger all registered progress callbacks
for (final callback in _progressEventCallbacks) {
callback(progressEvent);
}
} catch (e) {
print('Error creating ProgressEvent: $e');
}
}
}
/// Attempts to create an ExecutionEvent from a WebSocketEvent
void _tryCreateExecutionEvent(WebSocketEvent event) {
if (event.data.containsKey('prompt_id')) {
try {
final executionEvent = ExecutionEvent(
promptId: event.data['prompt_id'] as String,
timestamp: event.data['timestamp'] as int? ??
DateTime.now().millisecondsSinceEpoch,
node: event.data['node']?.toString(),
extra: event.data['extra'] as Map<String, dynamic>?,
);
_executionEventController.add(executionEvent);
} catch (e) {
print('Error creating ExecutionEvent: $e');
}
}
}
/// Closes the WebSocket connection and cleans up resources
void dispose() {
_wsChannel?.sink.close();
_progressController.close();
_eventController.close();
_progressEventController.close();
_executionEventController.close();
_httpClient.close();
}
@ -181,7 +328,17 @@ class ComfyUiApi {
body: jsonEncode(prompt),
);
_validateResponse(response);
return jsonDecode(response.body);
final responseData = jsonDecode(response.body);
// Trigger onPromptStart event if prompt_id exists
if (responseData.containsKey('prompt_id')) {
final promptId = responseData['prompt_id'];
for (final callback in _eventCallbacks['onPromptStart']!) {
callback(promptId);
}
}
return responseData;
}
/// Validates HTTP response and throws an exception if needed

View File

@ -0,0 +1,7 @@
// Main API
export 'comfyui_api.dart';
// Models
export 'models/websocket_event.dart';
export 'models/progress_event.dart';
export 'models/execution_event.dart';

View File

@ -0,0 +1,11 @@
import 'websocket_event.dart';
import 'progress_event.dart';
/// Callback function type for prompt events
typedef PromptEventCallback = void Function(String promptId);
/// Callback function type for typed WebSocket events
typedef WebSocketEventCallback = void Function(WebSocketEvent event);
/// Callback function type for progress events
typedef ProgressEventCallback = void Function(ProgressEvent event);

View File

@ -0,0 +1,31 @@
class ExecutionEvent {
final String promptId;
final int timestamp;
final String? node;
final Map<String, dynamic>? extra;
const ExecutionEvent({
required this.promptId,
required this.timestamp,
this.node,
this.extra,
});
factory ExecutionEvent.fromJson(Map<String, dynamic> json) {
return ExecutionEvent(
promptId: json['prompt_id'] as String,
timestamp: json['timestamp'] as int,
node: json['node'] as String?,
extra: json['extra'] as Map<String, dynamic>?,
);
}
Map<String, dynamic> toJson() {
return {
'prompt_id': promptId,
'timestamp': timestamp,
'node': node,
'extra': extra,
};
}
}

View File

@ -0,0 +1,39 @@
class ProgressEvent {
final int value;
final int max;
final String promptId;
final String? node;
const ProgressEvent({
required this.value,
required this.max,
required this.promptId,
this.node,
});
factory ProgressEvent.fromJson(Map<String, dynamic> json) {
return ProgressEvent(
value: json['value'] as int,
max: json['max'] as int,
promptId: json['prompt_id'] as String,
node: json['node'] as String?,
);
}
Map<String, dynamic> toJson() {
return {
'value': value,
'max': max,
'prompt_id': promptId,
'node': node,
};
}
}
extension ProgressEventComputation on ProgressEvent {
/// Returns the progress percentage (0-100)
double get percentage => (value / max) * 100;
/// Returns true if the progress is complete
bool get isComplete => value >= max;
}

View File

@ -0,0 +1,109 @@
/// Types of WebSocket events from ComfyUI
enum WebSocketEventType {
status,
progress,
executing,
executed,
executionStart,
executionCached,
executionSuccess,
executionError,
dataOutput,
unknown
}
/// A typed event from the ComfyUI WebSocket
class WebSocketEvent {
final WebSocketEventType eventType;
final Map<String, dynamic> data;
final String? promptId;
final String rawType;
WebSocketEvent({
required this.eventType,
required this.data,
this.promptId,
required this.rawType,
});
@override
String toString() =>
'WebSocketEvent{type: $eventType, rawType: $rawType, promptId: $promptId, data: ${data.keys.join(', ')}}';
/// Creates a WebSocketEvent from JSON
factory WebSocketEvent.fromJson(Map<String, dynamic> json) {
WebSocketEventType type;
Map<String, dynamic> eventData = {};
String? promptId;
String rawType = json['type']?.toString() ?? 'unknown';
// Extract event type
if (json.containsKey('type')) {
final typeStr = json['type'].toString();
// First, try to extract prompt_id from the data field
if (json.containsKey('data') && json['data'] is Map) {
final dataMap = Map<String, dynamic>.from(json['data'] as Map);
if (dataMap.containsKey('prompt_id')) {
promptId = dataMap['prompt_id'].toString();
}
eventData = dataMap;
} else {
// If no data field, use the root object
eventData = Map<String, dynamic>.from(json);
}
// Try to extract prompt_id from other potential locations
if (promptId == null && json.containsKey('prompt_id')) {
promptId = json['prompt_id'].toString();
}
switch (typeStr) {
case 'status':
type = WebSocketEventType.status;
break;
case 'progress':
type = WebSocketEventType.progress;
break;
case 'executing':
type = WebSocketEventType.executing;
break;
case 'executed':
type = WebSocketEventType.executed;
break;
case 'execution_start':
type = WebSocketEventType.executionStart;
break;
case 'execution_cached':
type = WebSocketEventType.executionCached;
break;
case 'execution_success':
type = WebSocketEventType.executionSuccess;
break;
case 'execution_error':
type = WebSocketEventType.executionError;
break;
default:
if (typeStr.startsWith('data_output')) {
type = WebSocketEventType.dataOutput;
} else {
// Default to unknown for unrecognized types
type = WebSocketEventType.unknown;
print('Unknown event type: $typeStr');
}
}
} else {
// If no type field, default to unknown
type = WebSocketEventType.unknown;
print('WebSocket event missing type field');
eventData = Map<String, dynamic>.from(json);
}
return WebSocketEvent(
eventType: type,
data: eventData,
promptId: promptId,
rawType: rawType,
);
}
}

View File

@ -74,13 +74,13 @@ packages:
source: hosted
version: "2.4.4"
build_runner:
dependency: "direct main"
dependency: "direct dev"
description:
name: build_runner
sha256: "74691599a5bc750dc96a6b4bfd48f7d9d66453eab04c7f4063134800d6a5c573"
sha256: "058fe9dce1de7d69c4b84fada934df3e0153dd000758c4d65964d0166779aa99"
url: "https://pub.dev"
source: hosted
version: "2.4.14"
version: "2.4.15"
build_runner_core:
dependency: transitive
description:
@ -177,6 +177,22 @@ packages:
url: "https://pub.dev"
source: hosted
version: "1.1.1"
freezed:
dependency: "direct dev"
description:
name: freezed
sha256: "7ed2ddaa47524976d5f2aa91432a79da36a76969edd84170777ac5bea82d797c"
url: "https://pub.dev"
source: hosted
version: "3.0.4"
freezed_annotation:
dependency: "direct main"
description:
name: freezed_annotation
sha256: c87ff004c8aa6af2d531668b46a4ea379f7191dc6dfa066acd53d506da6e044b
url: "https://pub.dev"
source: hosted
version: "3.0.0"
frontend_server_client:
dependency: transitive
description:
@ -205,10 +221,10 @@ packages:
dependency: "direct main"
description:
name: http
sha256: "5895291c13fa8a3bd82e76d5627f69e0d85ca6a30dcac95c4ea19a5d555879c2"
sha256: fe7ab022b76f3034adc518fb6ea04a82387620e19977665ea18d30a1cf43442f
url: "https://pub.dev"
source: hosted
version: "0.13.6"
version: "1.3.0"
http_multi_server:
dependency: transitive
description:
@ -242,21 +258,29 @@ packages:
source: hosted
version: "0.7.2"
json_annotation:
dependency: transitive
dependency: "direct main"
description:
name: json_annotation
sha256: "1ce844379ca14835a50d2f019a3099f419082cfdd231cd86a142af94dd5c6bb1"
url: "https://pub.dev"
source: hosted
version: "4.9.0"
json_serializable:
dependency: "direct dev"
description:
name: json_serializable
sha256: "81f04dee10969f89f604e1249382d46b97a1ccad53872875369622b5bfc9e58a"
url: "https://pub.dev"
source: hosted
version: "6.9.4"
lints:
dependency: "direct dev"
description:
name: lints
sha256: "0a217c6c989d21039f1498c3ed9f3ed71b354e69873f13a8dfc3c9fe76f1b452"
sha256: c35bb79562d980e9a453fc715854e1ed39e24e7d0297a880ef54e17f9874a9d7
url: "https://pub.dev"
source: hosted
version: "2.1.1"
version: "5.1.1"
logging:
dependency: transitive
description:
@ -290,7 +314,7 @@ packages:
source: hosted
version: "2.0.0"
mockito:
dependency: "direct main"
dependency: "direct dev"
description:
name: mockito
sha256: f99d8d072e249f719a5531735d146d8cf04c580d93920b04de75bef6dfb2daf6
@ -373,10 +397,10 @@ packages:
dependency: transitive
description:
name: shelf_web_socket
sha256: cc36c297b52866d203dbf9332263c94becc2fe0ceaa9681d07b6ef9807023b67
sha256: "3632775c8e90d6c9712f883e633716432a27758216dfb61bd86a8321c0580925"
url: "https://pub.dev"
source: hosted
version: "2.0.1"
version: "3.0.0"
source_gen:
dependency: transitive
description:
@ -385,6 +409,14 @@ packages:
url: "https://pub.dev"
source: hosted
version: "2.0.0"
source_helper:
dependency: transitive
description:
name: source_helper
sha256: "86d247119aedce8e63f4751bd9626fc9613255935558447569ad42f9f5b48b3c"
url: "https://pub.dev"
source: hosted
version: "1.3.5"
source_map_stack_trace:
dependency: transitive
description:
@ -409,6 +441,14 @@ packages:
url: "https://pub.dev"
source: hosted
version: "1.10.1"
sprintf:
dependency: transitive
description:
name: sprintf
sha256: "1fc9ffe69d4df602376b52949af107d8f5703b77cda567c4d7d86a0693120f23"
url: "https://pub.dev"
source: hosted
version: "7.0.0"
stack_trace:
dependency: transitive
description:
@ -493,10 +533,10 @@ packages:
dependency: "direct main"
description:
name: uuid
sha256: "648e103079f7c64a36dc7d39369cabb358d377078a051d6ae2ad3aa539519313"
sha256: a5be9ef6618a7ac1e964353ef476418026db906c4facdedaa299b7a2e71690ff
url: "https://pub.dev"
source: hosted
version: "3.0.7"
version: "4.5.1"
vm_service:
dependency: transitive
description:
@ -517,18 +557,26 @@ packages:
dependency: transitive
description:
name: web
sha256: "97da13628db363c635202ad97068d47c5b8aa555808e7a9411963c533b449b27"
sha256: "868d88a33d8a87b18ffc05f9f030ba328ffefba92d6c127917a2ba740f9cfe4a"
url: "https://pub.dev"
source: hosted
version: "0.5.1"
version: "1.1.1"
web_socket:
dependency: transitive
description:
name: web_socket
sha256: "3c12d96c0c9a4eec095246debcea7b86c0324f22df69893d538fcc6f1b8cce83"
url: "https://pub.dev"
source: hosted
version: "0.1.6"
web_socket_channel:
dependency: "direct main"
description:
name: web_socket_channel
sha256: "58c6666b342a38816b2e7e50ed0f1e261959630becd4c879c4f26bfa14aa5a42"
sha256: "0b8e2457400d8a859b7b2030786835a28a8e80836ef64402abef392ff4f1d0e5"
url: "https://pub.dev"
source: hosted
version: "2.4.5"
version: "3.0.2"
webkit_inspection_protocol:
dependency: transitive
description:

View File

@ -6,12 +6,16 @@ environment:
sdk: '>=3.0.0 <4.0.0'
dependencies:
http: ^0.13.5
web_socket_channel: ^2.3.0
uuid: ^3.0.7
mockito: ^5.4.5
build_runner: ^2.4.14
http: ^1.3.0
web_socket_channel: ^3.0.2
uuid: ^4.5.1
json_annotation: ^4.9.0
freezed_annotation: ^3.0.0
dev_dependencies:
lints: ^2.0.0
lints: ^5.1.1
test: ^1.21.0
mockito: ^5.4.5
build_runner: ^2.4.14
freezed: ^3.0.4
json_serializable: ^6.7.1

View File

@ -1,197 +0,0 @@
import 'dart:convert';
import 'package:comfyui_api_sdk/comfyui_api_sdk.dart';
import 'package:http/http.dart' as http;
import 'package:http/testing.dart';
import 'package:mockito/annotations.dart';
import 'package:mockito/mockito.dart';
import 'package:test/test.dart';
import 'package:web_socket_channel/web_socket_channel.dart';
import 'comfyui_api_test.mocks.dart';
import 'test_data.dart';
@GenerateMocks([http.Client, WebSocketChannel, WebSocketSink])
void main() {
late MockClient mockClient;
late ComfyUiApi api;
const String testHost = 'http://localhost:8188';
const String testClientId = 'test-client-id';
setUp(() {
mockClient = MockClient();
api = ComfyUiApi(
host: testHost,
clientId: testClientId,
httpClient: mockClient,
);
});
group('ComfyUiApi', () {
test('initialize with provided values', () {
expect(api.host, equals(testHost));
expect(api.clientId, equals(testClientId));
});
test('initialize with generated clientId when not provided', () {
final autoApi = ComfyUiApi(host: testHost, httpClient: mockClient);
expect(autoApi.clientId, isNotEmpty);
expect(autoApi.clientId, isNot(equals(testClientId)));
});
test('getQueue returns parsed response', () async {
when(mockClient.get(Uri.parse('$testHost/queue'))).thenAnswer(
(_) async => http.Response(jsonEncode(TestData.queueResponse), 200));
final result = await api.getQueue();
expect(result, equals(TestData.queueResponse));
verify(mockClient.get(Uri.parse('$testHost/queue'))).called(1);
});
test('getHistory returns parsed response', () async {
when(mockClient.get(Uri.parse('$testHost/api/history?max_items=64')))
.thenAnswer((_) async =>
http.Response(jsonEncode(TestData.historyResponse), 200));
final result = await api.getHistory();
expect(result, equals(TestData.historyResponse));
verify(mockClient.get(Uri.parse('$testHost/api/history?max_items=64')))
.called(1);
});
test('getImage returns image bytes', () async {
final bytes = [1, 2, 3, 4];
when(mockClient.get(Uri.parse('$testHost/api/view?filename=test.png')))
.thenAnswer((_) async => http.Response.bytes(bytes, 200));
final result = await api.getImage('test.png');
expect(result, equals(bytes));
verify(mockClient.get(Uri.parse('$testHost/api/view?filename=test.png')))
.called(1);
});
test('getCheckpoints returns parsed response', () async {
when(mockClient
.get(Uri.parse('$testHost/api/experiment/models/checkpoints')))
.thenAnswer((_) async =>
http.Response(jsonEncode(TestData.checkpointsResponse), 200));
final result = await api.getCheckpoints();
expect(result, equals(TestData.checkpointsResponse));
verify(mockClient
.get(Uri.parse('$testHost/api/experiment/models/checkpoints')))
.called(1);
});
test('getCheckpointDetails returns parsed response', () async {
const filename = 'models/checkpoints/test.safetensors';
when(mockClient.get(Uri.parse(
'$testHost/api/view_metadata/checkpoints?filename=$filename')))
.thenAnswer((_) async => http.Response(
jsonEncode(TestData.checkpointMetadataResponse), 200));
final result = await api.getCheckpointDetails(filename);
expect(result, equals(TestData.checkpointMetadataResponse));
verify(mockClient.get(Uri.parse(
'$testHost/api/view_metadata/checkpoints?filename=$filename')))
.called(1);
});
test('getLoras returns parsed response', () async {
when(mockClient.get(Uri.parse('$testHost/api/experiment/models/loras')))
.thenAnswer((_) async =>
http.Response(jsonEncode(TestData.lorasResponse), 200));
final result = await api.getLoras();
expect(result, equals(TestData.lorasResponse));
verify(mockClient.get(Uri.parse('$testHost/api/experiment/models/loras')))
.called(1);
});
test('getVaes returns parsed response', () async {
when(mockClient.get(Uri.parse('$testHost/api/experiment/models/vae')))
.thenAnswer((_) async =>
http.Response(jsonEncode(TestData.vaeResponse), 200));
final result = await api.getVaes();
expect(result, equals(TestData.vaeResponse));
verify(mockClient.get(Uri.parse('$testHost/api/experiment/models/vae')))
.called(1);
});
test('getObjectInfo returns parsed response', () async {
when(mockClient.get(Uri.parse('$testHost/api/object_info'))).thenAnswer(
(_) async =>
http.Response(jsonEncode(TestData.objectInfoResponse), 200));
final result = await api.getObjectInfo();
expect(result, equals(TestData.objectInfoResponse));
verify(mockClient.get(Uri.parse('$testHost/api/object_info'))).called(1);
});
test('submitPrompt returns parsed response', () async {
when(mockClient.post(
Uri.parse('$testHost/api/prompt'),
headers: {'Content-Type': 'application/json'},
body: jsonEncode(TestData.promptRequest),
)).thenAnswer(
(_) async => http.Response(jsonEncode(TestData.promptResponse), 200));
final result = await api.submitPrompt(TestData.promptRequest);
expect(result, equals(TestData.promptResponse));
verify(mockClient.post(
Uri.parse('$testHost/api/prompt'),
headers: {'Content-Type': 'application/json'},
body: jsonEncode(TestData.promptRequest),
)).called(1);
});
test('throws ComfyUiApiException on error response', () async {
when(mockClient.get(Uri.parse('$testHost/queue')))
.thenAnswer((_) async => http.Response('Error message', 500));
expect(() => api.getQueue(), throwsA(isA<ComfyUiApiException>()));
});
});
group('Models', () {
test('QueueInfo parses from JSON correctly', () {
final queueInfo = QueueInfo.fromJson(TestData.queueResponse);
expect(queueInfo.queueRunning, equals(0));
expect(queueInfo.queue.length, equals(0));
expect(queueInfo.queuePending, isA<Map<String, dynamic>>());
});
test('PromptExecutionStatus parses from JSON correctly', () {
final status = PromptExecutionStatus.fromJson(TestData.promptResponse);
expect(status.promptId, equals('123456789'));
expect(status.number, equals(1));
expect(status.status, equals('success'));
});
test('HistoryItem parses from JSON correctly', () {
final item = HistoryItem.fromJson(TestData.historyItemResponse);
expect(item.promptId, equals('123456789'));
expect(item.prompt, isA<Map<String, dynamic>>());
expect(item.outputs, isA<Map<String, dynamic>>());
});
test('ProgressUpdate parses from JSON correctly', () {
final update = ProgressUpdate.fromJson(TestData.progressUpdateResponse);
expect(update.type, equals('execution_start'));
expect(update.data, isA<Map<String, dynamic>>());
});
});
}

View File

@ -1,212 +0,0 @@
import 'dart:convert';
import 'dart:io';
import 'package:comfyui_api_sdk/comfyui_api_sdk.dart';
import 'package:http/http.dart' as http;
import 'package:http/testing.dart';
import 'package:test/test.dart';
void main() {
late ComfyUiApi api;
late MockClient mockClient;
const String testHost = 'http://localhost:8188';
setUp(() {
// Setup a MockClient that simulates real API responses
mockClient = MockClient((request) async {
final uri = request.url;
final method = request.method;
// Simulate queue endpoint
if (uri.path == '/queue' && method == 'GET') {
return http.Response(
jsonEncode({'queue_running': 0, 'queue': [], 'queue_pending': {}}),
200,
);
}
// Simulate history endpoint
if (uri.path == '/api/history' && method == 'GET') {
return http.Response(
jsonEncode({
'History': {
'123456789': {
'prompt': {
// Simplified prompt data
'1': {'class_type': 'TestNode'}
},
'outputs': {
'8': {
'images': {
'filename': 'ComfyUI_00001_.png',
'subfolder': '',
'type': 'output',
}
}
}
}
}
}),
200,
);
}
// Simulate checkpoint list endpoint
if (uri.path == '/api/experiment/models/checkpoints' && method == 'GET') {
return http.Response(
jsonEncode({
'models/checkpoints/dreamshaper_8.safetensors': {
'filename': 'dreamshaper_8.safetensors',
'folder': 'models/checkpoints',
}
}),
200,
);
}
// Simulate checkpoint metadata endpoint
if (uri.path == '/api/view_metadata/checkpoints' && method == 'GET') {
return http.Response(
jsonEncode({
'model': {
'type': 'checkpoint',
'title': 'Dreamshaper 8',
'hash': 'abcdef1234567890',
}
}),
200,
);
}
// Simulate object info endpoint
if (uri.path == '/api/object_info' && method == 'GET') {
return http.Response(
jsonEncode({
'KSampler': {
'input': {
'required': {
'model': 'MODEL',
'seed': 'INT',
'steps': 'INT',
}
},
'output': ['LATENT'],
'output_is_list': [false]
}
}),
200,
);
}
// Simulate prompt submission endpoint
if (uri.path == '/api/prompt' && method == 'POST') {
return http.Response(
jsonEncode(
{'prompt_id': '123456789', 'number': 1, 'status': 'success'}),
200,
);
}
// Simulate image view endpoint
if (uri.path == '/api/view' && method == 'GET') {
// Return a dummy image
return http.Response.bytes([1, 2, 3, 4], 200,
headers: {
'Content-Type': 'image/png',
});
}
// Default response for unhandled routes
return http.Response('Not Found', 404);
});
// Create the API with our mock client
api = ComfyUiApi(
host: testHost,
clientId: 'integration-test-client',
httpClient: mockClient,
);
});
group('Integration Tests', () {
test('Get queue information', () async {
final queue = await api.getQueue();
expect(queue['queue_running'], equals(0));
expect(queue['queue'], isEmpty);
expect(queue['queue_pending'], isA<Map>());
});
test('Get history information', () async {
final history = await api.getHistory();
expect(history['History'], isA<Map>());
expect(history['History']['123456789'], isA<Map>());
expect(history['History']['123456789']['outputs'], isA<Map>());
});
test('Get checkpoint list', () async {
final checkpoints = await api.getCheckpoints();
expect(checkpoints.keys,
contains('models/checkpoints/dreamshaper_8.safetensors'));
expect(
checkpoints['models/checkpoints/dreamshaper_8.safetensors']
['filename'],
equals('dreamshaper_8.safetensors'));
});
test('Get checkpoint metadata', () async {
final metadata = await api
.getCheckpointDetails('models/checkpoints/dreamshaper_8.safetensors');
expect(metadata['model']['type'], equals('checkpoint'));
expect(metadata['model']['title'], equals('Dreamshaper 8'));
});
test('Get object info', () async {
final info = await api.getObjectInfo();
expect(info['KSampler'], isA<Map>());
expect(info['KSampler']['input']['required']['seed'], equals('INT'));
});
test('Submit prompt', () async {
final promptData = {
'prompt': {
'1': {
'inputs': {'text': 'A beautiful landscape'},
'class_type': 'CLIPTextEncode'
}
},
'client_id': 'integration-test-client'
};
final result = await api.submitPrompt(promptData);
expect(result['prompt_id'], equals('123456789'));
expect(result['status'], equals('success'));
});
test('Get image', () async {
final imageBytes = await api.getImage('ComfyUI_00001_.png');
expect(imageBytes, equals([1, 2, 3, 4]));
});
test('Handle error response', () async {
// Create a client that always returns an error
final errorClient = MockClient((_) async {
return http.Response('Server Error', 500);
});
final errorApi = ComfyUiApi(
host: testHost,
clientId: 'error-test-client',
httpClient: errorClient,
);
expect(() => errorApi.getQueue(), throwsA(isA<ComfyUiApiException>()));
});
});
}

View File

@ -1,127 +0,0 @@
import 'package:comfyui_api_sdk/comfyui_api_sdk.dart';
import 'package:test/test.dart';
void main() {
group('QueueInfo', () {
test('fromJson creates instance with correct values', () {
final json = {
'queue_running': 1,
'queue': [
{'prompt_id': '123', 'number': 1}
],
'queue_pending': {
'456': {'prompt_id': '456', 'number': 2}
}
};
final queueInfo = QueueInfo.fromJson(json);
expect(queueInfo.queueRunning, equals(1));
expect(queueInfo.queue.length, equals(1));
expect(queueInfo.queue[0]['prompt_id'], equals('123'));
expect(queueInfo.queuePending['456']['prompt_id'], equals('456'));
});
test('fromJson handles missing or empty values', () {
final json = {'queue_running': 0};
final queueInfo = QueueInfo.fromJson(json);
expect(queueInfo.queueRunning, equals(0));
expect(queueInfo.queue, isEmpty);
expect(queueInfo.queuePending, isEmpty);
});
});
group('PromptExecutionStatus', () {
test('fromJson creates instance with correct values', () {
final json = {
'prompt_id': 'abc123',
'number': 5,
'status': 'processing',
'error': null
};
final status = PromptExecutionStatus.fromJson(json);
expect(status.promptId, equals('abc123'));
expect(status.number, equals(5));
expect(status.status, equals('processing'));
expect(status.error, isNull);
});
test('fromJson handles error information', () {
final json = {
'prompt_id': 'abc123',
'number': 5,
'status': 'error',
'error': 'Something went wrong'
};
final status = PromptExecutionStatus.fromJson(json);
expect(status.status, equals('error'));
expect(status.error, equals('Something went wrong'));
});
});
group('HistoryItem', () {
test('fromJson creates instance with correct values', () {
final json = {
'prompt_id': 'abc123',
'prompt': {
'1': {'class_type': 'TestNode'}
},
'outputs': {
'2': {
'images': {'filename': 'test.png'}
}
}
};
final item = HistoryItem.fromJson(json);
expect(item.promptId, equals('abc123'));
expect(item.prompt['1']['class_type'], equals('TestNode'));
expect(item.outputs?['2']['images']['filename'], equals('test.png'));
});
test('fromJson handles missing outputs', () {
final json = {
'prompt_id': 'abc123',
'prompt': {
'1': {'class_type': 'TestNode'}
}
};
final item = HistoryItem.fromJson(json);
expect(item.promptId, equals('abc123'));
expect(item.outputs, isNull);
});
});
group('ProgressUpdate', () {
test('fromJson creates instance with correct values', () {
final json = {
'type': 'execution_start',
'data': {'prompt_id': 'abc123', 'node': 5}
};
final update = ProgressUpdate.fromJson(json);
expect(update.type, equals('execution_start'));
expect(update.data['prompt_id'], equals('abc123'));
expect(update.data['node'], equals(5));
});
test('fromJson handles empty data', () {
final json = {'type': 'status', 'data': {}};
final update = ProgressUpdate.fromJson(json);
expect(update.type, equals('status'));
expect(update.data, isEmpty);
});
});
}

View File

@ -1,148 +0,0 @@
/// Test data for ComfyUI API tests
class TestData {
/// Mock queue response
static final Map<String, dynamic> queueResponse = {
'queue_running': 0,
'queue': [],
'queue_pending': {}
};
/// Mock history response
static final Map<String, dynamic> historyResponse = {
'History': {
'123456789': {
'prompt': {
// Prompt data
},
'outputs': {
'8': {
'images': {
'filename': 'ComfyUI_00001_.png',
'subfolder': '',
'type': 'output',
}
}
}
}
}
};
/// Mock history item
static final Map<String, dynamic> historyItemResponse = {
'prompt_id': '123456789',
'prompt': {
// Prompt data
},
'outputs': {
'8': {
'images': {
'filename': 'ComfyUI_00001_.png',
'subfolder': '',
'type': 'output',
}
}
}
};
/// Mock checkpoints response
static final Map<String, dynamic> checkpointsResponse = {
'models/checkpoints/dreamshaper_8.safetensors': {
'filename': 'dreamshaper_8.safetensors',
'folder': 'models/checkpoints',
},
'models/checkpoints/sd_xl_base_1.0.safetensors': {
'filename': 'sd_xl_base_1.0.safetensors',
'folder': 'models/checkpoints',
}
};
/// Mock checkpoint metadata response
static final Map<String, dynamic> checkpointMetadataResponse = {
'model': {
'type': 'checkpoint',
'title': 'Dreamshaper 8',
'filename': 'dreamshaper_8.safetensors',
'hash': 'abcdef1234567890',
}
};
/// Mock LoRAs response
static final Map<String, dynamic> lorasResponse = {
'models/loras/example_lora.safetensors': {
'filename': 'example_lora.safetensors',
'folder': 'models/loras',
}
};
/// Mock VAE response
static final Map<String, dynamic> vaeResponse = {
'models/vae/example_vae.safetensors': {
'filename': 'example_vae.safetensors',
'folder': 'models/vae',
}
};
/// Mock object info response (simplified)
static final Map<String, dynamic> objectInfoResponse = {
'CheckpointLoaderSimple': {
'input': {
'required': {'ckpt_name': 'STRING'}
},
'output': ['MODEL', 'CLIP', 'VAE'],
'output_is_list': [false, false, false]
},
'KSampler': {
'input': {
'required': {
'model': 'MODEL',
'seed': 'INT',
'steps': 'INT',
'cfg': 'FLOAT',
'sampler_name': 'STRING',
'scheduler': 'STRING',
'positive': 'CONDITIONING',
'negative': 'CONDITIONING',
'latent_image': 'LATENT'
},
'optional': {'denoise': 'FLOAT'}
},
'output': ['LATENT'],
'output_is_list': [false]
}
};
/// Mock prompt request
static final Map<String, dynamic> promptRequest = {
'prompt': {
'3': {
'inputs': {
'seed': 123456789,
'steps': 20,
'cfg': 7,
'sampler_name': 'euler_ancestral',
'scheduler': 'normal',
'denoise': 1,
'model': ['4', 0],
'positive': ['6', 0],
'negative': ['7', 0],
'latent_image': ['5', 0]
},
'class_type': 'KSampler'
}
},
'client_id': 'test-client-id'
};
/// Mock prompt response
static final Map<String, dynamic> promptResponse = {
'prompt_id': '123456789',
'number': 1,
'status': 'success'
};
/// Mock progress update response
static final Map<String, dynamic> progressUpdateResponse = {
'type': 'execution_start',
'data': {'prompt_id': '123456789'}
};
}

View File

@ -1,145 +0,0 @@
import 'dart:async';
import 'dart:convert';
import 'package:comfyui_api_sdk/comfyui_api_sdk.dart';
import 'package:http/http.dart' as http;
import 'package:http/testing.dart';
import 'package:mockito/annotations.dart';
import 'package:mockito/mockito.dart';
import 'package:test/test.dart';
import 'package:web_socket_channel/web_socket_channel.dart';
import 'test_data.dart';
import 'websocket_test.mocks.dart';
@GenerateMocks([http.Client, WebSocketChannel, WebSocketSink, Stream])
void main() {
late MockClient mockClient;
late MockWebSocketChannel mockWebSocketChannel;
late MockWebSocketSink mockWebSocketSink;
late StreamController<dynamic> streamController;
late ComfyUiApi api;
const String testHost = 'http://localhost:8188';
const String testClientId = 'test-client-id';
setUp(() {
mockClient = MockClient();
mockWebSocketChannel = MockWebSocketChannel();
mockWebSocketSink = MockWebSocketSink();
streamController = StreamController<dynamic>.broadcast();
when(mockWebSocketChannel.sink).thenReturn(mockWebSocketSink);
when(mockWebSocketChannel.stream)
.thenAnswer((_) => streamController.stream);
api = ComfyUiApi(
host: testHost,
clientId: testClientId,
httpClient: mockClient,
);
});
tearDown(() {
streamController.close();
});
group('WebSocket functionality', () {
test('connectWebSocket connects to correct URL', () async {
// Use a spy to capture the URI passed to WebSocketChannel.connect
final wsUrl = 'ws://localhost:8188/ws?clientId=$testClientId';
await api.connectWebSocket();
// This is a bit tricky to test without modifying the implementation
// In a real test we'd use a different approach or dependency injection
// For now, we'll just verify that the WebSocket URL format is correct
expect(wsUrl, equals('ws://localhost:8188/ws?clientId=$testClientId'));
});
test('progressUpdates stream emits data received from WebSocket', () async {
// We need a way to provide a mock WebSocketChannel to the API
// For this test, we'll use a modified approach
final mockApi = MockComfyUiApi(
host: testHost,
clientId: testClientId,
httpClient: mockClient,
mockWebSocketChannel: mockWebSocketChannel,
);
// Connect and verify mock WebSocket is used
await mockApi.connectWebSocket();
// Prepare to capture emitted events
final events = <Map<String, dynamic>>[];
final subscription = mockApi.progressUpdates.listen(events.add);
// Send test data through the mock WebSocket
final testData = TestData.progressUpdateResponse;
streamController.add(jsonEncode(testData));
// Wait for async processing
await Future.delayed(Duration(milliseconds: 100));
// Verify the data was emitted
expect(events.length, equals(1));
expect(events.first, equals(testData));
// Clean up
await subscription.cancel();
});
test('dispose closes WebSocket and stream', () async {
final mockApi = MockComfyUiApi(
host: testHost,
clientId: testClientId,
httpClient: mockClient,
mockWebSocketChannel: mockWebSocketChannel,
);
// Connect
await mockApi.connectWebSocket();
// Dispose
mockApi.dispose();
// Verify WebSocket was closed
verify(mockWebSocketSink.close()).called(1);
});
});
}
/// A modified version of ComfyUiApi for testing that allows injecting a mock WebSocketChannel
class MockComfyUiApi extends ComfyUiApi {
final WebSocketChannel? mockWebSocketChannel;
MockComfyUiApi({
required String host,
required String clientId,
required http.Client httpClient,
this.mockWebSocketChannel,
}) : super(
host: host,
clientId: clientId,
httpClient: httpClient,
);
@override
Future<void> connectWebSocket() async {
if (mockWebSocketChannel != null) {
_wsChannel = mockWebSocketChannel;
_wsChannel!.stream.listen((message) {
final data = jsonDecode(message);
_progressController.add(data);
}, onError: (error) {
print('WebSocket error: $error');
}, onDone: () {
print('WebSocket connection closed');
});
} else {
await super.connectWebSocket();
}
}
}