From 0eee5dfd597817819dbe0463a63f671fde53432f Mon Sep 17 00:00:00 2001 From: Danny Martini Date: Wed, 9 Oct 2024 09:34:26 +0200 Subject: [PATCH] fix(core): Fix AI nodes not working with new partial execution flow (#11055) --- .../PartialExecutionUtils/DirectedGraph.ts | 63 ++++++++-- .../__tests__/DirectedGraph.test.ts | 54 +++++++++ .../__tests__/findSubgraph.test.ts | 108 ++++++++++++++++++ .../PartialExecutionUtils/findStartNodes.ts | 2 +- .../src/PartialExecutionUtils/findSubgraph.ts | 48 +++++++- .../recreateNodeExecutionStack.ts | 4 +- 6 files changed, 260 insertions(+), 19 deletions(-) diff --git a/packages/core/src/PartialExecutionUtils/DirectedGraph.ts b/packages/core/src/PartialExecutionUtils/DirectedGraph.ts index 33cc114698..606f624d02 100644 --- a/packages/core/src/PartialExecutionUtils/DirectedGraph.ts +++ b/packages/core/src/PartialExecutionUtils/DirectedGraph.ts @@ -12,6 +12,11 @@ export type GraphConnection = { // fromName-outputType-outputIndex-inputIndex-toName type DirectedGraphKey = `${string}-${NodeConnectionType}-${number}-${number}-${string}`; +type RemoveNodeBaseOptions = { + reconnectConnections: boolean; + skipConnectionFn?: (connection: GraphConnection) => boolean; +}; + /** * Represents a directed graph as an adjacency list, e.g. one list for the * vertices and one list for the edges. @@ -77,17 +82,34 @@ export class DirectedGraph { * connections making sure all parent nodes are connected to all child nodes * and return the new connections. */ - removeNode(node: INode, options?: { reconnectConnections: true }): GraphConnection[]; - removeNode(node: INode, options?: { reconnectConnections: false }): undefined; - removeNode(node: INode, { reconnectConnections = false } = {}): undefined | GraphConnection[] { - if (reconnectConnections) { - const incomingConnections = this.getDirectParents(node); - const outgoingConnections = this.getDirectChildren(node); + removeNode( + node: INode, + options?: { reconnectConnections: true } & RemoveNodeBaseOptions, + ): GraphConnection[]; + removeNode( + node: INode, + options?: { reconnectConnections: false } & RemoveNodeBaseOptions, + ): undefined; + removeNode( + node: INode, + options: RemoveNodeBaseOptions = { reconnectConnections: false }, + ): undefined | GraphConnection[] { + if (options.reconnectConnections) { + const incomingConnections = this.getDirectParentConnections(node); + const outgoingConnections = this.getDirectChildConnections(node); const newConnections: GraphConnection[] = []; for (const incomingConnection of incomingConnections) { + if (options.skipConnectionFn && options.skipConnectionFn(incomingConnection)) { + continue; + } + for (const outgoingConnection of outgoingConnections) { + if (options.skipConnectionFn && options.skipConnectionFn(outgoingConnection)) { + continue; + } + const newConnection = { ...incomingConnection, to: outgoingConnection.to, @@ -165,7 +187,7 @@ export class DirectedGraph { return this; } - getDirectChildren(node: INode) { + getDirectChildConnections(node: INode) { const nodeExists = this.nodes.get(node.name) === node; a.ok(nodeExists); @@ -183,7 +205,7 @@ export class DirectedGraph { } private getChildrenRecursive(node: INode, children: Set) { - const directChildren = this.getDirectChildren(node); + const directChildren = this.getDirectChildConnections(node); for (const directChild of directChildren) { // Break out if we found a cycle. @@ -202,13 +224,13 @@ export class DirectedGraph { * argument. * * If the node being passed in is a child of itself (e.g. is part of a - * cylce), the return set will contain it as well. + * cycle), the return set will contain it as well. */ getChildren(node: INode) { return this.getChildrenRecursive(node, new Set()); } - getDirectParents(node: INode) { + getDirectParentConnections(node: INode) { const nodeExists = this.nodes.get(node.name) === node; a.ok(nodeExists); @@ -225,6 +247,27 @@ export class DirectedGraph { return directParents; } + private getParentConnectionsRecursive(node: INode, connections: Set) { + const parentConnections = this.getDirectParentConnections(node); + + for (const connection of parentConnections) { + // break out of cycles + if (connections.has(connection)) { + continue; + } + + connections.add(connection); + + this.getParentConnectionsRecursive(connection.from, connections); + } + + return connections; + } + + getParentConnections(node: INode) { + return this.getParentConnectionsRecursive(node, new Set()); + } + getConnection( from: INode, outputIndex: number, diff --git a/packages/core/src/PartialExecutionUtils/__tests__/DirectedGraph.test.ts b/packages/core/src/PartialExecutionUtils/__tests__/DirectedGraph.test.ts index 426a5405c7..d6eedf416d 100644 --- a/packages/core/src/PartialExecutionUtils/__tests__/DirectedGraph.test.ts +++ b/packages/core/src/PartialExecutionUtils/__tests__/DirectedGraph.test.ts @@ -89,6 +89,60 @@ describe('DirectedGraph', () => { }); }); + describe('getParentConnections', () => { + // ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ + // │node1├──►│node2├──►│node3│──►│node4│ + // └─────┘ └─────┘ └─────┘ └─────┘ + test('returns all parent connections', () => { + // ARRANGE + const node1 = createNodeData({ name: 'Node1' }); + const node2 = createNodeData({ name: 'Node2' }); + const node3 = createNodeData({ name: 'Node3' }); + const node4 = createNodeData({ name: 'Node4' }); + const graph = new DirectedGraph() + .addNodes(node1, node2, node3, node4) + .addConnections( + { from: node1, to: node2 }, + { from: node2, to: node3 }, + { from: node3, to: node4 }, + ); + + // ACT + const connections = graph.getParentConnections(node3); + + // ASSERT + const expectedConnections = graph.getConnections().filter((c) => c.to !== node4); + expect(connections.size).toBe(2); + expect(connections).toEqual(new Set(expectedConnections)); + }); + + // ┌─────┐ ┌─────┐ ┌─────┐ + // ┌─►│node1├───►│node2├──►│node3├─┐ + // │ └─────┘ └─────┘ └─────┘ │ + // │ │ + // └───────────────────────────────┘ + test('terminates when finding a cycle', () => { + // ARRANGE + const node1 = createNodeData({ name: 'Node1' }); + const node2 = createNodeData({ name: 'Node2' }); + const node3 = createNodeData({ name: 'Node3' }); + const graph = new DirectedGraph() + .addNodes(node1, node2, node3) + .addConnections( + { from: node1, to: node2 }, + { from: node2, to: node3 }, + { from: node3, to: node1 }, + ); + + // ACT + const connections = graph.getParentConnections(node3); + + // ASSERT + expect(connections.size).toBe(3); + expect(connections).toEqual(new Set(graph.getConnections())); + }); + }); + describe('removeNode', () => { // XX // ┌─────┐ ┌─────┐ ┌─────┐ diff --git a/packages/core/src/PartialExecutionUtils/__tests__/findSubgraph.test.ts b/packages/core/src/PartialExecutionUtils/__tests__/findSubgraph.test.ts index a5187eacf4..76a1afcc69 100644 --- a/packages/core/src/PartialExecutionUtils/__tests__/findSubgraph.test.ts +++ b/packages/core/src/PartialExecutionUtils/__tests__/findSubgraph.test.ts @@ -9,6 +9,8 @@ // XX denotes that the node is disabled // PD denotes that the node has pinned data +import { NodeConnectionType } from 'n8n-workflow'; + import { createNodeData } from './helpers'; import { DirectedGraph } from '../DirectedGraph'; import { findSubgraph } from '../findSubgraph'; @@ -222,4 +224,110 @@ describe('findSubgraph', () => { .addConnections({ from: trigger, to: destination }), ); }); + + describe('root nodes', () => { + // ►► + // ┌───────┐ ┌───────────┐ + // │trigger├─────►│destination│ + // └───────┘ └──▲────────┘ + // │AiLanguageModel + // ┌┴──────┐ + // │aiModel│ + // └───────┘ + test('always retain connections that have a different type than `NodeConnectionType.Main`', () => { + // ARRANGE + const trigger = createNodeData({ name: 'trigger' }); + const destination = createNodeData({ name: 'destination' }); + const aiModel = createNodeData({ name: 'ai_model' }); + + const graph = new DirectedGraph() + .addNodes(trigger, destination, aiModel) + .addConnections( + { from: trigger, to: destination }, + { from: aiModel, type: NodeConnectionType.AiLanguageModel, to: destination }, + ); + + // ACT + const subgraph = findSubgraph(graph, destination, trigger); + + // ASSERT + expect(subgraph).toEqual(graph); + }); + + // This graph is not possible, it's only here to make sure `findSubgraph` + // does not follow non-Main connections. + // + // ┌────┐ ┌───────────┐ + // │root┼───►destination│ + // └──▲─┘ └───────────┘ + // │AiLanguageModel + // ┌┴──────┐ + // │aiModel│ + // └▲──────┘ + // ┌┴──────┐ + // │trigger│ + // └───────┘ + // turns into an empty graph, because there is no `Main` typed connection + // connecting destination and trigger. + test('skip non-Main connection types', () => { + // ARRANGE + const trigger = createNodeData({ name: 'trigger' }); + const root = createNodeData({ name: 'root' }); + const aiModel = createNodeData({ name: 'aiModel' }); + const destination = createNodeData({ name: 'destination' }); + const graph = new DirectedGraph() + .addNodes(trigger, root, aiModel, destination) + .addConnections( + { from: trigger, to: aiModel }, + { from: aiModel, type: NodeConnectionType.AiLanguageModel, to: root }, + { from: root, to: destination }, + ); + + // ACT + const subgraph = findSubgraph(graph, destination, trigger); + + // ASSERT + expect(subgraph.getConnections()).toHaveLength(0); + expect(subgraph.getNodes().size).toBe(0); + }); + + // + // XX + // ┌───────┐ ┌────┐ ┌───────────┐ + // │trigger├───►root├───►destination│ + // └───────┘ └──▲─┘ └───────────┘ + // │AiLanguageModel + // ┌┴──────┐ + // │aiModel│ + // └───────┘ + // turns into + // ┌───────┐ ┌───────────┐ + // │trigger├────────────►destination│ + // └───────┘ └───────────┘ + test('skip disabled root nodes', () => { + // ARRANGE + const trigger = createNodeData({ name: 'trigger' }); + const root = createNodeData({ name: 'root', disabled: true }); + const aiModel = createNodeData({ name: 'ai_model' }); + const destination = createNodeData({ name: 'destination' }); + + const graph = new DirectedGraph() + .addNodes(trigger, root, aiModel, destination) + .addConnections( + { from: trigger, to: root }, + { from: aiModel, type: NodeConnectionType.AiLanguageModel, to: root }, + { from: root, to: destination }, + ); + + // ACT + const subgraph = findSubgraph(graph, root, trigger); + + // ASSERT + expect(subgraph).toEqual( + new DirectedGraph() + .addNodes(trigger, destination) + .addConnections({ from: trigger, to: destination }), + ); + }); + }); }); diff --git a/packages/core/src/PartialExecutionUtils/findStartNodes.ts b/packages/core/src/PartialExecutionUtils/findStartNodes.ts index 12a9688c1c..28772bfc9a 100644 --- a/packages/core/src/PartialExecutionUtils/findStartNodes.ts +++ b/packages/core/src/PartialExecutionUtils/findStartNodes.ts @@ -80,7 +80,7 @@ function findStartNodesRecursive( } // Recurse with every direct child that is part of the sub graph. - const outGoingConnections = graph.getDirectChildren(current); + const outGoingConnections = graph.getDirectChildConnections(current); for (const outGoingConnection of outGoingConnections) { const nodeRunData = getIncomingData( runData, diff --git a/packages/core/src/PartialExecutionUtils/findSubgraph.ts b/packages/core/src/PartialExecutionUtils/findSubgraph.ts index ea1df91840..d05561e31a 100644 --- a/packages/core/src/PartialExecutionUtils/findSubgraph.ts +++ b/packages/core/src/PartialExecutionUtils/findSubgraph.ts @@ -1,4 +1,4 @@ -import type { INode } from 'n8n-workflow'; +import { NodeConnectionType, type INode } from 'n8n-workflow'; import type { GraphConnection } from './DirectedGraph'; import { DirectedGraph } from './DirectedGraph'; @@ -21,7 +21,7 @@ function findSubgraphRecursive( return; } - let parentConnections = graph.getDirectParents(current); + let parentConnections = graph.getDirectParentConnections(current); // If the current node has no parents, don’t keep this branch. if (parentConnections.length === 0) { @@ -58,11 +58,24 @@ function findSubgraphRecursive( // The node is replaced by a set of new connections, connecting the parents // and children of it directly. In the recursive call below we'll follow // them further. - parentConnections = graph.removeNode(current, { reconnectConnections: true }); + parentConnections = graph.removeNode(current, { + reconnectConnections: true, + // If the node has non-Main connections we don't want to rewire those. + // Otherwise we'd end up connecting AI utilities to nodes that don't + // support them. + skipConnectionFn: (c) => c.type !== NodeConnectionType.Main, + }); } // Recurse on each parent. for (const parentConnection of parentConnections) { + // Skip parents that are connected via non-Main connection types. They are + // only utility nodes for AI and are not part of the data or control flow + // and can never lead too the trigger. + if (parentConnection.type !== NodeConnectionType.Main) { + continue; + } + findSubgraphRecursive(graph, destinationNode, parentConnection.from, trigger, newGraph, [ ...currentBranch, parentConnection, @@ -87,15 +100,38 @@ function findSubgraphRecursive( * - take every incoming connection and connect it to every node that is * connected to the current node’s first output * 6. Recurse on each parent + * 7. Re-add all connections that don't use the `Main` connections type. + * Theses are used by nodes called root nodes and they are not part of the + * dataflow in the graph they are utility nodes, like the AI model used in a + * lang chain node. */ export function findSubgraph( graph: DirectedGraph, destinationNode: INode, trigger: INode, ): DirectedGraph { - const newGraph = new DirectedGraph(); + const subgraph = new DirectedGraph(); - findSubgraphRecursive(graph, destinationNode, destinationNode, trigger, newGraph, []); + findSubgraphRecursive(graph, destinationNode, destinationNode, trigger, subgraph, []); - return newGraph; + // For each node in the subgraph, if it has parent connections of a type that + // is not `Main` in the input graph, add the connections and the nodes + // connected to it to the subgraph + // + // Without this all AI related workflows would not work when executed + // partially, because all utility nodes would be missing. + for (const node of subgraph.getNodes().values()) { + const parentConnections = graph.getParentConnections(node); + + for (const connection of parentConnections) { + if (connection.type === NodeConnectionType.Main) { + continue; + } + + subgraph.addNodes(connection.from, connection.to); + subgraph.addConnection(connection); + } + } + + return subgraph; } diff --git a/packages/core/src/PartialExecutionUtils/recreateNodeExecutionStack.ts b/packages/core/src/PartialExecutionUtils/recreateNodeExecutionStack.ts index f2f1f4af68..4926becb79 100644 --- a/packages/core/src/PartialExecutionUtils/recreateNodeExecutionStack.ts +++ b/packages/core/src/PartialExecutionUtils/recreateNodeExecutionStack.ts @@ -64,7 +64,7 @@ export function recreateNodeExecutionStack( for (const startNode of startNodes) { const incomingStartNodeConnections = graph - .getDirectParents(startNode) + .getDirectParentConnections(startNode) .filter((c) => c.type === NodeConnectionType.Main); let incomingData: INodeExecutionData[][] = []; @@ -135,7 +135,7 @@ export function recreateNodeExecutionStack( // Check if the destinationNode has to be added as waiting // because some input data is already fully available const incomingDestinationNodeConnections = graph - .getDirectParents(destinationNode) + .getDirectParentConnections(destinationNode) .filter((c) => c.type === NodeConnectionType.Main); if (incomingDestinationNodeConnections !== undefined) { for (const connection of incomingDestinationNodeConnections) {