fix(core): Fix AI nodes not working with new partial execution flow (#11055)

This commit is contained in:
Danny Martini 2024-10-09 09:34:26 +02:00 committed by GitHub
parent b559352036
commit 0eee5dfd59
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 260 additions and 19 deletions

View file

@ -12,6 +12,11 @@ export type GraphConnection = {
// fromName-outputType-outputIndex-inputIndex-toName
type DirectedGraphKey = `${string}-${NodeConnectionType}-${number}-${number}-${string}`;
type RemoveNodeBaseOptions = {
reconnectConnections: boolean;
skipConnectionFn?: (connection: GraphConnection) => boolean;
};
/**
* Represents a directed graph as an adjacency list, e.g. one list for the
* vertices and one list for the edges.
@ -77,17 +82,34 @@ export class DirectedGraph {
* connections making sure all parent nodes are connected to all child nodes
* and return the new connections.
*/
removeNode(node: INode, options?: { reconnectConnections: true }): GraphConnection[];
removeNode(node: INode, options?: { reconnectConnections: false }): undefined;
removeNode(node: INode, { reconnectConnections = false } = {}): undefined | GraphConnection[] {
if (reconnectConnections) {
const incomingConnections = this.getDirectParents(node);
const outgoingConnections = this.getDirectChildren(node);
removeNode(
node: INode,
options?: { reconnectConnections: true } & RemoveNodeBaseOptions,
): GraphConnection[];
removeNode(
node: INode,
options?: { reconnectConnections: false } & RemoveNodeBaseOptions,
): undefined;
removeNode(
node: INode,
options: RemoveNodeBaseOptions = { reconnectConnections: false },
): undefined | GraphConnection[] {
if (options.reconnectConnections) {
const incomingConnections = this.getDirectParentConnections(node);
const outgoingConnections = this.getDirectChildConnections(node);
const newConnections: GraphConnection[] = [];
for (const incomingConnection of incomingConnections) {
if (options.skipConnectionFn && options.skipConnectionFn(incomingConnection)) {
continue;
}
for (const outgoingConnection of outgoingConnections) {
if (options.skipConnectionFn && options.skipConnectionFn(outgoingConnection)) {
continue;
}
const newConnection = {
...incomingConnection,
to: outgoingConnection.to,
@ -165,7 +187,7 @@ export class DirectedGraph {
return this;
}
getDirectChildren(node: INode) {
getDirectChildConnections(node: INode) {
const nodeExists = this.nodes.get(node.name) === node;
a.ok(nodeExists);
@ -183,7 +205,7 @@ export class DirectedGraph {
}
private getChildrenRecursive(node: INode, children: Set<INode>) {
const directChildren = this.getDirectChildren(node);
const directChildren = this.getDirectChildConnections(node);
for (const directChild of directChildren) {
// Break out if we found a cycle.
@ -202,13 +224,13 @@ export class DirectedGraph {
* argument.
*
* If the node being passed in is a child of itself (e.g. is part of a
* cylce), the return set will contain it as well.
* cycle), the return set will contain it as well.
*/
getChildren(node: INode) {
return this.getChildrenRecursive(node, new Set());
}
getDirectParents(node: INode) {
getDirectParentConnections(node: INode) {
const nodeExists = this.nodes.get(node.name) === node;
a.ok(nodeExists);
@ -225,6 +247,27 @@ export class DirectedGraph {
return directParents;
}
private getParentConnectionsRecursive(node: INode, connections: Set<GraphConnection>) {
const parentConnections = this.getDirectParentConnections(node);
for (const connection of parentConnections) {
// break out of cycles
if (connections.has(connection)) {
continue;
}
connections.add(connection);
this.getParentConnectionsRecursive(connection.from, connections);
}
return connections;
}
getParentConnections(node: INode) {
return this.getParentConnectionsRecursive(node, new Set());
}
getConnection(
from: INode,
outputIndex: number,

View file

@ -89,6 +89,60 @@ describe('DirectedGraph', () => {
});
});
describe('getParentConnections', () => {
// ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐
// │node1├──►│node2├──►│node3│──►│node4│
// └─────┘ └─────┘ └─────┘ └─────┘
test('returns all parent connections', () => {
// ARRANGE
const node1 = createNodeData({ name: 'Node1' });
const node2 = createNodeData({ name: 'Node2' });
const node3 = createNodeData({ name: 'Node3' });
const node4 = createNodeData({ name: 'Node4' });
const graph = new DirectedGraph()
.addNodes(node1, node2, node3, node4)
.addConnections(
{ from: node1, to: node2 },
{ from: node2, to: node3 },
{ from: node3, to: node4 },
);
// ACT
const connections = graph.getParentConnections(node3);
// ASSERT
const expectedConnections = graph.getConnections().filter((c) => c.to !== node4);
expect(connections.size).toBe(2);
expect(connections).toEqual(new Set(expectedConnections));
});
// ┌─────┐ ┌─────┐ ┌─────┐
// ┌─►│node1├───►│node2├──►│node3├─┐
// │ └─────┘ └─────┘ └─────┘ │
// │ │
// └───────────────────────────────┘
test('terminates when finding a cycle', () => {
// ARRANGE
const node1 = createNodeData({ name: 'Node1' });
const node2 = createNodeData({ name: 'Node2' });
const node3 = createNodeData({ name: 'Node3' });
const graph = new DirectedGraph()
.addNodes(node1, node2, node3)
.addConnections(
{ from: node1, to: node2 },
{ from: node2, to: node3 },
{ from: node3, to: node1 },
);
// ACT
const connections = graph.getParentConnections(node3);
// ASSERT
expect(connections.size).toBe(3);
expect(connections).toEqual(new Set(graph.getConnections()));
});
});
describe('removeNode', () => {
// XX
// ┌─────┐ ┌─────┐ ┌─────┐

View file

@ -9,6 +9,8 @@
// XX denotes that the node is disabled
// PD denotes that the node has pinned data
import { NodeConnectionType } from 'n8n-workflow';
import { createNodeData } from './helpers';
import { DirectedGraph } from '../DirectedGraph';
import { findSubgraph } from '../findSubgraph';
@ -222,4 +224,110 @@ describe('findSubgraph', () => {
.addConnections({ from: trigger, to: destination }),
);
});
describe('root nodes', () => {
// ►►
// ┌───────┐ ┌───────────┐
// │trigger├─────►│destination│
// └───────┘ └──▲────────┘
// │AiLanguageModel
// ┌┴──────┐
// │aiModel│
// └───────┘
test('always retain connections that have a different type than `NodeConnectionType.Main`', () => {
// ARRANGE
const trigger = createNodeData({ name: 'trigger' });
const destination = createNodeData({ name: 'destination' });
const aiModel = createNodeData({ name: 'ai_model' });
const graph = new DirectedGraph()
.addNodes(trigger, destination, aiModel)
.addConnections(
{ from: trigger, to: destination },
{ from: aiModel, type: NodeConnectionType.AiLanguageModel, to: destination },
);
// ACT
const subgraph = findSubgraph(graph, destination, trigger);
// ASSERT
expect(subgraph).toEqual(graph);
});
// This graph is not possible, it's only here to make sure `findSubgraph`
// does not follow non-Main connections.
//
// ┌────┐ ┌───────────┐
// │root┼───►destination│
// └──▲─┘ └───────────┘
// │AiLanguageModel
// ┌┴──────┐
// │aiModel│
// └▲──────┘
// ┌┴──────┐
// │trigger│
// └───────┘
// turns into an empty graph, because there is no `Main` typed connection
// connecting destination and trigger.
test('skip non-Main connection types', () => {
// ARRANGE
const trigger = createNodeData({ name: 'trigger' });
const root = createNodeData({ name: 'root' });
const aiModel = createNodeData({ name: 'aiModel' });
const destination = createNodeData({ name: 'destination' });
const graph = new DirectedGraph()
.addNodes(trigger, root, aiModel, destination)
.addConnections(
{ from: trigger, to: aiModel },
{ from: aiModel, type: NodeConnectionType.AiLanguageModel, to: root },
{ from: root, to: destination },
);
// ACT
const subgraph = findSubgraph(graph, destination, trigger);
// ASSERT
expect(subgraph.getConnections()).toHaveLength(0);
expect(subgraph.getNodes().size).toBe(0);
});
//
// XX
// ┌───────┐ ┌────┐ ┌───────────┐
// │trigger├───►root├───►destination│
// └───────┘ └──▲─┘ └───────────┘
// │AiLanguageModel
// ┌┴──────┐
// │aiModel│
// └───────┘
// turns into
// ┌───────┐ ┌───────────┐
// │trigger├────────────►destination│
// └───────┘ └───────────┘
test('skip disabled root nodes', () => {
// ARRANGE
const trigger = createNodeData({ name: 'trigger' });
const root = createNodeData({ name: 'root', disabled: true });
const aiModel = createNodeData({ name: 'ai_model' });
const destination = createNodeData({ name: 'destination' });
const graph = new DirectedGraph()
.addNodes(trigger, root, aiModel, destination)
.addConnections(
{ from: trigger, to: root },
{ from: aiModel, type: NodeConnectionType.AiLanguageModel, to: root },
{ from: root, to: destination },
);
// ACT
const subgraph = findSubgraph(graph, root, trigger);
// ASSERT
expect(subgraph).toEqual(
new DirectedGraph()
.addNodes(trigger, destination)
.addConnections({ from: trigger, to: destination }),
);
});
});
});

View file

@ -80,7 +80,7 @@ function findStartNodesRecursive(
}
// Recurse with every direct child that is part of the sub graph.
const outGoingConnections = graph.getDirectChildren(current);
const outGoingConnections = graph.getDirectChildConnections(current);
for (const outGoingConnection of outGoingConnections) {
const nodeRunData = getIncomingData(
runData,

View file

@ -1,4 +1,4 @@
import type { INode } from 'n8n-workflow';
import { NodeConnectionType, type INode } from 'n8n-workflow';
import type { GraphConnection } from './DirectedGraph';
import { DirectedGraph } from './DirectedGraph';
@ -21,7 +21,7 @@ function findSubgraphRecursive(
return;
}
let parentConnections = graph.getDirectParents(current);
let parentConnections = graph.getDirectParentConnections(current);
// If the current node has no parents, dont keep this branch.
if (parentConnections.length === 0) {
@ -58,11 +58,24 @@ function findSubgraphRecursive(
// The node is replaced by a set of new connections, connecting the parents
// and children of it directly. In the recursive call below we'll follow
// them further.
parentConnections = graph.removeNode(current, { reconnectConnections: true });
parentConnections = graph.removeNode(current, {
reconnectConnections: true,
// If the node has non-Main connections we don't want to rewire those.
// Otherwise we'd end up connecting AI utilities to nodes that don't
// support them.
skipConnectionFn: (c) => c.type !== NodeConnectionType.Main,
});
}
// Recurse on each parent.
for (const parentConnection of parentConnections) {
// Skip parents that are connected via non-Main connection types. They are
// only utility nodes for AI and are not part of the data or control flow
// and can never lead too the trigger.
if (parentConnection.type !== NodeConnectionType.Main) {
continue;
}
findSubgraphRecursive(graph, destinationNode, parentConnection.from, trigger, newGraph, [
...currentBranch,
parentConnection,
@ -87,15 +100,38 @@ function findSubgraphRecursive(
* - take every incoming connection and connect it to every node that is
* connected to the current nodes first output
* 6. Recurse on each parent
* 7. Re-add all connections that don't use the `Main` connections type.
* Theses are used by nodes called root nodes and they are not part of the
* dataflow in the graph they are utility nodes, like the AI model used in a
* lang chain node.
*/
export function findSubgraph(
graph: DirectedGraph,
destinationNode: INode,
trigger: INode,
): DirectedGraph {
const newGraph = new DirectedGraph();
const subgraph = new DirectedGraph();
findSubgraphRecursive(graph, destinationNode, destinationNode, trigger, newGraph, []);
findSubgraphRecursive(graph, destinationNode, destinationNode, trigger, subgraph, []);
return newGraph;
// For each node in the subgraph, if it has parent connections of a type that
// is not `Main` in the input graph, add the connections and the nodes
// connected to it to the subgraph
//
// Without this all AI related workflows would not work when executed
// partially, because all utility nodes would be missing.
for (const node of subgraph.getNodes().values()) {
const parentConnections = graph.getParentConnections(node);
for (const connection of parentConnections) {
if (connection.type === NodeConnectionType.Main) {
continue;
}
subgraph.addNodes(connection.from, connection.to);
subgraph.addConnection(connection);
}
}
return subgraph;
}

View file

@ -64,7 +64,7 @@ export function recreateNodeExecutionStack(
for (const startNode of startNodes) {
const incomingStartNodeConnections = graph
.getDirectParents(startNode)
.getDirectParentConnections(startNode)
.filter((c) => c.type === NodeConnectionType.Main);
let incomingData: INodeExecutionData[][] = [];
@ -135,7 +135,7 @@ export function recreateNodeExecutionStack(
// Check if the destinationNode has to be added as waiting
// because some input data is already fully available
const incomingDestinationNodeConnections = graph
.getDirectParents(destinationNode)
.getDirectParentConnections(destinationNode)
.filter((c) => c.type === NodeConnectionType.Main);
if (incomingDestinationNodeConnections !== undefined) {
for (const connection of incomingDestinationNodeConnections) {