mirror of
https://github.com/n8n-io/n8n.git
synced 2025-03-05 20:50:17 -08:00
fix(core): Fix AI nodes not working with new partial execution flow (#11055)
This commit is contained in:
parent
b559352036
commit
0eee5dfd59
|
@ -12,6 +12,11 @@ export type GraphConnection = {
|
|||
// fromName-outputType-outputIndex-inputIndex-toName
|
||||
type DirectedGraphKey = `${string}-${NodeConnectionType}-${number}-${number}-${string}`;
|
||||
|
||||
type RemoveNodeBaseOptions = {
|
||||
reconnectConnections: boolean;
|
||||
skipConnectionFn?: (connection: GraphConnection) => boolean;
|
||||
};
|
||||
|
||||
/**
|
||||
* Represents a directed graph as an adjacency list, e.g. one list for the
|
||||
* vertices and one list for the edges.
|
||||
|
@ -77,17 +82,34 @@ export class DirectedGraph {
|
|||
* connections making sure all parent nodes are connected to all child nodes
|
||||
* and return the new connections.
|
||||
*/
|
||||
removeNode(node: INode, options?: { reconnectConnections: true }): GraphConnection[];
|
||||
removeNode(node: INode, options?: { reconnectConnections: false }): undefined;
|
||||
removeNode(node: INode, { reconnectConnections = false } = {}): undefined | GraphConnection[] {
|
||||
if (reconnectConnections) {
|
||||
const incomingConnections = this.getDirectParents(node);
|
||||
const outgoingConnections = this.getDirectChildren(node);
|
||||
removeNode(
|
||||
node: INode,
|
||||
options?: { reconnectConnections: true } & RemoveNodeBaseOptions,
|
||||
): GraphConnection[];
|
||||
removeNode(
|
||||
node: INode,
|
||||
options?: { reconnectConnections: false } & RemoveNodeBaseOptions,
|
||||
): undefined;
|
||||
removeNode(
|
||||
node: INode,
|
||||
options: RemoveNodeBaseOptions = { reconnectConnections: false },
|
||||
): undefined | GraphConnection[] {
|
||||
if (options.reconnectConnections) {
|
||||
const incomingConnections = this.getDirectParentConnections(node);
|
||||
const outgoingConnections = this.getDirectChildConnections(node);
|
||||
|
||||
const newConnections: GraphConnection[] = [];
|
||||
|
||||
for (const incomingConnection of incomingConnections) {
|
||||
if (options.skipConnectionFn && options.skipConnectionFn(incomingConnection)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (const outgoingConnection of outgoingConnections) {
|
||||
if (options.skipConnectionFn && options.skipConnectionFn(outgoingConnection)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const newConnection = {
|
||||
...incomingConnection,
|
||||
to: outgoingConnection.to,
|
||||
|
@ -165,7 +187,7 @@ export class DirectedGraph {
|
|||
return this;
|
||||
}
|
||||
|
||||
getDirectChildren(node: INode) {
|
||||
getDirectChildConnections(node: INode) {
|
||||
const nodeExists = this.nodes.get(node.name) === node;
|
||||
a.ok(nodeExists);
|
||||
|
||||
|
@ -183,7 +205,7 @@ export class DirectedGraph {
|
|||
}
|
||||
|
||||
private getChildrenRecursive(node: INode, children: Set<INode>) {
|
||||
const directChildren = this.getDirectChildren(node);
|
||||
const directChildren = this.getDirectChildConnections(node);
|
||||
|
||||
for (const directChild of directChildren) {
|
||||
// Break out if we found a cycle.
|
||||
|
@ -202,13 +224,13 @@ export class DirectedGraph {
|
|||
* argument.
|
||||
*
|
||||
* If the node being passed in is a child of itself (e.g. is part of a
|
||||
* cylce), the return set will contain it as well.
|
||||
* cycle), the return set will contain it as well.
|
||||
*/
|
||||
getChildren(node: INode) {
|
||||
return this.getChildrenRecursive(node, new Set());
|
||||
}
|
||||
|
||||
getDirectParents(node: INode) {
|
||||
getDirectParentConnections(node: INode) {
|
||||
const nodeExists = this.nodes.get(node.name) === node;
|
||||
a.ok(nodeExists);
|
||||
|
||||
|
@ -225,6 +247,27 @@ export class DirectedGraph {
|
|||
return directParents;
|
||||
}
|
||||
|
||||
private getParentConnectionsRecursive(node: INode, connections: Set<GraphConnection>) {
|
||||
const parentConnections = this.getDirectParentConnections(node);
|
||||
|
||||
for (const connection of parentConnections) {
|
||||
// break out of cycles
|
||||
if (connections.has(connection)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
connections.add(connection);
|
||||
|
||||
this.getParentConnectionsRecursive(connection.from, connections);
|
||||
}
|
||||
|
||||
return connections;
|
||||
}
|
||||
|
||||
getParentConnections(node: INode) {
|
||||
return this.getParentConnectionsRecursive(node, new Set());
|
||||
}
|
||||
|
||||
getConnection(
|
||||
from: INode,
|
||||
outputIndex: number,
|
||||
|
|
|
@ -89,6 +89,60 @@ describe('DirectedGraph', () => {
|
|||
});
|
||||
});
|
||||
|
||||
describe('getParentConnections', () => {
|
||||
// ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐
|
||||
// │node1├──►│node2├──►│node3│──►│node4│
|
||||
// └─────┘ └─────┘ └─────┘ └─────┘
|
||||
test('returns all parent connections', () => {
|
||||
// ARRANGE
|
||||
const node1 = createNodeData({ name: 'Node1' });
|
||||
const node2 = createNodeData({ name: 'Node2' });
|
||||
const node3 = createNodeData({ name: 'Node3' });
|
||||
const node4 = createNodeData({ name: 'Node4' });
|
||||
const graph = new DirectedGraph()
|
||||
.addNodes(node1, node2, node3, node4)
|
||||
.addConnections(
|
||||
{ from: node1, to: node2 },
|
||||
{ from: node2, to: node3 },
|
||||
{ from: node3, to: node4 },
|
||||
);
|
||||
|
||||
// ACT
|
||||
const connections = graph.getParentConnections(node3);
|
||||
|
||||
// ASSERT
|
||||
const expectedConnections = graph.getConnections().filter((c) => c.to !== node4);
|
||||
expect(connections.size).toBe(2);
|
||||
expect(connections).toEqual(new Set(expectedConnections));
|
||||
});
|
||||
|
||||
// ┌─────┐ ┌─────┐ ┌─────┐
|
||||
// ┌─►│node1├───►│node2├──►│node3├─┐
|
||||
// │ └─────┘ └─────┘ └─────┘ │
|
||||
// │ │
|
||||
// └───────────────────────────────┘
|
||||
test('terminates when finding a cycle', () => {
|
||||
// ARRANGE
|
||||
const node1 = createNodeData({ name: 'Node1' });
|
||||
const node2 = createNodeData({ name: 'Node2' });
|
||||
const node3 = createNodeData({ name: 'Node3' });
|
||||
const graph = new DirectedGraph()
|
||||
.addNodes(node1, node2, node3)
|
||||
.addConnections(
|
||||
{ from: node1, to: node2 },
|
||||
{ from: node2, to: node3 },
|
||||
{ from: node3, to: node1 },
|
||||
);
|
||||
|
||||
// ACT
|
||||
const connections = graph.getParentConnections(node3);
|
||||
|
||||
// ASSERT
|
||||
expect(connections.size).toBe(3);
|
||||
expect(connections).toEqual(new Set(graph.getConnections()));
|
||||
});
|
||||
});
|
||||
|
||||
describe('removeNode', () => {
|
||||
// XX
|
||||
// ┌─────┐ ┌─────┐ ┌─────┐
|
||||
|
|
|
@ -9,6 +9,8 @@
|
|||
// XX denotes that the node is disabled
|
||||
// PD denotes that the node has pinned data
|
||||
|
||||
import { NodeConnectionType } from 'n8n-workflow';
|
||||
|
||||
import { createNodeData } from './helpers';
|
||||
import { DirectedGraph } from '../DirectedGraph';
|
||||
import { findSubgraph } from '../findSubgraph';
|
||||
|
@ -222,4 +224,110 @@ describe('findSubgraph', () => {
|
|||
.addConnections({ from: trigger, to: destination }),
|
||||
);
|
||||
});
|
||||
|
||||
describe('root nodes', () => {
|
||||
// ►►
|
||||
// ┌───────┐ ┌───────────┐
|
||||
// │trigger├─────►│destination│
|
||||
// └───────┘ └──▲────────┘
|
||||
// │AiLanguageModel
|
||||
// ┌┴──────┐
|
||||
// │aiModel│
|
||||
// └───────┘
|
||||
test('always retain connections that have a different type than `NodeConnectionType.Main`', () => {
|
||||
// ARRANGE
|
||||
const trigger = createNodeData({ name: 'trigger' });
|
||||
const destination = createNodeData({ name: 'destination' });
|
||||
const aiModel = createNodeData({ name: 'ai_model' });
|
||||
|
||||
const graph = new DirectedGraph()
|
||||
.addNodes(trigger, destination, aiModel)
|
||||
.addConnections(
|
||||
{ from: trigger, to: destination },
|
||||
{ from: aiModel, type: NodeConnectionType.AiLanguageModel, to: destination },
|
||||
);
|
||||
|
||||
// ACT
|
||||
const subgraph = findSubgraph(graph, destination, trigger);
|
||||
|
||||
// ASSERT
|
||||
expect(subgraph).toEqual(graph);
|
||||
});
|
||||
|
||||
// This graph is not possible, it's only here to make sure `findSubgraph`
|
||||
// does not follow non-Main connections.
|
||||
//
|
||||
// ┌────┐ ┌───────────┐
|
||||
// │root┼───►destination│
|
||||
// └──▲─┘ └───────────┘
|
||||
// │AiLanguageModel
|
||||
// ┌┴──────┐
|
||||
// │aiModel│
|
||||
// └▲──────┘
|
||||
// ┌┴──────┐
|
||||
// │trigger│
|
||||
// └───────┘
|
||||
// turns into an empty graph, because there is no `Main` typed connection
|
||||
// connecting destination and trigger.
|
||||
test('skip non-Main connection types', () => {
|
||||
// ARRANGE
|
||||
const trigger = createNodeData({ name: 'trigger' });
|
||||
const root = createNodeData({ name: 'root' });
|
||||
const aiModel = createNodeData({ name: 'aiModel' });
|
||||
const destination = createNodeData({ name: 'destination' });
|
||||
const graph = new DirectedGraph()
|
||||
.addNodes(trigger, root, aiModel, destination)
|
||||
.addConnections(
|
||||
{ from: trigger, to: aiModel },
|
||||
{ from: aiModel, type: NodeConnectionType.AiLanguageModel, to: root },
|
||||
{ from: root, to: destination },
|
||||
);
|
||||
|
||||
// ACT
|
||||
const subgraph = findSubgraph(graph, destination, trigger);
|
||||
|
||||
// ASSERT
|
||||
expect(subgraph.getConnections()).toHaveLength(0);
|
||||
expect(subgraph.getNodes().size).toBe(0);
|
||||
});
|
||||
|
||||
//
|
||||
// XX
|
||||
// ┌───────┐ ┌────┐ ┌───────────┐
|
||||
// │trigger├───►root├───►destination│
|
||||
// └───────┘ └──▲─┘ └───────────┘
|
||||
// │AiLanguageModel
|
||||
// ┌┴──────┐
|
||||
// │aiModel│
|
||||
// └───────┘
|
||||
// turns into
|
||||
// ┌───────┐ ┌───────────┐
|
||||
// │trigger├────────────►destination│
|
||||
// └───────┘ └───────────┘
|
||||
test('skip disabled root nodes', () => {
|
||||
// ARRANGE
|
||||
const trigger = createNodeData({ name: 'trigger' });
|
||||
const root = createNodeData({ name: 'root', disabled: true });
|
||||
const aiModel = createNodeData({ name: 'ai_model' });
|
||||
const destination = createNodeData({ name: 'destination' });
|
||||
|
||||
const graph = new DirectedGraph()
|
||||
.addNodes(trigger, root, aiModel, destination)
|
||||
.addConnections(
|
||||
{ from: trigger, to: root },
|
||||
{ from: aiModel, type: NodeConnectionType.AiLanguageModel, to: root },
|
||||
{ from: root, to: destination },
|
||||
);
|
||||
|
||||
// ACT
|
||||
const subgraph = findSubgraph(graph, root, trigger);
|
||||
|
||||
// ASSERT
|
||||
expect(subgraph).toEqual(
|
||||
new DirectedGraph()
|
||||
.addNodes(trigger, destination)
|
||||
.addConnections({ from: trigger, to: destination }),
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -80,7 +80,7 @@ function findStartNodesRecursive(
|
|||
}
|
||||
|
||||
// Recurse with every direct child that is part of the sub graph.
|
||||
const outGoingConnections = graph.getDirectChildren(current);
|
||||
const outGoingConnections = graph.getDirectChildConnections(current);
|
||||
for (const outGoingConnection of outGoingConnections) {
|
||||
const nodeRunData = getIncomingData(
|
||||
runData,
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import type { INode } from 'n8n-workflow';
|
||||
import { NodeConnectionType, type INode } from 'n8n-workflow';
|
||||
|
||||
import type { GraphConnection } from './DirectedGraph';
|
||||
import { DirectedGraph } from './DirectedGraph';
|
||||
|
@ -21,7 +21,7 @@ function findSubgraphRecursive(
|
|||
return;
|
||||
}
|
||||
|
||||
let parentConnections = graph.getDirectParents(current);
|
||||
let parentConnections = graph.getDirectParentConnections(current);
|
||||
|
||||
// If the current node has no parents, don’t keep this branch.
|
||||
if (parentConnections.length === 0) {
|
||||
|
@ -58,11 +58,24 @@ function findSubgraphRecursive(
|
|||
// The node is replaced by a set of new connections, connecting the parents
|
||||
// and children of it directly. In the recursive call below we'll follow
|
||||
// them further.
|
||||
parentConnections = graph.removeNode(current, { reconnectConnections: true });
|
||||
parentConnections = graph.removeNode(current, {
|
||||
reconnectConnections: true,
|
||||
// If the node has non-Main connections we don't want to rewire those.
|
||||
// Otherwise we'd end up connecting AI utilities to nodes that don't
|
||||
// support them.
|
||||
skipConnectionFn: (c) => c.type !== NodeConnectionType.Main,
|
||||
});
|
||||
}
|
||||
|
||||
// Recurse on each parent.
|
||||
for (const parentConnection of parentConnections) {
|
||||
// Skip parents that are connected via non-Main connection types. They are
|
||||
// only utility nodes for AI and are not part of the data or control flow
|
||||
// and can never lead too the trigger.
|
||||
if (parentConnection.type !== NodeConnectionType.Main) {
|
||||
continue;
|
||||
}
|
||||
|
||||
findSubgraphRecursive(graph, destinationNode, parentConnection.from, trigger, newGraph, [
|
||||
...currentBranch,
|
||||
parentConnection,
|
||||
|
@ -87,15 +100,38 @@ function findSubgraphRecursive(
|
|||
* - take every incoming connection and connect it to every node that is
|
||||
* connected to the current node’s first output
|
||||
* 6. Recurse on each parent
|
||||
* 7. Re-add all connections that don't use the `Main` connections type.
|
||||
* Theses are used by nodes called root nodes and they are not part of the
|
||||
* dataflow in the graph they are utility nodes, like the AI model used in a
|
||||
* lang chain node.
|
||||
*/
|
||||
export function findSubgraph(
|
||||
graph: DirectedGraph,
|
||||
destinationNode: INode,
|
||||
trigger: INode,
|
||||
): DirectedGraph {
|
||||
const newGraph = new DirectedGraph();
|
||||
const subgraph = new DirectedGraph();
|
||||
|
||||
findSubgraphRecursive(graph, destinationNode, destinationNode, trigger, newGraph, []);
|
||||
findSubgraphRecursive(graph, destinationNode, destinationNode, trigger, subgraph, []);
|
||||
|
||||
return newGraph;
|
||||
// For each node in the subgraph, if it has parent connections of a type that
|
||||
// is not `Main` in the input graph, add the connections and the nodes
|
||||
// connected to it to the subgraph
|
||||
//
|
||||
// Without this all AI related workflows would not work when executed
|
||||
// partially, because all utility nodes would be missing.
|
||||
for (const node of subgraph.getNodes().values()) {
|
||||
const parentConnections = graph.getParentConnections(node);
|
||||
|
||||
for (const connection of parentConnections) {
|
||||
if (connection.type === NodeConnectionType.Main) {
|
||||
continue;
|
||||
}
|
||||
|
||||
subgraph.addNodes(connection.from, connection.to);
|
||||
subgraph.addConnection(connection);
|
||||
}
|
||||
}
|
||||
|
||||
return subgraph;
|
||||
}
|
||||
|
|
|
@ -64,7 +64,7 @@ export function recreateNodeExecutionStack(
|
|||
|
||||
for (const startNode of startNodes) {
|
||||
const incomingStartNodeConnections = graph
|
||||
.getDirectParents(startNode)
|
||||
.getDirectParentConnections(startNode)
|
||||
.filter((c) => c.type === NodeConnectionType.Main);
|
||||
|
||||
let incomingData: INodeExecutionData[][] = [];
|
||||
|
@ -135,7 +135,7 @@ export function recreateNodeExecutionStack(
|
|||
// Check if the destinationNode has to be added as waiting
|
||||
// because some input data is already fully available
|
||||
const incomingDestinationNodeConnections = graph
|
||||
.getDirectParents(destinationNode)
|
||||
.getDirectParentConnections(destinationNode)
|
||||
.filter((c) => c.type === NodeConnectionType.Main);
|
||||
if (incomingDestinationNodeConnections !== undefined) {
|
||||
for (const connection of incomingDestinationNodeConnections) {
|
||||
|
|
Loading…
Reference in a new issue