diff --git a/lib/src/prompt_builder.dart b/lib/src/prompt_builder.dart index c3b9898..1b6c9c8 100644 --- a/lib/src/prompt_builder.dart +++ b/lib/src/prompt_builder.dart @@ -101,6 +101,133 @@ class PromptBuilder { ..addAll(reorderedNodes); } + /// Inserts a node between an existing node's output and its consumers. + /// + /// - [classType]: The class type of the node to insert. + /// - [inputs]: The specific inputs for the new node (excluding connections being intercepted). + /// - [targetNodeId]: The ID of the node *before* the insertion point. + /// - [targetOutputIndices]: The output indices of the [targetNodeId] to intercept (e.g., [0, 1] for MODEL/CLIP). + /// - [newNodeOutputTags]: A map defining the symbolic tags for the *new* node's outputs, keyed by the output name (e.g., {"MODEL": "new_model_tag", "CLIP": "new_clip_tag"}). + /// - [title]: Optional title for the new node. + /// + /// Returns the ID of the newly inserted node. + String insertNode({ + required String classType, + required Map inputs, + required String targetNodeId, + required List targetOutputIndices, + required Map newNodeOutputTags, + String? title, + }) { + if (!_nodes.containsKey(targetNodeId)) { + throw Exception("Target node with ID $targetNodeId does not exist."); + } + print("[insertNode] Inserting $classType after target $targetNodeId"); // Log 1 + + // 1. Find original symbolic tags for the target node's outputs being intercepted + final Map originalTargetTags = {}; // { outputIndex: tag } + _outputToNode.forEach((tag, info) { + if (info['nodeId'] == targetNodeId && + targetOutputIndices.contains(int.parse(info['outputIndex']))) { + originalTargetTags[int.parse(info['outputIndex'])] = tag; + } + }); + print("[insertNode] Found original target tags: $originalTargetTags"); // Log 2 + + if (originalTargetTags.length != targetOutputIndices.length) { + throw Exception( + "Could not find all specified original output tags for target node $targetNodeId."); + } + + // 2. Find consumers connected to these original tags + final Map> consumers = {}; // { consumerNodeId: { inputKey: originalTargetOutputIndex } } + print("[insertNode] Searching for consumers connected to tags: ${originalTargetTags.values}"); // Log 3 + _nodes.forEach((consumerNodeId, consumerNodeData) { + final consumerInputs = consumerNodeData['inputs'] as Map? ?? {}; + consumerInputs.forEach((inputKey, inputValue) { + if (inputValue is List && inputValue.length == 2 && inputValue[0] is String) { + final sourceTag = inputValue[0] as String; + final sourceIndex = inputValue[1] as int; + // Check if this input is connected to one of the tags we are intercepting + // Added check for null safety on originalTargetTags[sourceIndex] + if (originalTargetTags.containsKey(sourceIndex) && originalTargetTags[sourceIndex] == sourceTag) { + consumers.putIfAbsent(consumerNodeId, () => {})[inputKey] = sourceIndex; + } + } + }); + }); + print("[insertNode] Found consumers: $consumers"); // Log 4 + + // 3. Prepare inputs for the new node (combining provided inputs and intercepted connections) + final Map newNodeInputs = Map.from(inputs); + originalTargetTags.forEach((outputIndex, tag) { + // Determine the input key for the new node based on the output type (assuming MODEL/CLIP convention) + // This might need refinement if inserting other node types + final inputKey = (outputIndex == 0) ? 'model' : (outputIndex == 1) ? 'clip' : null; + if (inputKey != null) { + newNodeInputs[inputKey] = [tag, outputIndex]; + } else { + print("Warning: Could not determine input key for intercepted output index $outputIndex"); + } + }); + print("[insertNode] Prepared new node inputs: $newNodeInputs"); // Log 5 + + + // 4. Add the new node using the existing addNode logic + // Note: We need to ensure the outputTags passed to addNode match the structure it expects (keyed by default output name) + final defaultOutputs = _getDefaultOutputs(classType); // e.g., ["MODEL", "CLIP"] + final Map internalOutputTags = {}; + newNodeOutputTags.forEach((key, value) { + if (defaultOutputs.contains(key)) { + internalOutputTags[key] = value; + } + }); + print("[insertNode] Adding node with internalOutputTags: $internalOutputTags"); // Log 6 + + final newNodeId = addNode( + classType, + newNodeInputs, + title: title, + outputTags: internalOutputTags, // Pass the structured tags + ); + print("[insertNode] Added node with ID: $newNodeId"); // Log 7 + + // 5. Rewire consumers to point to the new node's output tags + print("[insertNode] Starting consumer rewiring..."); // Log 8 + consumers.forEach((consumerNodeId, inputMap) { + print("[insertNode] Rewiring consumer: $consumerNodeId"); // Log 9 + inputMap.forEach((inputKey, originalTargetOutputIndex) { + print("[insertNode] Input: $inputKey (original index: $originalTargetOutputIndex)"); // Log 10 + // Find the corresponding *new* output tag based on the original index + final outputName = (originalTargetOutputIndex == 0) ? "MODEL" : (originalTargetOutputIndex == 1) ? "CLIP" : null; + if (outputName != null && newNodeOutputTags.containsKey(outputName)) { + final newTag = newNodeOutputTags[outputName]!; + final newConnection = [newTag, originalTargetOutputIndex]; + print("[insertNode] New connection: $newConnection (using tag '$newTag')"); // Log 11 + + // Get the consumer's current inputs + final currentConsumerInputs = Map.from( + _nodes[consumerNodeId]?['inputs'] ?? {}); + print("[insertNode] Current inputs before edit: $currentConsumerInputs"); // Log 12 (Changed name slightly) + + // Update only the specific input being rewired + currentConsumerInputs[inputKey] = newConnection; + print("[insertNode] Updated inputs map: $currentConsumerInputs"); // Log 13 (Changed name slightly) + + // Use editNode with the full, updated inputs map + editNode(consumerNodeId, newInputs: currentConsumerInputs); + print("[insertNode] Called editNode for $consumerNodeId"); // Log 14 + } else { + print("Warning: Could not find new output tag or determine output name for consumer $consumerNodeId.$inputKey (original index: $originalTargetOutputIndex)"); + } + }); + }); + print("[insertNode] Finished consumer rewiring."); // Log 15 + + return newNodeId; + } + + /// Generates the final workflow map Map build() { final resolvedNodes = {};