Add initial structure for comfyui_api_sdk with API models and event handling
This commit is contained in:
parent
1f9409ce0e
commit
0b2769310b
@ -57,7 +57,7 @@
|
||||
},
|
||||
{
|
||||
"name": "build_runner",
|
||||
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/build_runner-2.4.14",
|
||||
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/build_runner-2.4.15",
|
||||
"packageUri": "lib/",
|
||||
"languageVersion": "3.6"
|
||||
},
|
||||
@ -133,6 +133,18 @@
|
||||
"packageUri": "lib/",
|
||||
"languageVersion": "3.1"
|
||||
},
|
||||
{
|
||||
"name": "freezed",
|
||||
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/freezed-3.0.4",
|
||||
"packageUri": "lib/",
|
||||
"languageVersion": "3.6"
|
||||
},
|
||||
{
|
||||
"name": "freezed_annotation",
|
||||
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/freezed_annotation-3.0.0",
|
||||
"packageUri": "lib/",
|
||||
"languageVersion": "3.0"
|
||||
},
|
||||
{
|
||||
"name": "frontend_server_client",
|
||||
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/frontend_server_client-4.0.0",
|
||||
@ -153,9 +165,9 @@
|
||||
},
|
||||
{
|
||||
"name": "http",
|
||||
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/http-0.13.6",
|
||||
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/http-1.3.0",
|
||||
"packageUri": "lib/",
|
||||
"languageVersion": "2.19"
|
||||
"languageVersion": "3.4"
|
||||
},
|
||||
{
|
||||
"name": "http_multi_server",
|
||||
@ -188,10 +200,16 @@
|
||||
"languageVersion": "3.0"
|
||||
},
|
||||
{
|
||||
"name": "lints",
|
||||
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/lints-2.1.1",
|
||||
"name": "json_serializable",
|
||||
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/json_serializable-6.9.4",
|
||||
"packageUri": "lib/",
|
||||
"languageVersion": "3.0"
|
||||
"languageVersion": "3.6"
|
||||
},
|
||||
{
|
||||
"name": "lints",
|
||||
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/lints-5.1.1",
|
||||
"packageUri": "lib/",
|
||||
"languageVersion": "3.6"
|
||||
},
|
||||
{
|
||||
"name": "logging",
|
||||
@ -279,9 +297,9 @@
|
||||
},
|
||||
{
|
||||
"name": "shelf_web_socket",
|
||||
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/shelf_web_socket-2.0.1",
|
||||
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/shelf_web_socket-3.0.0",
|
||||
"packageUri": "lib/",
|
||||
"languageVersion": "3.3"
|
||||
"languageVersion": "3.5"
|
||||
},
|
||||
{
|
||||
"name": "source_gen",
|
||||
@ -289,6 +307,12 @@
|
||||
"packageUri": "lib/",
|
||||
"languageVersion": "3.6"
|
||||
},
|
||||
{
|
||||
"name": "source_helper",
|
||||
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/source_helper-1.3.5",
|
||||
"packageUri": "lib/",
|
||||
"languageVersion": "3.4"
|
||||
},
|
||||
{
|
||||
"name": "source_map_stack_trace",
|
||||
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/source_map_stack_trace-2.1.2",
|
||||
@ -307,6 +331,12 @@
|
||||
"packageUri": "lib/",
|
||||
"languageVersion": "3.1"
|
||||
},
|
||||
{
|
||||
"name": "sprintf",
|
||||
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/sprintf-7.0.0",
|
||||
"packageUri": "lib/",
|
||||
"languageVersion": "2.12"
|
||||
},
|
||||
{
|
||||
"name": "stack_trace",
|
||||
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/stack_trace-1.12.1",
|
||||
@ -369,9 +399,9 @@
|
||||
},
|
||||
{
|
||||
"name": "uuid",
|
||||
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/uuid-3.0.7",
|
||||
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/uuid-4.5.1",
|
||||
"packageUri": "lib/",
|
||||
"languageVersion": "2.12"
|
||||
"languageVersion": "3.0"
|
||||
},
|
||||
{
|
||||
"name": "vm_service",
|
||||
@ -387,13 +417,19 @@
|
||||
},
|
||||
{
|
||||
"name": "web",
|
||||
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/web-0.5.1",
|
||||
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/web-1.1.1",
|
||||
"packageUri": "lib/",
|
||||
"languageVersion": "3.4"
|
||||
},
|
||||
{
|
||||
"name": "web_socket",
|
||||
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/web_socket-0.1.6",
|
||||
"packageUri": "lib/",
|
||||
"languageVersion": "3.3"
|
||||
},
|
||||
{
|
||||
"name": "web_socket_channel",
|
||||
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/web_socket_channel-2.4.5",
|
||||
"rootUri": "file:///home/menno/.pub-cache/hosted/pub.dev/web_socket_channel-3.0.2",
|
||||
"packageUri": "lib/",
|
||||
"languageVersion": "3.3"
|
||||
},
|
||||
@ -416,10 +452,8 @@
|
||||
"languageVersion": "3.0"
|
||||
}
|
||||
],
|
||||
"generated": "2025-03-20T10:48:40.871849Z",
|
||||
"generated": "2025-03-20T13:27:48.930532Z",
|
||||
"generator": "pub",
|
||||
"generatorVersion": "3.7.2",
|
||||
"flutterRoot": "file:///home/menno/.flutter/flutter",
|
||||
"flutterVersion": "3.29.2",
|
||||
"pubCache": "file:///home/menno/.pub-cache"
|
||||
}
|
||||
|
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
.dart_tool
|
4
Makefile
Normal file
4
Makefile
Normal file
@ -0,0 +1,4 @@
|
||||
default:
|
||||
|
||||
build-runner:
|
||||
dart run build_runner build --delete-conflicting-outputs
|
@ -1,205 +0,0 @@
|
||||
import 'dart:async';
|
||||
import 'dart:convert';
|
||||
|
||||
import 'package:http/http.dart' as http;
|
||||
import 'package:uuid/uuid.dart';
|
||||
import 'package:web_socket_channel/web_socket_channel.dart';
|
||||
|
||||
class ComfyUiApi {
|
||||
final String host;
|
||||
final String clientId;
|
||||
final http.Client _httpClient;
|
||||
WebSocketChannel? _wsChannel;
|
||||
final StreamController<Map<String, dynamic>> _progressController =
|
||||
StreamController.broadcast();
|
||||
|
||||
/// Stream of progress updates from ComfyUI
|
||||
Stream<Map<String, dynamic>> get progressUpdates =>
|
||||
_progressController.stream;
|
||||
|
||||
/// Creates a new ComfyUI API client
|
||||
///
|
||||
/// [host] The host of the ComfyUI server (e.g. 'http://localhost:8188')
|
||||
/// [clientId] Optional client ID, will be automatically generated if not provided
|
||||
ComfyUiApi({
|
||||
required this.host,
|
||||
String? clientId,
|
||||
http.Client? httpClient,
|
||||
}) : clientId = clientId ?? const Uuid().v4(),
|
||||
_httpClient = httpClient ?? http.Client();
|
||||
|
||||
/// Connects to the WebSocket for progress updates
|
||||
Future<void> connectWebSocket() async {
|
||||
final wsUrl =
|
||||
'ws://${host.replaceFirst(RegExp(r'^https?://'), '')}/ws?clientId=$clientId';
|
||||
_wsChannel = WebSocketChannel.connect(Uri.parse(wsUrl));
|
||||
|
||||
_wsChannel!.stream.listen((message) {
|
||||
final data = jsonDecode(message);
|
||||
_progressController.add(data);
|
||||
}, onError: (error) {
|
||||
print('WebSocket error: $error');
|
||||
}, onDone: () {
|
||||
print('WebSocket connection closed');
|
||||
});
|
||||
}
|
||||
|
||||
/// Closes the WebSocket connection and cleans up resources
|
||||
void dispose() {
|
||||
_wsChannel?.sink.close();
|
||||
_progressController.close();
|
||||
_httpClient.close();
|
||||
}
|
||||
|
||||
/// Gets the current queue status
|
||||
Future<Map<String, dynamic>> getQueue() async {
|
||||
final response = await _httpClient.get(Uri.parse('$host/queue'));
|
||||
_validateResponse(response);
|
||||
return jsonDecode(response.body);
|
||||
}
|
||||
|
||||
/// Gets the history of the queue
|
||||
Future<Map<String, dynamic>> getHistory({int maxItems = 64}) async {
|
||||
final response = await _httpClient
|
||||
.get(Uri.parse('$host/api/history?max_items=$maxItems'));
|
||||
_validateResponse(response);
|
||||
return jsonDecode(response.body);
|
||||
}
|
||||
|
||||
/// Gets image data by filename
|
||||
Future<List<int>> getImage(String filename) async {
|
||||
final response =
|
||||
await _httpClient.get(Uri.parse('$host/api/view?filename=$filename'));
|
||||
_validateResponse(response);
|
||||
return response.bodyBytes;
|
||||
}
|
||||
|
||||
/// Gets a list of all available models
|
||||
Future<Map<String, dynamic>> getModels() async {
|
||||
final response =
|
||||
await _httpClient.get(Uri.parse('$host/api/experiment/models'));
|
||||
_validateResponse(response);
|
||||
return jsonDecode(response.body);
|
||||
}
|
||||
|
||||
/// Gets a list of checkpoints
|
||||
Future<Map<String, dynamic>> getCheckpoints() async {
|
||||
final response = await _httpClient
|
||||
.get(Uri.parse('$host/api/experiment/models/checkpoints'));
|
||||
_validateResponse(response);
|
||||
return jsonDecode(response.body);
|
||||
}
|
||||
|
||||
/// Gets details for a specific checkpoint
|
||||
Future<Map<String, dynamic>> getCheckpointDetails(
|
||||
String pathAndFileName) async {
|
||||
final response = await _httpClient.get(Uri.parse(
|
||||
'$host/api/view_metadata/checkpoints?filename=$pathAndFileName'));
|
||||
_validateResponse(response);
|
||||
return jsonDecode(response.body);
|
||||
}
|
||||
|
||||
/// Gets a list of LoRAs
|
||||
Future<Map<String, dynamic>> getLoras() async {
|
||||
final response =
|
||||
await _httpClient.get(Uri.parse('$host/api/experiment/models/loras'));
|
||||
_validateResponse(response);
|
||||
return jsonDecode(response.body);
|
||||
}
|
||||
|
||||
/// Gets details for a specific LoRA
|
||||
Future<Map<String, dynamic>> getLoraDetails(String pathAndFileName) async {
|
||||
final response = await _httpClient.get(
|
||||
Uri.parse('$host/api/view_metadata/loras?filename=$pathAndFileName'));
|
||||
_validateResponse(response);
|
||||
return jsonDecode(response.body);
|
||||
}
|
||||
|
||||
/// Gets a list of VAEs
|
||||
Future<Map<String, dynamic>> getVaes() async {
|
||||
final response =
|
||||
await _httpClient.get(Uri.parse('$host/api/experiment/models/vae'));
|
||||
_validateResponse(response);
|
||||
return jsonDecode(response.body);
|
||||
}
|
||||
|
||||
/// Gets details for a specific VAE
|
||||
Future<Map<String, dynamic>> getVaeDetails(String pathAndFileName) async {
|
||||
final response = await _httpClient.get(
|
||||
Uri.parse('$host/api/view_metadata/vae?filename=$pathAndFileName'));
|
||||
_validateResponse(response);
|
||||
return jsonDecode(response.body);
|
||||
}
|
||||
|
||||
/// Gets a list of upscale models
|
||||
Future<Map<String, dynamic>> getUpscaleModels() async {
|
||||
final response = await _httpClient
|
||||
.get(Uri.parse('$host/api/experiment/models/upscale_models'));
|
||||
_validateResponse(response);
|
||||
return jsonDecode(response.body);
|
||||
}
|
||||
|
||||
/// Gets details for a specific upscale model
|
||||
Future<Map<String, dynamic>> getUpscaleModelDetails(
|
||||
String pathAndFileName) async {
|
||||
final response = await _httpClient.get(Uri.parse(
|
||||
'$host/api/view_metadata/upscale_models?filename=$pathAndFileName'));
|
||||
_validateResponse(response);
|
||||
return jsonDecode(response.body);
|
||||
}
|
||||
|
||||
/// Gets a list of embeddings
|
||||
Future<Map<String, dynamic>> getEmbeddings() async {
|
||||
final response = await _httpClient
|
||||
.get(Uri.parse('$host/api/experiment/models/embeddings'));
|
||||
_validateResponse(response);
|
||||
return jsonDecode(response.body);
|
||||
}
|
||||
|
||||
/// Gets details for a specific embedding
|
||||
Future<Map<String, dynamic>> getEmbeddingDetails(
|
||||
String pathAndFileName) async {
|
||||
final response = await _httpClient.get(Uri.parse(
|
||||
'$host/api/view_metadata/embeddings?filename=$pathAndFileName'));
|
||||
_validateResponse(response);
|
||||
return jsonDecode(response.body);
|
||||
}
|
||||
|
||||
/// Gets information about all available objects (nodes)
|
||||
Future<Map<String, dynamic>> getObjectInfo() async {
|
||||
final response = await _httpClient.get(Uri.parse('$host/api/object_info'));
|
||||
_validateResponse(response);
|
||||
return jsonDecode(response.body);
|
||||
}
|
||||
|
||||
/// Submits a prompt (workflow) to generate an image
|
||||
Future<Map<String, dynamic>> submitPrompt(Map<String, dynamic> prompt) async {
|
||||
final response = await _httpClient.post(
|
||||
Uri.parse('$host/api/prompt'),
|
||||
headers: {'Content-Type': 'application/json'},
|
||||
body: jsonEncode(prompt),
|
||||
);
|
||||
_validateResponse(response);
|
||||
return jsonDecode(response.body);
|
||||
}
|
||||
|
||||
/// Validates HTTP response and throws an exception if needed
|
||||
void _validateResponse(http.Response response) {
|
||||
if (response.statusCode < 200 || response.statusCode >= 300) {
|
||||
throw ComfyUiApiException(
|
||||
statusCode: response.statusCode,
|
||||
message: 'API request failed: ${response.body}');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Exception thrown when the ComfyUI API returns an error
|
||||
class ComfyUiApiException implements Exception {
|
||||
final int statusCode;
|
||||
final String message;
|
||||
|
||||
ComfyUiApiException({required this.statusCode, required this.message});
|
||||
|
||||
@override
|
||||
String toString() => 'ComfyUiApiException: $statusCode - $message';
|
||||
}
|
@ -1,94 +0,0 @@
|
||||
import 'dart:io';
|
||||
import 'package:comfyui_api_sdk/comfyui_api_sdk.dart';
|
||||
|
||||
void main() async {
|
||||
// Create the API client
|
||||
final api = ComfyUiApi(host: 'http://mennos-server:7860');
|
||||
|
||||
// Connect to the WebSocket for progress updates
|
||||
await api.connectWebSocket();
|
||||
|
||||
// Listen for progress updates
|
||||
api.progressUpdates.listen((update) {
|
||||
print('Progress update: $update');
|
||||
});
|
||||
|
||||
// Get available checkpoints
|
||||
final checkpoints = await api.getCheckpoints();
|
||||
print('Available checkpoints: ${checkpoints.keys.join(', ')}');
|
||||
|
||||
// Get queue status
|
||||
final queue = await api.getQueue();
|
||||
print('Queue status: $queue');
|
||||
|
||||
// Submit a basic text-to-image prompt
|
||||
final promptWorkflow = {
|
||||
"prompt": {
|
||||
"3": {
|
||||
"inputs": {
|
||||
"seed": 123456789,
|
||||
"steps": 20,
|
||||
"cfg": 7,
|
||||
"sampler_name": "euler_ancestral",
|
||||
"scheduler": "normal",
|
||||
"denoise": 1,
|
||||
"model": ["4", 0],
|
||||
"positive": ["6", 0],
|
||||
"negative": ["7", 0],
|
||||
"latent_image": ["5", 0]
|
||||
},
|
||||
"class_type": "KSampler"
|
||||
},
|
||||
"4": {
|
||||
"inputs": {"ckpt_name": "dreamshaper_8.safetensors"},
|
||||
"class_type": "CheckpointLoaderSimple"
|
||||
},
|
||||
"5": {
|
||||
"inputs": {"width": 512, "height": 512, "batch_size": 1},
|
||||
"class_type": "EmptyLatentImage"
|
||||
},
|
||||
"6": {
|
||||
"inputs": {
|
||||
"text": "a beautiful landscape with mountains and a lake",
|
||||
"clip": ["4", 1]
|
||||
},
|
||||
"class_type": "CLIPTextEncode"
|
||||
},
|
||||
"7": {
|
||||
"inputs": {
|
||||
"text": "ugly, blurry, low quality",
|
||||
"clip": ["4", 1]
|
||||
},
|
||||
"class_type": "CLIPTextEncode"
|
||||
},
|
||||
"8": {
|
||||
"inputs": {
|
||||
"samples": ["3", 0],
|
||||
"vae": ["4", 2]
|
||||
},
|
||||
"class_type": "VAEDecode"
|
||||
},
|
||||
"9": {
|
||||
"inputs": {
|
||||
"filename_prefix": "ComfyUI",
|
||||
"images": ["8", 0]
|
||||
},
|
||||
"class_type": "SaveImage"
|
||||
}
|
||||
},
|
||||
"client_id": api.clientId
|
||||
};
|
||||
|
||||
try {
|
||||
final result = await api.submitPrompt(promptWorkflow);
|
||||
print('Prompt submitted: $result');
|
||||
} catch (e) {
|
||||
print('Error submitting prompt: $e');
|
||||
}
|
||||
|
||||
// Wait for some time to receive WebSocket messages
|
||||
await Future.delayed(Duration(seconds: 60));
|
||||
|
||||
// Clean up
|
||||
api.dispose();
|
||||
}
|
@ -1,4 +1,6 @@
|
||||
library comfyui_api_sdk;
|
||||
|
||||
export 'src/comfyui_api.dart';
|
||||
export 'src/models/models.dart';
|
||||
export 'src/models/websocket_event.dart';
|
||||
export 'src/models/progress_event.dart';
|
||||
export 'src/models/execution_event.dart';
|
||||
|
@ -5,22 +5,69 @@ import 'package:http/http.dart' as http;
|
||||
import 'package:uuid/uuid.dart';
|
||||
import 'package:web_socket_channel/web_socket_channel.dart';
|
||||
|
||||
import 'models/websocket_event.dart';
|
||||
import 'models/progress_event.dart';
|
||||
import 'models/execution_event.dart';
|
||||
|
||||
/// Callback function type for prompt events
|
||||
typedef PromptEventCallback = void Function(String promptId);
|
||||
|
||||
/// Callback function type for typed WebSocket events
|
||||
typedef WebSocketEventCallback = void Function(WebSocketEvent event);
|
||||
|
||||
/// Callback function type for progress events
|
||||
typedef ProgressEventCallback = void Function(ProgressEvent event);
|
||||
|
||||
/// A Dart SDK for interacting with the ComfyUI API
|
||||
class ComfyUiApi {
|
||||
final String host;
|
||||
final String clientId;
|
||||
final http.Client _httpClient;
|
||||
WebSocketChannel? _wsChannel;
|
||||
|
||||
// Controllers for different event streams
|
||||
final StreamController<WebSocketEvent> _eventController =
|
||||
StreamController.broadcast();
|
||||
final StreamController<Map<String, dynamic>> _progressController =
|
||||
StreamController.broadcast();
|
||||
|
||||
/// Stream of progress updates from ComfyUI
|
||||
// Add new controllers for specific event types
|
||||
final StreamController<ProgressEvent> _progressEventController =
|
||||
StreamController.broadcast();
|
||||
final StreamController<ExecutionEvent> _executionEventController =
|
||||
StreamController.broadcast();
|
||||
|
||||
/// Stream of typed progress events
|
||||
Stream<ProgressEvent> get progressEvents => _progressEventController.stream;
|
||||
|
||||
/// Stream of typed execution events
|
||||
Stream<ExecutionEvent> get executionEvents =>
|
||||
_executionEventController.stream;
|
||||
|
||||
// Event callbacks
|
||||
final Map<String, List<PromptEventCallback>> _eventCallbacks = {
|
||||
'onPromptStart': [],
|
||||
'onPromptFinished': [],
|
||||
};
|
||||
|
||||
final Map<WebSocketEventType, List<WebSocketEventCallback>>
|
||||
_typedEventCallbacks = {
|
||||
for (var type in WebSocketEventType.values) type: [],
|
||||
};
|
||||
|
||||
// Add a separate map for progress event callbacks
|
||||
final List<ProgressEventCallback> _progressEventCallbacks = [];
|
||||
|
||||
/// Stream of typed WebSocket events
|
||||
Stream<WebSocketEvent> get events => _eventController.stream;
|
||||
|
||||
/// Stream of progress updates from ComfyUI (legacy format)
|
||||
Stream<Map<String, dynamic>> get progressUpdates =>
|
||||
_progressController.stream;
|
||||
|
||||
/// Creates a new ComfyUI API client
|
||||
///
|
||||
/// [host] The host of the ComfyUI server (e.g. 'http://localhost:8188')
|
||||
/// [host] The host of the ComfyUI server (e.g. 'http://localhost:7860')
|
||||
/// [clientId] Optional client ID, will be automatically generated if not provided
|
||||
ComfyUiApi({
|
||||
required this.host,
|
||||
@ -29,6 +76,26 @@ class ComfyUiApi {
|
||||
}) : clientId = clientId ?? const Uuid().v4(),
|
||||
_httpClient = httpClient ?? http.Client();
|
||||
|
||||
/// Register a callback for when a prompt starts processing
|
||||
void onPromptStart(PromptEventCallback callback) {
|
||||
_eventCallbacks['onPromptStart']!.add(callback);
|
||||
}
|
||||
|
||||
/// Register a callback for when a prompt finishes processing
|
||||
void onPromptFinished(PromptEventCallback callback) {
|
||||
_eventCallbacks['onPromptFinished']!.add(callback);
|
||||
}
|
||||
|
||||
/// Register a callback for specific WebSocket event types
|
||||
void onEventType(WebSocketEventType type, WebSocketEventCallback callback) {
|
||||
_typedEventCallbacks[type]!.add(callback);
|
||||
}
|
||||
|
||||
/// Register a callback for progress updates
|
||||
void onProgressChanged(ProgressEventCallback callback) {
|
||||
_progressEventCallbacks.add(callback);
|
||||
}
|
||||
|
||||
/// Connects to the WebSocket for progress updates
|
||||
Future<void> connectWebSocket() async {
|
||||
final wsUrl =
|
||||
@ -36,8 +103,43 @@ class ComfyUiApi {
|
||||
_wsChannel = WebSocketChannel.connect(Uri.parse(wsUrl));
|
||||
|
||||
_wsChannel!.stream.listen((message) {
|
||||
final data = jsonDecode(message);
|
||||
_progressController.add(data);
|
||||
final jsonData = jsonDecode(message);
|
||||
|
||||
// Create a typed event
|
||||
final event = WebSocketEvent.fromJson(jsonData);
|
||||
|
||||
// Add to the typed event stream
|
||||
_eventController.add(event);
|
||||
|
||||
// Also add to the legacy progress stream
|
||||
_progressController.add(jsonData);
|
||||
|
||||
// Convert to more specific event types if possible
|
||||
if (event.eventType == WebSocketEventType.progress) {
|
||||
_tryCreateProgressEvent(event);
|
||||
} else if ([
|
||||
WebSocketEventType.executionStart,
|
||||
WebSocketEventType.executionSuccess,
|
||||
WebSocketEventType.executionCached,
|
||||
WebSocketEventType.executed,
|
||||
WebSocketEventType.executing,
|
||||
].contains(event.eventType)) {
|
||||
_tryCreateExecutionEvent(event);
|
||||
}
|
||||
|
||||
// Trigger event type specific callbacks
|
||||
for (final callback in _typedEventCallbacks[event.eventType]!) {
|
||||
callback(event);
|
||||
}
|
||||
|
||||
// Handle execution_success event (prompt finished)
|
||||
if (event.eventType == WebSocketEventType.executionSuccess &&
|
||||
event.promptId != null) {
|
||||
final promptId = event.promptId!;
|
||||
for (final callback in _eventCallbacks['onPromptFinished']!) {
|
||||
callback(promptId);
|
||||
}
|
||||
}
|
||||
}, onError: (error) {
|
||||
print('WebSocket error: $error');
|
||||
}, onDone: () {
|
||||
@ -45,10 +147,55 @@ class ComfyUiApi {
|
||||
});
|
||||
}
|
||||
|
||||
/// Attempts to create a ProgressEvent from a WebSocketEvent
|
||||
void _tryCreateProgressEvent(WebSocketEvent event) {
|
||||
if (event.data.containsKey('value') &&
|
||||
event.data.containsKey('max') &&
|
||||
event.data.containsKey('prompt_id')) {
|
||||
try {
|
||||
final progressEvent = ProgressEvent(
|
||||
value: event.data['value'] as int,
|
||||
max: event.data['max'] as int,
|
||||
promptId: event.data['prompt_id'] as String,
|
||||
node: event.data['node']?.toString(),
|
||||
);
|
||||
_progressEventController.add(progressEvent);
|
||||
|
||||
// Trigger all registered progress callbacks
|
||||
for (final callback in _progressEventCallbacks) {
|
||||
callback(progressEvent);
|
||||
}
|
||||
} catch (e) {
|
||||
print('Error creating ProgressEvent: $e');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempts to create an ExecutionEvent from a WebSocketEvent
|
||||
void _tryCreateExecutionEvent(WebSocketEvent event) {
|
||||
if (event.data.containsKey('prompt_id')) {
|
||||
try {
|
||||
final executionEvent = ExecutionEvent(
|
||||
promptId: event.data['prompt_id'] as String,
|
||||
timestamp: event.data['timestamp'] as int? ??
|
||||
DateTime.now().millisecondsSinceEpoch,
|
||||
node: event.data['node']?.toString(),
|
||||
extra: event.data['extra'] as Map<String, dynamic>?,
|
||||
);
|
||||
_executionEventController.add(executionEvent);
|
||||
} catch (e) {
|
||||
print('Error creating ExecutionEvent: $e');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Closes the WebSocket connection and cleans up resources
|
||||
void dispose() {
|
||||
_wsChannel?.sink.close();
|
||||
_progressController.close();
|
||||
_eventController.close();
|
||||
_progressEventController.close();
|
||||
_executionEventController.close();
|
||||
_httpClient.close();
|
||||
}
|
||||
|
||||
@ -181,7 +328,17 @@ class ComfyUiApi {
|
||||
body: jsonEncode(prompt),
|
||||
);
|
||||
_validateResponse(response);
|
||||
return jsonDecode(response.body);
|
||||
final responseData = jsonDecode(response.body);
|
||||
|
||||
// Trigger onPromptStart event if prompt_id exists
|
||||
if (responseData.containsKey('prompt_id')) {
|
||||
final promptId = responseData['prompt_id'];
|
||||
for (final callback in _eventCallbacks['onPromptStart']!) {
|
||||
callback(promptId);
|
||||
}
|
||||
}
|
||||
|
||||
return responseData;
|
||||
}
|
||||
|
||||
/// Validates HTTP response and throws an exception if needed
|
||||
|
7
lib/src/comfyui_api_sdk.dart
Normal file
7
lib/src/comfyui_api_sdk.dart
Normal file
@ -0,0 +1,7 @@
|
||||
// Main API
|
||||
export 'comfyui_api.dart';
|
||||
|
||||
// Models
|
||||
export 'models/websocket_event.dart';
|
||||
export 'models/progress_event.dart';
|
||||
export 'models/execution_event.dart';
|
11
lib/src/models/callbacks.dart
Normal file
11
lib/src/models/callbacks.dart
Normal file
@ -0,0 +1,11 @@
|
||||
import 'websocket_event.dart';
|
||||
import 'progress_event.dart';
|
||||
|
||||
/// Callback function type for prompt events
|
||||
typedef PromptEventCallback = void Function(String promptId);
|
||||
|
||||
/// Callback function type for typed WebSocket events
|
||||
typedef WebSocketEventCallback = void Function(WebSocketEvent event);
|
||||
|
||||
/// Callback function type for progress events
|
||||
typedef ProgressEventCallback = void Function(ProgressEvent event);
|
31
lib/src/models/execution_event.dart
Normal file
31
lib/src/models/execution_event.dart
Normal file
@ -0,0 +1,31 @@
|
||||
class ExecutionEvent {
|
||||
final String promptId;
|
||||
final int timestamp;
|
||||
final String? node;
|
||||
final Map<String, dynamic>? extra;
|
||||
|
||||
const ExecutionEvent({
|
||||
required this.promptId,
|
||||
required this.timestamp,
|
||||
this.node,
|
||||
this.extra,
|
||||
});
|
||||
|
||||
factory ExecutionEvent.fromJson(Map<String, dynamic> json) {
|
||||
return ExecutionEvent(
|
||||
promptId: json['prompt_id'] as String,
|
||||
timestamp: json['timestamp'] as int,
|
||||
node: json['node'] as String?,
|
||||
extra: json['extra'] as Map<String, dynamic>?,
|
||||
);
|
||||
}
|
||||
|
||||
Map<String, dynamic> toJson() {
|
||||
return {
|
||||
'prompt_id': promptId,
|
||||
'timestamp': timestamp,
|
||||
'node': node,
|
||||
'extra': extra,
|
||||
};
|
||||
}
|
||||
}
|
39
lib/src/models/progress_event.dart
Normal file
39
lib/src/models/progress_event.dart
Normal file
@ -0,0 +1,39 @@
|
||||
class ProgressEvent {
|
||||
final int value;
|
||||
final int max;
|
||||
final String promptId;
|
||||
final String? node;
|
||||
|
||||
const ProgressEvent({
|
||||
required this.value,
|
||||
required this.max,
|
||||
required this.promptId,
|
||||
this.node,
|
||||
});
|
||||
|
||||
factory ProgressEvent.fromJson(Map<String, dynamic> json) {
|
||||
return ProgressEvent(
|
||||
value: json['value'] as int,
|
||||
max: json['max'] as int,
|
||||
promptId: json['prompt_id'] as String,
|
||||
node: json['node'] as String?,
|
||||
);
|
||||
}
|
||||
|
||||
Map<String, dynamic> toJson() {
|
||||
return {
|
||||
'value': value,
|
||||
'max': max,
|
||||
'prompt_id': promptId,
|
||||
'node': node,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
extension ProgressEventComputation on ProgressEvent {
|
||||
/// Returns the progress percentage (0-100)
|
||||
double get percentage => (value / max) * 100;
|
||||
|
||||
/// Returns true if the progress is complete
|
||||
bool get isComplete => value >= max;
|
||||
}
|
109
lib/src/models/websocket_event.dart
Normal file
109
lib/src/models/websocket_event.dart
Normal file
@ -0,0 +1,109 @@
|
||||
/// Types of WebSocket events from ComfyUI
|
||||
enum WebSocketEventType {
|
||||
status,
|
||||
progress,
|
||||
executing,
|
||||
executed,
|
||||
executionStart,
|
||||
executionCached,
|
||||
executionSuccess,
|
||||
executionError,
|
||||
dataOutput,
|
||||
unknown
|
||||
}
|
||||
|
||||
/// A typed event from the ComfyUI WebSocket
|
||||
class WebSocketEvent {
|
||||
final WebSocketEventType eventType;
|
||||
final Map<String, dynamic> data;
|
||||
final String? promptId;
|
||||
final String rawType;
|
||||
|
||||
WebSocketEvent({
|
||||
required this.eventType,
|
||||
required this.data,
|
||||
this.promptId,
|
||||
required this.rawType,
|
||||
});
|
||||
|
||||
@override
|
||||
String toString() =>
|
||||
'WebSocketEvent{type: $eventType, rawType: $rawType, promptId: $promptId, data: ${data.keys.join(', ')}}';
|
||||
|
||||
/// Creates a WebSocketEvent from JSON
|
||||
factory WebSocketEvent.fromJson(Map<String, dynamic> json) {
|
||||
WebSocketEventType type;
|
||||
Map<String, dynamic> eventData = {};
|
||||
String? promptId;
|
||||
String rawType = json['type']?.toString() ?? 'unknown';
|
||||
|
||||
// Extract event type
|
||||
if (json.containsKey('type')) {
|
||||
final typeStr = json['type'].toString();
|
||||
|
||||
// First, try to extract prompt_id from the data field
|
||||
if (json.containsKey('data') && json['data'] is Map) {
|
||||
final dataMap = Map<String, dynamic>.from(json['data'] as Map);
|
||||
if (dataMap.containsKey('prompt_id')) {
|
||||
promptId = dataMap['prompt_id'].toString();
|
||||
}
|
||||
eventData = dataMap;
|
||||
} else {
|
||||
// If no data field, use the root object
|
||||
eventData = Map<String, dynamic>.from(json);
|
||||
}
|
||||
|
||||
// Try to extract prompt_id from other potential locations
|
||||
if (promptId == null && json.containsKey('prompt_id')) {
|
||||
promptId = json['prompt_id'].toString();
|
||||
}
|
||||
|
||||
switch (typeStr) {
|
||||
case 'status':
|
||||
type = WebSocketEventType.status;
|
||||
break;
|
||||
case 'progress':
|
||||
type = WebSocketEventType.progress;
|
||||
break;
|
||||
case 'executing':
|
||||
type = WebSocketEventType.executing;
|
||||
break;
|
||||
case 'executed':
|
||||
type = WebSocketEventType.executed;
|
||||
break;
|
||||
case 'execution_start':
|
||||
type = WebSocketEventType.executionStart;
|
||||
break;
|
||||
case 'execution_cached':
|
||||
type = WebSocketEventType.executionCached;
|
||||
break;
|
||||
case 'execution_success':
|
||||
type = WebSocketEventType.executionSuccess;
|
||||
break;
|
||||
case 'execution_error':
|
||||
type = WebSocketEventType.executionError;
|
||||
break;
|
||||
default:
|
||||
if (typeStr.startsWith('data_output')) {
|
||||
type = WebSocketEventType.dataOutput;
|
||||
} else {
|
||||
// Default to unknown for unrecognized types
|
||||
type = WebSocketEventType.unknown;
|
||||
print('Unknown event type: $typeStr');
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// If no type field, default to unknown
|
||||
type = WebSocketEventType.unknown;
|
||||
print('WebSocket event missing type field');
|
||||
eventData = Map<String, dynamic>.from(json);
|
||||
}
|
||||
|
||||
return WebSocketEvent(
|
||||
eventType: type,
|
||||
data: eventData,
|
||||
promptId: promptId,
|
||||
rawType: rawType,
|
||||
);
|
||||
}
|
||||
}
|
82
pubspec.lock
82
pubspec.lock
@ -74,13 +74,13 @@ packages:
|
||||
source: hosted
|
||||
version: "2.4.4"
|
||||
build_runner:
|
||||
dependency: "direct main"
|
||||
dependency: "direct dev"
|
||||
description:
|
||||
name: build_runner
|
||||
sha256: "74691599a5bc750dc96a6b4bfd48f7d9d66453eab04c7f4063134800d6a5c573"
|
||||
sha256: "058fe9dce1de7d69c4b84fada934df3e0153dd000758c4d65964d0166779aa99"
|
||||
url: "https://pub.dev"
|
||||
source: hosted
|
||||
version: "2.4.14"
|
||||
version: "2.4.15"
|
||||
build_runner_core:
|
||||
dependency: transitive
|
||||
description:
|
||||
@ -177,6 +177,22 @@ packages:
|
||||
url: "https://pub.dev"
|
||||
source: hosted
|
||||
version: "1.1.1"
|
||||
freezed:
|
||||
dependency: "direct dev"
|
||||
description:
|
||||
name: freezed
|
||||
sha256: "7ed2ddaa47524976d5f2aa91432a79da36a76969edd84170777ac5bea82d797c"
|
||||
url: "https://pub.dev"
|
||||
source: hosted
|
||||
version: "3.0.4"
|
||||
freezed_annotation:
|
||||
dependency: "direct main"
|
||||
description:
|
||||
name: freezed_annotation
|
||||
sha256: c87ff004c8aa6af2d531668b46a4ea379f7191dc6dfa066acd53d506da6e044b
|
||||
url: "https://pub.dev"
|
||||
source: hosted
|
||||
version: "3.0.0"
|
||||
frontend_server_client:
|
||||
dependency: transitive
|
||||
description:
|
||||
@ -205,10 +221,10 @@ packages:
|
||||
dependency: "direct main"
|
||||
description:
|
||||
name: http
|
||||
sha256: "5895291c13fa8a3bd82e76d5627f69e0d85ca6a30dcac95c4ea19a5d555879c2"
|
||||
sha256: fe7ab022b76f3034adc518fb6ea04a82387620e19977665ea18d30a1cf43442f
|
||||
url: "https://pub.dev"
|
||||
source: hosted
|
||||
version: "0.13.6"
|
||||
version: "1.3.0"
|
||||
http_multi_server:
|
||||
dependency: transitive
|
||||
description:
|
||||
@ -242,21 +258,29 @@ packages:
|
||||
source: hosted
|
||||
version: "0.7.2"
|
||||
json_annotation:
|
||||
dependency: transitive
|
||||
dependency: "direct main"
|
||||
description:
|
||||
name: json_annotation
|
||||
sha256: "1ce844379ca14835a50d2f019a3099f419082cfdd231cd86a142af94dd5c6bb1"
|
||||
url: "https://pub.dev"
|
||||
source: hosted
|
||||
version: "4.9.0"
|
||||
json_serializable:
|
||||
dependency: "direct dev"
|
||||
description:
|
||||
name: json_serializable
|
||||
sha256: "81f04dee10969f89f604e1249382d46b97a1ccad53872875369622b5bfc9e58a"
|
||||
url: "https://pub.dev"
|
||||
source: hosted
|
||||
version: "6.9.4"
|
||||
lints:
|
||||
dependency: "direct dev"
|
||||
description:
|
||||
name: lints
|
||||
sha256: "0a217c6c989d21039f1498c3ed9f3ed71b354e69873f13a8dfc3c9fe76f1b452"
|
||||
sha256: c35bb79562d980e9a453fc715854e1ed39e24e7d0297a880ef54e17f9874a9d7
|
||||
url: "https://pub.dev"
|
||||
source: hosted
|
||||
version: "2.1.1"
|
||||
version: "5.1.1"
|
||||
logging:
|
||||
dependency: transitive
|
||||
description:
|
||||
@ -290,7 +314,7 @@ packages:
|
||||
source: hosted
|
||||
version: "2.0.0"
|
||||
mockito:
|
||||
dependency: "direct main"
|
||||
dependency: "direct dev"
|
||||
description:
|
||||
name: mockito
|
||||
sha256: f99d8d072e249f719a5531735d146d8cf04c580d93920b04de75bef6dfb2daf6
|
||||
@ -373,10 +397,10 @@ packages:
|
||||
dependency: transitive
|
||||
description:
|
||||
name: shelf_web_socket
|
||||
sha256: cc36c297b52866d203dbf9332263c94becc2fe0ceaa9681d07b6ef9807023b67
|
||||
sha256: "3632775c8e90d6c9712f883e633716432a27758216dfb61bd86a8321c0580925"
|
||||
url: "https://pub.dev"
|
||||
source: hosted
|
||||
version: "2.0.1"
|
||||
version: "3.0.0"
|
||||
source_gen:
|
||||
dependency: transitive
|
||||
description:
|
||||
@ -385,6 +409,14 @@ packages:
|
||||
url: "https://pub.dev"
|
||||
source: hosted
|
||||
version: "2.0.0"
|
||||
source_helper:
|
||||
dependency: transitive
|
||||
description:
|
||||
name: source_helper
|
||||
sha256: "86d247119aedce8e63f4751bd9626fc9613255935558447569ad42f9f5b48b3c"
|
||||
url: "https://pub.dev"
|
||||
source: hosted
|
||||
version: "1.3.5"
|
||||
source_map_stack_trace:
|
||||
dependency: transitive
|
||||
description:
|
||||
@ -409,6 +441,14 @@ packages:
|
||||
url: "https://pub.dev"
|
||||
source: hosted
|
||||
version: "1.10.1"
|
||||
sprintf:
|
||||
dependency: transitive
|
||||
description:
|
||||
name: sprintf
|
||||
sha256: "1fc9ffe69d4df602376b52949af107d8f5703b77cda567c4d7d86a0693120f23"
|
||||
url: "https://pub.dev"
|
||||
source: hosted
|
||||
version: "7.0.0"
|
||||
stack_trace:
|
||||
dependency: transitive
|
||||
description:
|
||||
@ -493,10 +533,10 @@ packages:
|
||||
dependency: "direct main"
|
||||
description:
|
||||
name: uuid
|
||||
sha256: "648e103079f7c64a36dc7d39369cabb358d377078a051d6ae2ad3aa539519313"
|
||||
sha256: a5be9ef6618a7ac1e964353ef476418026db906c4facdedaa299b7a2e71690ff
|
||||
url: "https://pub.dev"
|
||||
source: hosted
|
||||
version: "3.0.7"
|
||||
version: "4.5.1"
|
||||
vm_service:
|
||||
dependency: transitive
|
||||
description:
|
||||
@ -517,18 +557,26 @@ packages:
|
||||
dependency: transitive
|
||||
description:
|
||||
name: web
|
||||
sha256: "97da13628db363c635202ad97068d47c5b8aa555808e7a9411963c533b449b27"
|
||||
sha256: "868d88a33d8a87b18ffc05f9f030ba328ffefba92d6c127917a2ba740f9cfe4a"
|
||||
url: "https://pub.dev"
|
||||
source: hosted
|
||||
version: "0.5.1"
|
||||
version: "1.1.1"
|
||||
web_socket:
|
||||
dependency: transitive
|
||||
description:
|
||||
name: web_socket
|
||||
sha256: "3c12d96c0c9a4eec095246debcea7b86c0324f22df69893d538fcc6f1b8cce83"
|
||||
url: "https://pub.dev"
|
||||
source: hosted
|
||||
version: "0.1.6"
|
||||
web_socket_channel:
|
||||
dependency: "direct main"
|
||||
description:
|
||||
name: web_socket_channel
|
||||
sha256: "58c6666b342a38816b2e7e50ed0f1e261959630becd4c879c4f26bfa14aa5a42"
|
||||
sha256: "0b8e2457400d8a859b7b2030786835a28a8e80836ef64402abef392ff4f1d0e5"
|
||||
url: "https://pub.dev"
|
||||
source: hosted
|
||||
version: "2.4.5"
|
||||
version: "3.0.2"
|
||||
webkit_inspection_protocol:
|
||||
dependency: transitive
|
||||
description:
|
||||
|
16
pubspec.yaml
16
pubspec.yaml
@ -6,12 +6,16 @@ environment:
|
||||
sdk: '>=3.0.0 <4.0.0'
|
||||
|
||||
dependencies:
|
||||
http: ^0.13.5
|
||||
web_socket_channel: ^2.3.0
|
||||
uuid: ^3.0.7
|
||||
mockito: ^5.4.5
|
||||
build_runner: ^2.4.14
|
||||
http: ^1.3.0
|
||||
web_socket_channel: ^3.0.2
|
||||
uuid: ^4.5.1
|
||||
json_annotation: ^4.9.0
|
||||
freezed_annotation: ^3.0.0
|
||||
|
||||
dev_dependencies:
|
||||
lints: ^2.0.0
|
||||
lints: ^5.1.1
|
||||
test: ^1.21.0
|
||||
mockito: ^5.4.5
|
||||
build_runner: ^2.4.14
|
||||
freezed: ^3.0.4
|
||||
json_serializable: ^6.7.1
|
@ -1,197 +0,0 @@
|
||||
import 'dart:convert';
|
||||
import 'package:comfyui_api_sdk/comfyui_api_sdk.dart';
|
||||
import 'package:http/http.dart' as http;
|
||||
import 'package:http/testing.dart';
|
||||
import 'package:mockito/annotations.dart';
|
||||
import 'package:mockito/mockito.dart';
|
||||
import 'package:test/test.dart';
|
||||
import 'package:web_socket_channel/web_socket_channel.dart';
|
||||
|
||||
import 'comfyui_api_test.mocks.dart';
|
||||
import 'test_data.dart';
|
||||
|
||||
@GenerateMocks([http.Client, WebSocketChannel, WebSocketSink])
|
||||
void main() {
|
||||
late MockClient mockClient;
|
||||
late ComfyUiApi api;
|
||||
const String testHost = 'http://localhost:8188';
|
||||
const String testClientId = 'test-client-id';
|
||||
|
||||
setUp(() {
|
||||
mockClient = MockClient();
|
||||
api = ComfyUiApi(
|
||||
host: testHost,
|
||||
clientId: testClientId,
|
||||
httpClient: mockClient,
|
||||
);
|
||||
});
|
||||
|
||||
group('ComfyUiApi', () {
|
||||
test('initialize with provided values', () {
|
||||
expect(api.host, equals(testHost));
|
||||
expect(api.clientId, equals(testClientId));
|
||||
});
|
||||
|
||||
test('initialize with generated clientId when not provided', () {
|
||||
final autoApi = ComfyUiApi(host: testHost, httpClient: mockClient);
|
||||
expect(autoApi.clientId, isNotEmpty);
|
||||
expect(autoApi.clientId, isNot(equals(testClientId)));
|
||||
});
|
||||
|
||||
test('getQueue returns parsed response', () async {
|
||||
when(mockClient.get(Uri.parse('$testHost/queue'))).thenAnswer(
|
||||
(_) async => http.Response(jsonEncode(TestData.queueResponse), 200));
|
||||
|
||||
final result = await api.getQueue();
|
||||
|
||||
expect(result, equals(TestData.queueResponse));
|
||||
verify(mockClient.get(Uri.parse('$testHost/queue'))).called(1);
|
||||
});
|
||||
|
||||
test('getHistory returns parsed response', () async {
|
||||
when(mockClient.get(Uri.parse('$testHost/api/history?max_items=64')))
|
||||
.thenAnswer((_) async =>
|
||||
http.Response(jsonEncode(TestData.historyResponse), 200));
|
||||
|
||||
final result = await api.getHistory();
|
||||
|
||||
expect(result, equals(TestData.historyResponse));
|
||||
verify(mockClient.get(Uri.parse('$testHost/api/history?max_items=64')))
|
||||
.called(1);
|
||||
});
|
||||
|
||||
test('getImage returns image bytes', () async {
|
||||
final bytes = [1, 2, 3, 4];
|
||||
when(mockClient.get(Uri.parse('$testHost/api/view?filename=test.png')))
|
||||
.thenAnswer((_) async => http.Response.bytes(bytes, 200));
|
||||
|
||||
final result = await api.getImage('test.png');
|
||||
|
||||
expect(result, equals(bytes));
|
||||
verify(mockClient.get(Uri.parse('$testHost/api/view?filename=test.png')))
|
||||
.called(1);
|
||||
});
|
||||
|
||||
test('getCheckpoints returns parsed response', () async {
|
||||
when(mockClient
|
||||
.get(Uri.parse('$testHost/api/experiment/models/checkpoints')))
|
||||
.thenAnswer((_) async =>
|
||||
http.Response(jsonEncode(TestData.checkpointsResponse), 200));
|
||||
|
||||
final result = await api.getCheckpoints();
|
||||
|
||||
expect(result, equals(TestData.checkpointsResponse));
|
||||
verify(mockClient
|
||||
.get(Uri.parse('$testHost/api/experiment/models/checkpoints')))
|
||||
.called(1);
|
||||
});
|
||||
|
||||
test('getCheckpointDetails returns parsed response', () async {
|
||||
const filename = 'models/checkpoints/test.safetensors';
|
||||
when(mockClient.get(Uri.parse(
|
||||
'$testHost/api/view_metadata/checkpoints?filename=$filename')))
|
||||
.thenAnswer((_) async => http.Response(
|
||||
jsonEncode(TestData.checkpointMetadataResponse), 200));
|
||||
|
||||
final result = await api.getCheckpointDetails(filename);
|
||||
|
||||
expect(result, equals(TestData.checkpointMetadataResponse));
|
||||
verify(mockClient.get(Uri.parse(
|
||||
'$testHost/api/view_metadata/checkpoints?filename=$filename')))
|
||||
.called(1);
|
||||
});
|
||||
|
||||
test('getLoras returns parsed response', () async {
|
||||
when(mockClient.get(Uri.parse('$testHost/api/experiment/models/loras')))
|
||||
.thenAnswer((_) async =>
|
||||
http.Response(jsonEncode(TestData.lorasResponse), 200));
|
||||
|
||||
final result = await api.getLoras();
|
||||
|
||||
expect(result, equals(TestData.lorasResponse));
|
||||
verify(mockClient.get(Uri.parse('$testHost/api/experiment/models/loras')))
|
||||
.called(1);
|
||||
});
|
||||
|
||||
test('getVaes returns parsed response', () async {
|
||||
when(mockClient.get(Uri.parse('$testHost/api/experiment/models/vae')))
|
||||
.thenAnswer((_) async =>
|
||||
http.Response(jsonEncode(TestData.vaeResponse), 200));
|
||||
|
||||
final result = await api.getVaes();
|
||||
|
||||
expect(result, equals(TestData.vaeResponse));
|
||||
verify(mockClient.get(Uri.parse('$testHost/api/experiment/models/vae')))
|
||||
.called(1);
|
||||
});
|
||||
|
||||
test('getObjectInfo returns parsed response', () async {
|
||||
when(mockClient.get(Uri.parse('$testHost/api/object_info'))).thenAnswer(
|
||||
(_) async =>
|
||||
http.Response(jsonEncode(TestData.objectInfoResponse), 200));
|
||||
|
||||
final result = await api.getObjectInfo();
|
||||
|
||||
expect(result, equals(TestData.objectInfoResponse));
|
||||
verify(mockClient.get(Uri.parse('$testHost/api/object_info'))).called(1);
|
||||
});
|
||||
|
||||
test('submitPrompt returns parsed response', () async {
|
||||
when(mockClient.post(
|
||||
Uri.parse('$testHost/api/prompt'),
|
||||
headers: {'Content-Type': 'application/json'},
|
||||
body: jsonEncode(TestData.promptRequest),
|
||||
)).thenAnswer(
|
||||
(_) async => http.Response(jsonEncode(TestData.promptResponse), 200));
|
||||
|
||||
final result = await api.submitPrompt(TestData.promptRequest);
|
||||
|
||||
expect(result, equals(TestData.promptResponse));
|
||||
verify(mockClient.post(
|
||||
Uri.parse('$testHost/api/prompt'),
|
||||
headers: {'Content-Type': 'application/json'},
|
||||
body: jsonEncode(TestData.promptRequest),
|
||||
)).called(1);
|
||||
});
|
||||
|
||||
test('throws ComfyUiApiException on error response', () async {
|
||||
when(mockClient.get(Uri.parse('$testHost/queue')))
|
||||
.thenAnswer((_) async => http.Response('Error message', 500));
|
||||
|
||||
expect(() => api.getQueue(), throwsA(isA<ComfyUiApiException>()));
|
||||
});
|
||||
});
|
||||
|
||||
group('Models', () {
|
||||
test('QueueInfo parses from JSON correctly', () {
|
||||
final queueInfo = QueueInfo.fromJson(TestData.queueResponse);
|
||||
|
||||
expect(queueInfo.queueRunning, equals(0));
|
||||
expect(queueInfo.queue.length, equals(0));
|
||||
expect(queueInfo.queuePending, isA<Map<String, dynamic>>());
|
||||
});
|
||||
|
||||
test('PromptExecutionStatus parses from JSON correctly', () {
|
||||
final status = PromptExecutionStatus.fromJson(TestData.promptResponse);
|
||||
|
||||
expect(status.promptId, equals('123456789'));
|
||||
expect(status.number, equals(1));
|
||||
expect(status.status, equals('success'));
|
||||
});
|
||||
|
||||
test('HistoryItem parses from JSON correctly', () {
|
||||
final item = HistoryItem.fromJson(TestData.historyItemResponse);
|
||||
|
||||
expect(item.promptId, equals('123456789'));
|
||||
expect(item.prompt, isA<Map<String, dynamic>>());
|
||||
expect(item.outputs, isA<Map<String, dynamic>>());
|
||||
});
|
||||
|
||||
test('ProgressUpdate parses from JSON correctly', () {
|
||||
final update = ProgressUpdate.fromJson(TestData.progressUpdateResponse);
|
||||
|
||||
expect(update.type, equals('execution_start'));
|
||||
expect(update.data, isA<Map<String, dynamic>>());
|
||||
});
|
||||
});
|
||||
}
|
@ -1,212 +0,0 @@
|
||||
import 'dart:convert';
|
||||
import 'dart:io';
|
||||
import 'package:comfyui_api_sdk/comfyui_api_sdk.dart';
|
||||
import 'package:http/http.dart' as http;
|
||||
import 'package:http/testing.dart';
|
||||
import 'package:test/test.dart';
|
||||
|
||||
void main() {
|
||||
late ComfyUiApi api;
|
||||
late MockClient mockClient;
|
||||
|
||||
const String testHost = 'http://localhost:8188';
|
||||
|
||||
setUp(() {
|
||||
// Setup a MockClient that simulates real API responses
|
||||
mockClient = MockClient((request) async {
|
||||
final uri = request.url;
|
||||
final method = request.method;
|
||||
|
||||
// Simulate queue endpoint
|
||||
if (uri.path == '/queue' && method == 'GET') {
|
||||
return http.Response(
|
||||
jsonEncode({'queue_running': 0, 'queue': [], 'queue_pending': {}}),
|
||||
200,
|
||||
);
|
||||
}
|
||||
|
||||
// Simulate history endpoint
|
||||
if (uri.path == '/api/history' && method == 'GET') {
|
||||
return http.Response(
|
||||
jsonEncode({
|
||||
'History': {
|
||||
'123456789': {
|
||||
'prompt': {
|
||||
// Simplified prompt data
|
||||
'1': {'class_type': 'TestNode'}
|
||||
},
|
||||
'outputs': {
|
||||
'8': {
|
||||
'images': {
|
||||
'filename': 'ComfyUI_00001_.png',
|
||||
'subfolder': '',
|
||||
'type': 'output',
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}),
|
||||
200,
|
||||
);
|
||||
}
|
||||
|
||||
// Simulate checkpoint list endpoint
|
||||
if (uri.path == '/api/experiment/models/checkpoints' && method == 'GET') {
|
||||
return http.Response(
|
||||
jsonEncode({
|
||||
'models/checkpoints/dreamshaper_8.safetensors': {
|
||||
'filename': 'dreamshaper_8.safetensors',
|
||||
'folder': 'models/checkpoints',
|
||||
}
|
||||
}),
|
||||
200,
|
||||
);
|
||||
}
|
||||
|
||||
// Simulate checkpoint metadata endpoint
|
||||
if (uri.path == '/api/view_metadata/checkpoints' && method == 'GET') {
|
||||
return http.Response(
|
||||
jsonEncode({
|
||||
'model': {
|
||||
'type': 'checkpoint',
|
||||
'title': 'Dreamshaper 8',
|
||||
'hash': 'abcdef1234567890',
|
||||
}
|
||||
}),
|
||||
200,
|
||||
);
|
||||
}
|
||||
|
||||
// Simulate object info endpoint
|
||||
if (uri.path == '/api/object_info' && method == 'GET') {
|
||||
return http.Response(
|
||||
jsonEncode({
|
||||
'KSampler': {
|
||||
'input': {
|
||||
'required': {
|
||||
'model': 'MODEL',
|
||||
'seed': 'INT',
|
||||
'steps': 'INT',
|
||||
}
|
||||
},
|
||||
'output': ['LATENT'],
|
||||
'output_is_list': [false]
|
||||
}
|
||||
}),
|
||||
200,
|
||||
);
|
||||
}
|
||||
|
||||
// Simulate prompt submission endpoint
|
||||
if (uri.path == '/api/prompt' && method == 'POST') {
|
||||
return http.Response(
|
||||
jsonEncode(
|
||||
{'prompt_id': '123456789', 'number': 1, 'status': 'success'}),
|
||||
200,
|
||||
);
|
||||
}
|
||||
|
||||
// Simulate image view endpoint
|
||||
if (uri.path == '/api/view' && method == 'GET') {
|
||||
// Return a dummy image
|
||||
return http.Response.bytes([1, 2, 3, 4], 200,
|
||||
headers: {
|
||||
'Content-Type': 'image/png',
|
||||
});
|
||||
}
|
||||
|
||||
// Default response for unhandled routes
|
||||
return http.Response('Not Found', 404);
|
||||
});
|
||||
|
||||
// Create the API with our mock client
|
||||
api = ComfyUiApi(
|
||||
host: testHost,
|
||||
clientId: 'integration-test-client',
|
||||
httpClient: mockClient,
|
||||
);
|
||||
});
|
||||
|
||||
group('Integration Tests', () {
|
||||
test('Get queue information', () async {
|
||||
final queue = await api.getQueue();
|
||||
|
||||
expect(queue['queue_running'], equals(0));
|
||||
expect(queue['queue'], isEmpty);
|
||||
expect(queue['queue_pending'], isA<Map>());
|
||||
});
|
||||
|
||||
test('Get history information', () async {
|
||||
final history = await api.getHistory();
|
||||
|
||||
expect(history['History'], isA<Map>());
|
||||
expect(history['History']['123456789'], isA<Map>());
|
||||
expect(history['History']['123456789']['outputs'], isA<Map>());
|
||||
});
|
||||
|
||||
test('Get checkpoint list', () async {
|
||||
final checkpoints = await api.getCheckpoints();
|
||||
|
||||
expect(checkpoints.keys,
|
||||
contains('models/checkpoints/dreamshaper_8.safetensors'));
|
||||
expect(
|
||||
checkpoints['models/checkpoints/dreamshaper_8.safetensors']
|
||||
['filename'],
|
||||
equals('dreamshaper_8.safetensors'));
|
||||
});
|
||||
|
||||
test('Get checkpoint metadata', () async {
|
||||
final metadata = await api
|
||||
.getCheckpointDetails('models/checkpoints/dreamshaper_8.safetensors');
|
||||
|
||||
expect(metadata['model']['type'], equals('checkpoint'));
|
||||
expect(metadata['model']['title'], equals('Dreamshaper 8'));
|
||||
});
|
||||
|
||||
test('Get object info', () async {
|
||||
final info = await api.getObjectInfo();
|
||||
|
||||
expect(info['KSampler'], isA<Map>());
|
||||
expect(info['KSampler']['input']['required']['seed'], equals('INT'));
|
||||
});
|
||||
|
||||
test('Submit prompt', () async {
|
||||
final promptData = {
|
||||
'prompt': {
|
||||
'1': {
|
||||
'inputs': {'text': 'A beautiful landscape'},
|
||||
'class_type': 'CLIPTextEncode'
|
||||
}
|
||||
},
|
||||
'client_id': 'integration-test-client'
|
||||
};
|
||||
|
||||
final result = await api.submitPrompt(promptData);
|
||||
|
||||
expect(result['prompt_id'], equals('123456789'));
|
||||
expect(result['status'], equals('success'));
|
||||
});
|
||||
|
||||
test('Get image', () async {
|
||||
final imageBytes = await api.getImage('ComfyUI_00001_.png');
|
||||
|
||||
expect(imageBytes, equals([1, 2, 3, 4]));
|
||||
});
|
||||
|
||||
test('Handle error response', () async {
|
||||
// Create a client that always returns an error
|
||||
final errorClient = MockClient((_) async {
|
||||
return http.Response('Server Error', 500);
|
||||
});
|
||||
|
||||
final errorApi = ComfyUiApi(
|
||||
host: testHost,
|
||||
clientId: 'error-test-client',
|
||||
httpClient: errorClient,
|
||||
);
|
||||
|
||||
expect(() => errorApi.getQueue(), throwsA(isA<ComfyUiApiException>()));
|
||||
});
|
||||
});
|
||||
}
|
@ -1,127 +0,0 @@
|
||||
import 'package:comfyui_api_sdk/comfyui_api_sdk.dart';
|
||||
import 'package:test/test.dart';
|
||||
|
||||
void main() {
|
||||
group('QueueInfo', () {
|
||||
test('fromJson creates instance with correct values', () {
|
||||
final json = {
|
||||
'queue_running': 1,
|
||||
'queue': [
|
||||
{'prompt_id': '123', 'number': 1}
|
||||
],
|
||||
'queue_pending': {
|
||||
'456': {'prompt_id': '456', 'number': 2}
|
||||
}
|
||||
};
|
||||
|
||||
final queueInfo = QueueInfo.fromJson(json);
|
||||
|
||||
expect(queueInfo.queueRunning, equals(1));
|
||||
expect(queueInfo.queue.length, equals(1));
|
||||
expect(queueInfo.queue[0]['prompt_id'], equals('123'));
|
||||
expect(queueInfo.queuePending['456']['prompt_id'], equals('456'));
|
||||
});
|
||||
|
||||
test('fromJson handles missing or empty values', () {
|
||||
final json = {'queue_running': 0};
|
||||
|
||||
final queueInfo = QueueInfo.fromJson(json);
|
||||
|
||||
expect(queueInfo.queueRunning, equals(0));
|
||||
expect(queueInfo.queue, isEmpty);
|
||||
expect(queueInfo.queuePending, isEmpty);
|
||||
});
|
||||
});
|
||||
|
||||
group('PromptExecutionStatus', () {
|
||||
test('fromJson creates instance with correct values', () {
|
||||
final json = {
|
||||
'prompt_id': 'abc123',
|
||||
'number': 5,
|
||||
'status': 'processing',
|
||||
'error': null
|
||||
};
|
||||
|
||||
final status = PromptExecutionStatus.fromJson(json);
|
||||
|
||||
expect(status.promptId, equals('abc123'));
|
||||
expect(status.number, equals(5));
|
||||
expect(status.status, equals('processing'));
|
||||
expect(status.error, isNull);
|
||||
});
|
||||
|
||||
test('fromJson handles error information', () {
|
||||
final json = {
|
||||
'prompt_id': 'abc123',
|
||||
'number': 5,
|
||||
'status': 'error',
|
||||
'error': 'Something went wrong'
|
||||
};
|
||||
|
||||
final status = PromptExecutionStatus.fromJson(json);
|
||||
|
||||
expect(status.status, equals('error'));
|
||||
expect(status.error, equals('Something went wrong'));
|
||||
});
|
||||
});
|
||||
|
||||
group('HistoryItem', () {
|
||||
test('fromJson creates instance with correct values', () {
|
||||
final json = {
|
||||
'prompt_id': 'abc123',
|
||||
'prompt': {
|
||||
'1': {'class_type': 'TestNode'}
|
||||
},
|
||||
'outputs': {
|
||||
'2': {
|
||||
'images': {'filename': 'test.png'}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
final item = HistoryItem.fromJson(json);
|
||||
|
||||
expect(item.promptId, equals('abc123'));
|
||||
expect(item.prompt['1']['class_type'], equals('TestNode'));
|
||||
expect(item.outputs?['2']['images']['filename'], equals('test.png'));
|
||||
});
|
||||
|
||||
test('fromJson handles missing outputs', () {
|
||||
final json = {
|
||||
'prompt_id': 'abc123',
|
||||
'prompt': {
|
||||
'1': {'class_type': 'TestNode'}
|
||||
}
|
||||
};
|
||||
|
||||
final item = HistoryItem.fromJson(json);
|
||||
|
||||
expect(item.promptId, equals('abc123'));
|
||||
expect(item.outputs, isNull);
|
||||
});
|
||||
});
|
||||
|
||||
group('ProgressUpdate', () {
|
||||
test('fromJson creates instance with correct values', () {
|
||||
final json = {
|
||||
'type': 'execution_start',
|
||||
'data': {'prompt_id': 'abc123', 'node': 5}
|
||||
};
|
||||
|
||||
final update = ProgressUpdate.fromJson(json);
|
||||
|
||||
expect(update.type, equals('execution_start'));
|
||||
expect(update.data['prompt_id'], equals('abc123'));
|
||||
expect(update.data['node'], equals(5));
|
||||
});
|
||||
|
||||
test('fromJson handles empty data', () {
|
||||
final json = {'type': 'status', 'data': {}};
|
||||
|
||||
final update = ProgressUpdate.fromJson(json);
|
||||
|
||||
expect(update.type, equals('status'));
|
||||
expect(update.data, isEmpty);
|
||||
});
|
||||
});
|
||||
}
|
@ -1,148 +0,0 @@
|
||||
/// Test data for ComfyUI API tests
|
||||
class TestData {
|
||||
/// Mock queue response
|
||||
static final Map<String, dynamic> queueResponse = {
|
||||
'queue_running': 0,
|
||||
'queue': [],
|
||||
'queue_pending': {}
|
||||
};
|
||||
|
||||
/// Mock history response
|
||||
static final Map<String, dynamic> historyResponse = {
|
||||
'History': {
|
||||
'123456789': {
|
||||
'prompt': {
|
||||
// Prompt data
|
||||
},
|
||||
'outputs': {
|
||||
'8': {
|
||||
'images': {
|
||||
'filename': 'ComfyUI_00001_.png',
|
||||
'subfolder': '',
|
||||
'type': 'output',
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Mock history item
|
||||
static final Map<String, dynamic> historyItemResponse = {
|
||||
'prompt_id': '123456789',
|
||||
'prompt': {
|
||||
// Prompt data
|
||||
},
|
||||
'outputs': {
|
||||
'8': {
|
||||
'images': {
|
||||
'filename': 'ComfyUI_00001_.png',
|
||||
'subfolder': '',
|
||||
'type': 'output',
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Mock checkpoints response
|
||||
static final Map<String, dynamic> checkpointsResponse = {
|
||||
'models/checkpoints/dreamshaper_8.safetensors': {
|
||||
'filename': 'dreamshaper_8.safetensors',
|
||||
'folder': 'models/checkpoints',
|
||||
},
|
||||
'models/checkpoints/sd_xl_base_1.0.safetensors': {
|
||||
'filename': 'sd_xl_base_1.0.safetensors',
|
||||
'folder': 'models/checkpoints',
|
||||
}
|
||||
};
|
||||
|
||||
/// Mock checkpoint metadata response
|
||||
static final Map<String, dynamic> checkpointMetadataResponse = {
|
||||
'model': {
|
||||
'type': 'checkpoint',
|
||||
'title': 'Dreamshaper 8',
|
||||
'filename': 'dreamshaper_8.safetensors',
|
||||
'hash': 'abcdef1234567890',
|
||||
}
|
||||
};
|
||||
|
||||
/// Mock LoRAs response
|
||||
static final Map<String, dynamic> lorasResponse = {
|
||||
'models/loras/example_lora.safetensors': {
|
||||
'filename': 'example_lora.safetensors',
|
||||
'folder': 'models/loras',
|
||||
}
|
||||
};
|
||||
|
||||
/// Mock VAE response
|
||||
static final Map<String, dynamic> vaeResponse = {
|
||||
'models/vae/example_vae.safetensors': {
|
||||
'filename': 'example_vae.safetensors',
|
||||
'folder': 'models/vae',
|
||||
}
|
||||
};
|
||||
|
||||
/// Mock object info response (simplified)
|
||||
static final Map<String, dynamic> objectInfoResponse = {
|
||||
'CheckpointLoaderSimple': {
|
||||
'input': {
|
||||
'required': {'ckpt_name': 'STRING'}
|
||||
},
|
||||
'output': ['MODEL', 'CLIP', 'VAE'],
|
||||
'output_is_list': [false, false, false]
|
||||
},
|
||||
'KSampler': {
|
||||
'input': {
|
||||
'required': {
|
||||
'model': 'MODEL',
|
||||
'seed': 'INT',
|
||||
'steps': 'INT',
|
||||
'cfg': 'FLOAT',
|
||||
'sampler_name': 'STRING',
|
||||
'scheduler': 'STRING',
|
||||
'positive': 'CONDITIONING',
|
||||
'negative': 'CONDITIONING',
|
||||
'latent_image': 'LATENT'
|
||||
},
|
||||
'optional': {'denoise': 'FLOAT'}
|
||||
},
|
||||
'output': ['LATENT'],
|
||||
'output_is_list': [false]
|
||||
}
|
||||
};
|
||||
|
||||
/// Mock prompt request
|
||||
static final Map<String, dynamic> promptRequest = {
|
||||
'prompt': {
|
||||
'3': {
|
||||
'inputs': {
|
||||
'seed': 123456789,
|
||||
'steps': 20,
|
||||
'cfg': 7,
|
||||
'sampler_name': 'euler_ancestral',
|
||||
'scheduler': 'normal',
|
||||
'denoise': 1,
|
||||
'model': ['4', 0],
|
||||
'positive': ['6', 0],
|
||||
'negative': ['7', 0],
|
||||
'latent_image': ['5', 0]
|
||||
},
|
||||
'class_type': 'KSampler'
|
||||
}
|
||||
},
|
||||
'client_id': 'test-client-id'
|
||||
};
|
||||
|
||||
/// Mock prompt response
|
||||
static final Map<String, dynamic> promptResponse = {
|
||||
'prompt_id': '123456789',
|
||||
'number': 1,
|
||||
'status': 'success'
|
||||
};
|
||||
|
||||
/// Mock progress update response
|
||||
static final Map<String, dynamic> progressUpdateResponse = {
|
||||
'type': 'execution_start',
|
||||
'data': {'prompt_id': '123456789'}
|
||||
};
|
||||
}
|
@ -1,145 +0,0 @@
|
||||
import 'dart:async';
|
||||
import 'dart:convert';
|
||||
|
||||
import 'package:comfyui_api_sdk/comfyui_api_sdk.dart';
|
||||
import 'package:http/http.dart' as http;
|
||||
import 'package:http/testing.dart';
|
||||
import 'package:mockito/annotations.dart';
|
||||
import 'package:mockito/mockito.dart';
|
||||
import 'package:test/test.dart';
|
||||
import 'package:web_socket_channel/web_socket_channel.dart';
|
||||
|
||||
import 'test_data.dart';
|
||||
import 'websocket_test.mocks.dart';
|
||||
|
||||
@GenerateMocks([http.Client, WebSocketChannel, WebSocketSink, Stream])
|
||||
void main() {
|
||||
late MockClient mockClient;
|
||||
late MockWebSocketChannel mockWebSocketChannel;
|
||||
late MockWebSocketSink mockWebSocketSink;
|
||||
late StreamController<dynamic> streamController;
|
||||
late ComfyUiApi api;
|
||||
|
||||
const String testHost = 'http://localhost:8188';
|
||||
const String testClientId = 'test-client-id';
|
||||
|
||||
setUp(() {
|
||||
mockClient = MockClient();
|
||||
mockWebSocketChannel = MockWebSocketChannel();
|
||||
mockWebSocketSink = MockWebSocketSink();
|
||||
streamController = StreamController<dynamic>.broadcast();
|
||||
|
||||
when(mockWebSocketChannel.sink).thenReturn(mockWebSocketSink);
|
||||
when(mockWebSocketChannel.stream)
|
||||
.thenAnswer((_) => streamController.stream);
|
||||
|
||||
api = ComfyUiApi(
|
||||
host: testHost,
|
||||
clientId: testClientId,
|
||||
httpClient: mockClient,
|
||||
);
|
||||
});
|
||||
|
||||
tearDown(() {
|
||||
streamController.close();
|
||||
});
|
||||
|
||||
group('WebSocket functionality', () {
|
||||
test('connectWebSocket connects to correct URL', () async {
|
||||
// Use a spy to capture the URI passed to WebSocketChannel.connect
|
||||
final wsUrl = 'ws://localhost:8188/ws?clientId=$testClientId';
|
||||
|
||||
await api.connectWebSocket();
|
||||
|
||||
// This is a bit tricky to test without modifying the implementation
|
||||
// In a real test we'd use a different approach or dependency injection
|
||||
// For now, we'll just verify that the WebSocket URL format is correct
|
||||
expect(wsUrl, equals('ws://localhost:8188/ws?clientId=$testClientId'));
|
||||
});
|
||||
|
||||
test('progressUpdates stream emits data received from WebSocket', () async {
|
||||
// We need a way to provide a mock WebSocketChannel to the API
|
||||
// For this test, we'll use a modified approach
|
||||
|
||||
final mockApi = MockComfyUiApi(
|
||||
host: testHost,
|
||||
clientId: testClientId,
|
||||
httpClient: mockClient,
|
||||
mockWebSocketChannel: mockWebSocketChannel,
|
||||
);
|
||||
|
||||
// Connect and verify mock WebSocket is used
|
||||
await mockApi.connectWebSocket();
|
||||
|
||||
// Prepare to capture emitted events
|
||||
final events = <Map<String, dynamic>>[];
|
||||
final subscription = mockApi.progressUpdates.listen(events.add);
|
||||
|
||||
// Send test data through the mock WebSocket
|
||||
final testData = TestData.progressUpdateResponse;
|
||||
streamController.add(jsonEncode(testData));
|
||||
|
||||
// Wait for async processing
|
||||
await Future.delayed(Duration(milliseconds: 100));
|
||||
|
||||
// Verify the data was emitted
|
||||
expect(events.length, equals(1));
|
||||
expect(events.first, equals(testData));
|
||||
|
||||
// Clean up
|
||||
await subscription.cancel();
|
||||
});
|
||||
|
||||
test('dispose closes WebSocket and stream', () async {
|
||||
final mockApi = MockComfyUiApi(
|
||||
host: testHost,
|
||||
clientId: testClientId,
|
||||
httpClient: mockClient,
|
||||
mockWebSocketChannel: mockWebSocketChannel,
|
||||
);
|
||||
|
||||
// Connect
|
||||
await mockApi.connectWebSocket();
|
||||
|
||||
// Dispose
|
||||
mockApi.dispose();
|
||||
|
||||
// Verify WebSocket was closed
|
||||
verify(mockWebSocketSink.close()).called(1);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/// A modified version of ComfyUiApi for testing that allows injecting a mock WebSocketChannel
|
||||
class MockComfyUiApi extends ComfyUiApi {
|
||||
final WebSocketChannel? mockWebSocketChannel;
|
||||
|
||||
MockComfyUiApi({
|
||||
required String host,
|
||||
required String clientId,
|
||||
required http.Client httpClient,
|
||||
this.mockWebSocketChannel,
|
||||
}) : super(
|
||||
host: host,
|
||||
clientId: clientId,
|
||||
httpClient: httpClient,
|
||||
);
|
||||
|
||||
@override
|
||||
Future<void> connectWebSocket() async {
|
||||
if (mockWebSocketChannel != null) {
|
||||
_wsChannel = mockWebSocketChannel;
|
||||
|
||||
_wsChannel!.stream.listen((message) {
|
||||
final data = jsonDecode(message);
|
||||
_progressController.add(data);
|
||||
}, onError: (error) {
|
||||
print('WebSocket error: $error');
|
||||
}, onDone: () {
|
||||
print('WebSocket connection closed');
|
||||
});
|
||||
} else {
|
||||
await super.connectWebSocket();
|
||||
}
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user