mirror of
https://github.com/n8n-io/n8n.git
synced 2025-03-05 20:50:17 -08:00
fix(core): Fix AI nodes not working with new partial execution flow (#11055)
This commit is contained in:
parent
b559352036
commit
0eee5dfd59
|
@ -12,6 +12,11 @@ export type GraphConnection = {
|
||||||
// fromName-outputType-outputIndex-inputIndex-toName
|
// fromName-outputType-outputIndex-inputIndex-toName
|
||||||
type DirectedGraphKey = `${string}-${NodeConnectionType}-${number}-${number}-${string}`;
|
type DirectedGraphKey = `${string}-${NodeConnectionType}-${number}-${number}-${string}`;
|
||||||
|
|
||||||
|
type RemoveNodeBaseOptions = {
|
||||||
|
reconnectConnections: boolean;
|
||||||
|
skipConnectionFn?: (connection: GraphConnection) => boolean;
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents a directed graph as an adjacency list, e.g. one list for the
|
* Represents a directed graph as an adjacency list, e.g. one list for the
|
||||||
* vertices and one list for the edges.
|
* vertices and one list for the edges.
|
||||||
|
@ -77,17 +82,34 @@ export class DirectedGraph {
|
||||||
* connections making sure all parent nodes are connected to all child nodes
|
* connections making sure all parent nodes are connected to all child nodes
|
||||||
* and return the new connections.
|
* and return the new connections.
|
||||||
*/
|
*/
|
||||||
removeNode(node: INode, options?: { reconnectConnections: true }): GraphConnection[];
|
removeNode(
|
||||||
removeNode(node: INode, options?: { reconnectConnections: false }): undefined;
|
node: INode,
|
||||||
removeNode(node: INode, { reconnectConnections = false } = {}): undefined | GraphConnection[] {
|
options?: { reconnectConnections: true } & RemoveNodeBaseOptions,
|
||||||
if (reconnectConnections) {
|
): GraphConnection[];
|
||||||
const incomingConnections = this.getDirectParents(node);
|
removeNode(
|
||||||
const outgoingConnections = this.getDirectChildren(node);
|
node: INode,
|
||||||
|
options?: { reconnectConnections: false } & RemoveNodeBaseOptions,
|
||||||
|
): undefined;
|
||||||
|
removeNode(
|
||||||
|
node: INode,
|
||||||
|
options: RemoveNodeBaseOptions = { reconnectConnections: false },
|
||||||
|
): undefined | GraphConnection[] {
|
||||||
|
if (options.reconnectConnections) {
|
||||||
|
const incomingConnections = this.getDirectParentConnections(node);
|
||||||
|
const outgoingConnections = this.getDirectChildConnections(node);
|
||||||
|
|
||||||
const newConnections: GraphConnection[] = [];
|
const newConnections: GraphConnection[] = [];
|
||||||
|
|
||||||
for (const incomingConnection of incomingConnections) {
|
for (const incomingConnection of incomingConnections) {
|
||||||
|
if (options.skipConnectionFn && options.skipConnectionFn(incomingConnection)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
for (const outgoingConnection of outgoingConnections) {
|
for (const outgoingConnection of outgoingConnections) {
|
||||||
|
if (options.skipConnectionFn && options.skipConnectionFn(outgoingConnection)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
const newConnection = {
|
const newConnection = {
|
||||||
...incomingConnection,
|
...incomingConnection,
|
||||||
to: outgoingConnection.to,
|
to: outgoingConnection.to,
|
||||||
|
@ -165,7 +187,7 @@ export class DirectedGraph {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
getDirectChildren(node: INode) {
|
getDirectChildConnections(node: INode) {
|
||||||
const nodeExists = this.nodes.get(node.name) === node;
|
const nodeExists = this.nodes.get(node.name) === node;
|
||||||
a.ok(nodeExists);
|
a.ok(nodeExists);
|
||||||
|
|
||||||
|
@ -183,7 +205,7 @@ export class DirectedGraph {
|
||||||
}
|
}
|
||||||
|
|
||||||
private getChildrenRecursive(node: INode, children: Set<INode>) {
|
private getChildrenRecursive(node: INode, children: Set<INode>) {
|
||||||
const directChildren = this.getDirectChildren(node);
|
const directChildren = this.getDirectChildConnections(node);
|
||||||
|
|
||||||
for (const directChild of directChildren) {
|
for (const directChild of directChildren) {
|
||||||
// Break out if we found a cycle.
|
// Break out if we found a cycle.
|
||||||
|
@ -202,13 +224,13 @@ export class DirectedGraph {
|
||||||
* argument.
|
* argument.
|
||||||
*
|
*
|
||||||
* If the node being passed in is a child of itself (e.g. is part of a
|
* If the node being passed in is a child of itself (e.g. is part of a
|
||||||
* cylce), the return set will contain it as well.
|
* cycle), the return set will contain it as well.
|
||||||
*/
|
*/
|
||||||
getChildren(node: INode) {
|
getChildren(node: INode) {
|
||||||
return this.getChildrenRecursive(node, new Set());
|
return this.getChildrenRecursive(node, new Set());
|
||||||
}
|
}
|
||||||
|
|
||||||
getDirectParents(node: INode) {
|
getDirectParentConnections(node: INode) {
|
||||||
const nodeExists = this.nodes.get(node.name) === node;
|
const nodeExists = this.nodes.get(node.name) === node;
|
||||||
a.ok(nodeExists);
|
a.ok(nodeExists);
|
||||||
|
|
||||||
|
@ -225,6 +247,27 @@ export class DirectedGraph {
|
||||||
return directParents;
|
return directParents;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private getParentConnectionsRecursive(node: INode, connections: Set<GraphConnection>) {
|
||||||
|
const parentConnections = this.getDirectParentConnections(node);
|
||||||
|
|
||||||
|
for (const connection of parentConnections) {
|
||||||
|
// break out of cycles
|
||||||
|
if (connections.has(connection)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
connections.add(connection);
|
||||||
|
|
||||||
|
this.getParentConnectionsRecursive(connection.from, connections);
|
||||||
|
}
|
||||||
|
|
||||||
|
return connections;
|
||||||
|
}
|
||||||
|
|
||||||
|
getParentConnections(node: INode) {
|
||||||
|
return this.getParentConnectionsRecursive(node, new Set());
|
||||||
|
}
|
||||||
|
|
||||||
getConnection(
|
getConnection(
|
||||||
from: INode,
|
from: INode,
|
||||||
outputIndex: number,
|
outputIndex: number,
|
||||||
|
|
|
@ -89,6 +89,60 @@ describe('DirectedGraph', () => {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('getParentConnections', () => {
|
||||||
|
// ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐
|
||||||
|
// │node1├──►│node2├──►│node3│──►│node4│
|
||||||
|
// └─────┘ └─────┘ └─────┘ └─────┘
|
||||||
|
test('returns all parent connections', () => {
|
||||||
|
// ARRANGE
|
||||||
|
const node1 = createNodeData({ name: 'Node1' });
|
||||||
|
const node2 = createNodeData({ name: 'Node2' });
|
||||||
|
const node3 = createNodeData({ name: 'Node3' });
|
||||||
|
const node4 = createNodeData({ name: 'Node4' });
|
||||||
|
const graph = new DirectedGraph()
|
||||||
|
.addNodes(node1, node2, node3, node4)
|
||||||
|
.addConnections(
|
||||||
|
{ from: node1, to: node2 },
|
||||||
|
{ from: node2, to: node3 },
|
||||||
|
{ from: node3, to: node4 },
|
||||||
|
);
|
||||||
|
|
||||||
|
// ACT
|
||||||
|
const connections = graph.getParentConnections(node3);
|
||||||
|
|
||||||
|
// ASSERT
|
||||||
|
const expectedConnections = graph.getConnections().filter((c) => c.to !== node4);
|
||||||
|
expect(connections.size).toBe(2);
|
||||||
|
expect(connections).toEqual(new Set(expectedConnections));
|
||||||
|
});
|
||||||
|
|
||||||
|
// ┌─────┐ ┌─────┐ ┌─────┐
|
||||||
|
// ┌─►│node1├───►│node2├──►│node3├─┐
|
||||||
|
// │ └─────┘ └─────┘ └─────┘ │
|
||||||
|
// │ │
|
||||||
|
// └───────────────────────────────┘
|
||||||
|
test('terminates when finding a cycle', () => {
|
||||||
|
// ARRANGE
|
||||||
|
const node1 = createNodeData({ name: 'Node1' });
|
||||||
|
const node2 = createNodeData({ name: 'Node2' });
|
||||||
|
const node3 = createNodeData({ name: 'Node3' });
|
||||||
|
const graph = new DirectedGraph()
|
||||||
|
.addNodes(node1, node2, node3)
|
||||||
|
.addConnections(
|
||||||
|
{ from: node1, to: node2 },
|
||||||
|
{ from: node2, to: node3 },
|
||||||
|
{ from: node3, to: node1 },
|
||||||
|
);
|
||||||
|
|
||||||
|
// ACT
|
||||||
|
const connections = graph.getParentConnections(node3);
|
||||||
|
|
||||||
|
// ASSERT
|
||||||
|
expect(connections.size).toBe(3);
|
||||||
|
expect(connections).toEqual(new Set(graph.getConnections()));
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe('removeNode', () => {
|
describe('removeNode', () => {
|
||||||
// XX
|
// XX
|
||||||
// ┌─────┐ ┌─────┐ ┌─────┐
|
// ┌─────┐ ┌─────┐ ┌─────┐
|
||||||
|
|
|
@ -9,6 +9,8 @@
|
||||||
// XX denotes that the node is disabled
|
// XX denotes that the node is disabled
|
||||||
// PD denotes that the node has pinned data
|
// PD denotes that the node has pinned data
|
||||||
|
|
||||||
|
import { NodeConnectionType } from 'n8n-workflow';
|
||||||
|
|
||||||
import { createNodeData } from './helpers';
|
import { createNodeData } from './helpers';
|
||||||
import { DirectedGraph } from '../DirectedGraph';
|
import { DirectedGraph } from '../DirectedGraph';
|
||||||
import { findSubgraph } from '../findSubgraph';
|
import { findSubgraph } from '../findSubgraph';
|
||||||
|
@ -222,4 +224,110 @@ describe('findSubgraph', () => {
|
||||||
.addConnections({ from: trigger, to: destination }),
|
.addConnections({ from: trigger, to: destination }),
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('root nodes', () => {
|
||||||
|
// ►►
|
||||||
|
// ┌───────┐ ┌───────────┐
|
||||||
|
// │trigger├─────►│destination│
|
||||||
|
// └───────┘ └──▲────────┘
|
||||||
|
// │AiLanguageModel
|
||||||
|
// ┌┴──────┐
|
||||||
|
// │aiModel│
|
||||||
|
// └───────┘
|
||||||
|
test('always retain connections that have a different type than `NodeConnectionType.Main`', () => {
|
||||||
|
// ARRANGE
|
||||||
|
const trigger = createNodeData({ name: 'trigger' });
|
||||||
|
const destination = createNodeData({ name: 'destination' });
|
||||||
|
const aiModel = createNodeData({ name: 'ai_model' });
|
||||||
|
|
||||||
|
const graph = new DirectedGraph()
|
||||||
|
.addNodes(trigger, destination, aiModel)
|
||||||
|
.addConnections(
|
||||||
|
{ from: trigger, to: destination },
|
||||||
|
{ from: aiModel, type: NodeConnectionType.AiLanguageModel, to: destination },
|
||||||
|
);
|
||||||
|
|
||||||
|
// ACT
|
||||||
|
const subgraph = findSubgraph(graph, destination, trigger);
|
||||||
|
|
||||||
|
// ASSERT
|
||||||
|
expect(subgraph).toEqual(graph);
|
||||||
|
});
|
||||||
|
|
||||||
|
// This graph is not possible, it's only here to make sure `findSubgraph`
|
||||||
|
// does not follow non-Main connections.
|
||||||
|
//
|
||||||
|
// ┌────┐ ┌───────────┐
|
||||||
|
// │root┼───►destination│
|
||||||
|
// └──▲─┘ └───────────┘
|
||||||
|
// │AiLanguageModel
|
||||||
|
// ┌┴──────┐
|
||||||
|
// │aiModel│
|
||||||
|
// └▲──────┘
|
||||||
|
// ┌┴──────┐
|
||||||
|
// │trigger│
|
||||||
|
// └───────┘
|
||||||
|
// turns into an empty graph, because there is no `Main` typed connection
|
||||||
|
// connecting destination and trigger.
|
||||||
|
test('skip non-Main connection types', () => {
|
||||||
|
// ARRANGE
|
||||||
|
const trigger = createNodeData({ name: 'trigger' });
|
||||||
|
const root = createNodeData({ name: 'root' });
|
||||||
|
const aiModel = createNodeData({ name: 'aiModel' });
|
||||||
|
const destination = createNodeData({ name: 'destination' });
|
||||||
|
const graph = new DirectedGraph()
|
||||||
|
.addNodes(trigger, root, aiModel, destination)
|
||||||
|
.addConnections(
|
||||||
|
{ from: trigger, to: aiModel },
|
||||||
|
{ from: aiModel, type: NodeConnectionType.AiLanguageModel, to: root },
|
||||||
|
{ from: root, to: destination },
|
||||||
|
);
|
||||||
|
|
||||||
|
// ACT
|
||||||
|
const subgraph = findSubgraph(graph, destination, trigger);
|
||||||
|
|
||||||
|
// ASSERT
|
||||||
|
expect(subgraph.getConnections()).toHaveLength(0);
|
||||||
|
expect(subgraph.getNodes().size).toBe(0);
|
||||||
|
});
|
||||||
|
|
||||||
|
//
|
||||||
|
// XX
|
||||||
|
// ┌───────┐ ┌────┐ ┌───────────┐
|
||||||
|
// │trigger├───►root├───►destination│
|
||||||
|
// └───────┘ └──▲─┘ └───────────┘
|
||||||
|
// │AiLanguageModel
|
||||||
|
// ┌┴──────┐
|
||||||
|
// │aiModel│
|
||||||
|
// └───────┘
|
||||||
|
// turns into
|
||||||
|
// ┌───────┐ ┌───────────┐
|
||||||
|
// │trigger├────────────►destination│
|
||||||
|
// └───────┘ └───────────┘
|
||||||
|
test('skip disabled root nodes', () => {
|
||||||
|
// ARRANGE
|
||||||
|
const trigger = createNodeData({ name: 'trigger' });
|
||||||
|
const root = createNodeData({ name: 'root', disabled: true });
|
||||||
|
const aiModel = createNodeData({ name: 'ai_model' });
|
||||||
|
const destination = createNodeData({ name: 'destination' });
|
||||||
|
|
||||||
|
const graph = new DirectedGraph()
|
||||||
|
.addNodes(trigger, root, aiModel, destination)
|
||||||
|
.addConnections(
|
||||||
|
{ from: trigger, to: root },
|
||||||
|
{ from: aiModel, type: NodeConnectionType.AiLanguageModel, to: root },
|
||||||
|
{ from: root, to: destination },
|
||||||
|
);
|
||||||
|
|
||||||
|
// ACT
|
||||||
|
const subgraph = findSubgraph(graph, root, trigger);
|
||||||
|
|
||||||
|
// ASSERT
|
||||||
|
expect(subgraph).toEqual(
|
||||||
|
new DirectedGraph()
|
||||||
|
.addNodes(trigger, destination)
|
||||||
|
.addConnections({ from: trigger, to: destination }),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -80,7 +80,7 @@ function findStartNodesRecursive(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Recurse with every direct child that is part of the sub graph.
|
// Recurse with every direct child that is part of the sub graph.
|
||||||
const outGoingConnections = graph.getDirectChildren(current);
|
const outGoingConnections = graph.getDirectChildConnections(current);
|
||||||
for (const outGoingConnection of outGoingConnections) {
|
for (const outGoingConnection of outGoingConnections) {
|
||||||
const nodeRunData = getIncomingData(
|
const nodeRunData = getIncomingData(
|
||||||
runData,
|
runData,
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import type { INode } from 'n8n-workflow';
|
import { NodeConnectionType, type INode } from 'n8n-workflow';
|
||||||
|
|
||||||
import type { GraphConnection } from './DirectedGraph';
|
import type { GraphConnection } from './DirectedGraph';
|
||||||
import { DirectedGraph } from './DirectedGraph';
|
import { DirectedGraph } from './DirectedGraph';
|
||||||
|
@ -21,7 +21,7 @@ function findSubgraphRecursive(
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let parentConnections = graph.getDirectParents(current);
|
let parentConnections = graph.getDirectParentConnections(current);
|
||||||
|
|
||||||
// If the current node has no parents, don’t keep this branch.
|
// If the current node has no parents, don’t keep this branch.
|
||||||
if (parentConnections.length === 0) {
|
if (parentConnections.length === 0) {
|
||||||
|
@ -58,11 +58,24 @@ function findSubgraphRecursive(
|
||||||
// The node is replaced by a set of new connections, connecting the parents
|
// The node is replaced by a set of new connections, connecting the parents
|
||||||
// and children of it directly. In the recursive call below we'll follow
|
// and children of it directly. In the recursive call below we'll follow
|
||||||
// them further.
|
// them further.
|
||||||
parentConnections = graph.removeNode(current, { reconnectConnections: true });
|
parentConnections = graph.removeNode(current, {
|
||||||
|
reconnectConnections: true,
|
||||||
|
// If the node has non-Main connections we don't want to rewire those.
|
||||||
|
// Otherwise we'd end up connecting AI utilities to nodes that don't
|
||||||
|
// support them.
|
||||||
|
skipConnectionFn: (c) => c.type !== NodeConnectionType.Main,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Recurse on each parent.
|
// Recurse on each parent.
|
||||||
for (const parentConnection of parentConnections) {
|
for (const parentConnection of parentConnections) {
|
||||||
|
// Skip parents that are connected via non-Main connection types. They are
|
||||||
|
// only utility nodes for AI and are not part of the data or control flow
|
||||||
|
// and can never lead too the trigger.
|
||||||
|
if (parentConnection.type !== NodeConnectionType.Main) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
findSubgraphRecursive(graph, destinationNode, parentConnection.from, trigger, newGraph, [
|
findSubgraphRecursive(graph, destinationNode, parentConnection.from, trigger, newGraph, [
|
||||||
...currentBranch,
|
...currentBranch,
|
||||||
parentConnection,
|
parentConnection,
|
||||||
|
@ -87,15 +100,38 @@ function findSubgraphRecursive(
|
||||||
* - take every incoming connection and connect it to every node that is
|
* - take every incoming connection and connect it to every node that is
|
||||||
* connected to the current node’s first output
|
* connected to the current node’s first output
|
||||||
* 6. Recurse on each parent
|
* 6. Recurse on each parent
|
||||||
|
* 7. Re-add all connections that don't use the `Main` connections type.
|
||||||
|
* Theses are used by nodes called root nodes and they are not part of the
|
||||||
|
* dataflow in the graph they are utility nodes, like the AI model used in a
|
||||||
|
* lang chain node.
|
||||||
*/
|
*/
|
||||||
export function findSubgraph(
|
export function findSubgraph(
|
||||||
graph: DirectedGraph,
|
graph: DirectedGraph,
|
||||||
destinationNode: INode,
|
destinationNode: INode,
|
||||||
trigger: INode,
|
trigger: INode,
|
||||||
): DirectedGraph {
|
): DirectedGraph {
|
||||||
const newGraph = new DirectedGraph();
|
const subgraph = new DirectedGraph();
|
||||||
|
|
||||||
findSubgraphRecursive(graph, destinationNode, destinationNode, trigger, newGraph, []);
|
findSubgraphRecursive(graph, destinationNode, destinationNode, trigger, subgraph, []);
|
||||||
|
|
||||||
return newGraph;
|
// For each node in the subgraph, if it has parent connections of a type that
|
||||||
|
// is not `Main` in the input graph, add the connections and the nodes
|
||||||
|
// connected to it to the subgraph
|
||||||
|
//
|
||||||
|
// Without this all AI related workflows would not work when executed
|
||||||
|
// partially, because all utility nodes would be missing.
|
||||||
|
for (const node of subgraph.getNodes().values()) {
|
||||||
|
const parentConnections = graph.getParentConnections(node);
|
||||||
|
|
||||||
|
for (const connection of parentConnections) {
|
||||||
|
if (connection.type === NodeConnectionType.Main) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
subgraph.addNodes(connection.from, connection.to);
|
||||||
|
subgraph.addConnection(connection);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return subgraph;
|
||||||
}
|
}
|
||||||
|
|
|
@ -64,7 +64,7 @@ export function recreateNodeExecutionStack(
|
||||||
|
|
||||||
for (const startNode of startNodes) {
|
for (const startNode of startNodes) {
|
||||||
const incomingStartNodeConnections = graph
|
const incomingStartNodeConnections = graph
|
||||||
.getDirectParents(startNode)
|
.getDirectParentConnections(startNode)
|
||||||
.filter((c) => c.type === NodeConnectionType.Main);
|
.filter((c) => c.type === NodeConnectionType.Main);
|
||||||
|
|
||||||
let incomingData: INodeExecutionData[][] = [];
|
let incomingData: INodeExecutionData[][] = [];
|
||||||
|
@ -135,7 +135,7 @@ export function recreateNodeExecutionStack(
|
||||||
// Check if the destinationNode has to be added as waiting
|
// Check if the destinationNode has to be added as waiting
|
||||||
// because some input data is already fully available
|
// because some input data is already fully available
|
||||||
const incomingDestinationNodeConnections = graph
|
const incomingDestinationNodeConnections = graph
|
||||||
.getDirectParents(destinationNode)
|
.getDirectParentConnections(destinationNode)
|
||||||
.filter((c) => c.type === NodeConnectionType.Main);
|
.filter((c) => c.type === NodeConnectionType.Main);
|
||||||
if (incomingDestinationNodeConnections !== undefined) {
|
if (incomingDestinationNodeConnections !== undefined) {
|
||||||
for (const connection of incomingDestinationNodeConnections) {
|
for (const connection of incomingDestinationNodeConnections) {
|
||||||
|
|
Loading…
Reference in a new issue