Add working version

This commit is contained in:
Charlie Kolb 2024-11-13 09:36:08 +01:00
parent 2f33d6f352
commit a1cf28dce9
No known key found for this signature in database
4 changed files with 133 additions and 72 deletions

View file

@ -43,7 +43,7 @@ import { useExternalHooks } from '@/composables/useExternalHooks';
import { sortNodeCreateElements, transformNodeType } from '../utils'; import { sortNodeCreateElements, transformNodeType } from '../utils';
import { useI18n } from '@/composables/useI18n'; import { useI18n } from '@/composables/useI18n';
import { useCanvasStore } from '@/stores/canvas.store'; import { useCanvasStore } from '@/stores/canvas.store';
import { adjustNewlyConnectedNodes } from '@/utils/connectionNodeUtils'; import { adjustNewNodes } from '@/utils/connectionNodeUtils';
export const useActions = () => { export const useActions = () => {
const nodeCreatorStore = useNodeCreatorStore(); const nodeCreatorStore = useNodeCreatorStore();
@ -287,7 +287,7 @@ export const useActions = () => {
} }
if (addedNodes.length === 2) { if (addedNodes.length === 2) {
adjustNewlyConnectedNodes(addedNodes[0], addedNodes[1]); adjustNewNodes(addedNodes[0], addedNodes[1]);
} }
addedNodes.forEach((node, index) => { addedNodes.forEach((node, index) => {

View file

@ -96,7 +96,7 @@ import type { useRouter } from 'vue-router';
import { useClipboard } from '@/composables/useClipboard'; import { useClipboard } from '@/composables/useClipboard';
import { useUniqueNodeName } from '@/composables/useUniqueNodeName'; import { useUniqueNodeName } from '@/composables/useUniqueNodeName';
import { isPresent } from '../utils/typesUtils'; import { isPresent } from '../utils/typesUtils';
import { adjustNewlyConnectedNodes } from '@/utils/connectionNodeUtils'; import { adjustNewNodes } from '@/utils/connectionNodeUtils';
type AddNodeData = Partial<INodeUi> & { type AddNodeData = Partial<INodeUi> & {
type: string; type: string;
@ -673,37 +673,46 @@ export function useCanvasOperations({ router }: { router: ReturnType<typeof useR
}); });
if (mode === CanvasConnectionMode.Input) { if (mode === CanvasConnectionMode.Input) {
createConnection({ createConnection(
source: nodeId, {
sourceHandle: nodeHandle, source: nodeId,
target: lastInteractedWithNodeId, sourceHandle: nodeHandle,
targetHandle: lastInteractedWithNodeHandle, target: lastInteractedWithNodeId,
}); targetHandle: lastInteractedWithNodeHandle,
},
{ parentIsNew: true },
);
} else { } else {
createConnection({ createConnection(
source: lastInteractedWithNodeId, {
sourceHandle: lastInteractedWithNodeHandle, source: lastInteractedWithNodeId,
target: nodeId, sourceHandle: lastInteractedWithNodeHandle,
targetHandle: nodeHandle, target: nodeId,
}); targetHandle: nodeHandle,
},
{ childIsNew: true },
);
} }
} else { } else {
// If a node is last selected then connect between the active and its child ones // If a node is last selected then connect between the active and its child ones
// Connect active node to the newly created one // Connect active node to the newly created one
createConnection({ createConnection(
source: lastInteractedWithNodeId, {
sourceHandle: createCanvasConnectionHandleString({ source: lastInteractedWithNodeId,
mode: CanvasConnectionMode.Output, sourceHandle: createCanvasConnectionHandleString({
type: NodeConnectionType.Main, mode: CanvasConnectionMode.Output,
index: 0, type: NodeConnectionType.Main,
}), index: 0,
target: node.id, }),
targetHandle: createCanvasConnectionHandleString({ target: node.id,
mode: CanvasConnectionMode.Input, targetHandle: createCanvasConnectionHandleString({
type: NodeConnectionType.Main, mode: CanvasConnectionMode.Input,
index: 0, type: NodeConnectionType.Main,
}), index: 0,
}); }),
},
{ childIsNew: true },
);
} }
if (lastInteractedWithNodeConnection) { if (lastInteractedWithNodeConnection) {
@ -711,16 +720,19 @@ export function useCanvasOperations({ router }: { router: ReturnType<typeof useR
const targetNode = workflowsStore.getNodeById(lastInteractedWithNodeConnection.target); const targetNode = workflowsStore.getNodeById(lastInteractedWithNodeConnection.target);
if (targetNode) { if (targetNode) {
createConnection({ createConnection(
source: node.id, {
sourceHandle: createCanvasConnectionHandleString({ source: node.id,
mode: CanvasConnectionMode.Input, sourceHandle: createCanvasConnectionHandleString({
type: NodeConnectionType.Main, mode: CanvasConnectionMode.Input,
index: 0, type: NodeConnectionType.Main,
}), index: 0,
target: lastInteractedWithNodeConnection.target, }),
targetHandle: lastInteractedWithNodeConnection.targetHandle, target: lastInteractedWithNodeConnection.target,
}); targetHandle: lastInteractedWithNodeConnection.targetHandle,
},
{ parentIsNew: true },
);
} }
} }
} }
@ -1079,7 +1091,7 @@ export function useCanvasOperations({ router }: { router: ReturnType<typeof useR
function createConnection( function createConnection(
connection: Connection, connection: Connection,
{ trackHistory = false, keepPristine = false } = {}, { trackHistory = false, keepPristine = false, parentIsNew = false, childIsNew = false } = {},
) { ) {
const sourceNode = workflowsStore.getNodeById(connection.source); const sourceNode = workflowsStore.getNodeById(connection.source);
const targetNode = workflowsStore.getNodeById(connection.target); const targetNode = workflowsStore.getNodeById(connection.target);
@ -1105,7 +1117,7 @@ export function useCanvasOperations({ router }: { router: ReturnType<typeof useR
return; return;
} }
adjustNewlyConnectedNodes(sourceNode, targetNode); adjustNewNodes(sourceNode, targetNode, { parentIsNew, childIsNew });
workflowsStore.addConnection({ workflowsStore.addConnection({
connection: mappedConnection, connection: mappedConnection,

View file

@ -1,39 +1,76 @@
import { useWorkflowsStore } from '@/stores/workflows.store';
import { AGENT_NODE_TYPE, CHAT_TRIGGER_NODE_TYPE, MANUAL_TRIGGER_NODE_TYPE } from '@/constants'; import { AGENT_NODE_TYPE, CHAT_TRIGGER_NODE_TYPE, MANUAL_TRIGGER_NODE_TYPE } from '@/constants';
import { adjustNewlyConnectedNodes } from './connectionNodeUtils'; import { adjustNewNodes } from '@/utils/connectionNodeUtils';
import { createPinia, setActivePinia } from 'pinia';
const getParentNodesByDepth = vi.fn();
const getNode = vi.fn();
vi.mock('@/stores/workflow.store', () => ({
useWorkflowsStore: vi.fn(() => ({
getParentNodesByDepth,
getNode,
})),
}));
describe('adjustNewlyConnectedNodes', () => { describe('adjustNewlyConnectedNodes', () => {
it('modifies promptType with ChatTrigger->Agent', () => { beforeEach(() => {
setActivePinia(createPinia());
});
afterEach(() => {
vi.clearAllMocks();
});
it('modifies promptType with ChatTrigger->new Agent', () => {
const parent = { type: CHAT_TRIGGER_NODE_TYPE }; const parent = { type: CHAT_TRIGGER_NODE_TYPE };
const child = { type: AGENT_NODE_TYPE }; const child = { type: AGENT_NODE_TYPE };
adjustNewlyConnectedNodes(parent, child); adjustNewNodes(parent, child, { parentIsNew: false });
expect(child).toEqual({ expect(child).toEqual({
type: AGENT_NODE_TYPE, type: AGENT_NODE_TYPE,
}); });
}); });
it('does not modify promptType with ManualTrigger->Agent', () => {
it('modifies promptType with new ChatTrigger->new Agent', () => {
const parent = { type: CHAT_TRIGGER_NODE_TYPE };
const child = { type: AGENT_NODE_TYPE };
adjustNewNodes(parent, child);
expect(child).toEqual({
type: AGENT_NODE_TYPE,
});
});
it('does not modify promptType with ManualTrigger->new Agent', () => {
const parent = { type: MANUAL_TRIGGER_NODE_TYPE }; const parent = { type: MANUAL_TRIGGER_NODE_TYPE };
const child = { type: AGENT_NODE_TYPE }; const child = { type: AGENT_NODE_TYPE };
adjustNewlyConnectedNodes(parent, child); adjustNewNodes(parent, child, { parentIsNew: false });
expect(child).toEqual({ expect(child).toEqual({
type: AGENT_NODE_TYPE, type: AGENT_NODE_TYPE,
parameters: { promptType: 'define' }, parameters: { promptType: 'define' },
}); });
}); });
it('modifies sessionId with ChatTrigger->Memory', () => { it('modifies sessionId with ChatTrigger->(new Memory->Agent)', () => {
const parent = { type: CHAT_TRIGGER_NODE_TYPE }; const trigger = { type: CHAT_TRIGGER_NODE_TYPE, name: 'trigger' };
const child = { type: '@n8n/n8n-nodes-langchain.memoryBufferWindow' }; getParentNodesByDepth.mockReturnValue([{ name: trigger.name }]);
adjustNewlyConnectedNodes(parent, child); getNode.mockReturnValue({ type: trigger.type });
expect(child).toEqual({
const child = { type: AGENT_NODE_TYPE };
const parent = { type: '@n8n/n8n-nodes-langchain.memoryBufferWindow' };
adjustNewNodes(parent, child, { childIsNew: false });
expect(parent).toEqual({
type: '@n8n/n8n-nodes-langchain.memoryBufferWindow', type: '@n8n/n8n-nodes-langchain.memoryBufferWindow',
}); });
}); });
it('does not modify sessionId with ManualTrigger->Memory', () => { it('does not modify sessionId with ManualTrigger->(new Memory->Agent)', () => {
const parent = { type: MANUAL_TRIGGER_NODE_TYPE }; const trigger = { type: MANUAL_TRIGGER_NODE_TYPE, name: 'trigger' };
const child = { type: '@n8n/n8n-nodes-langchain.memoryBufferWindow' }; getParentNodesByDepth.mockReturnValue([{ name: trigger.name }]);
adjustNewlyConnectedNodes(parent, child); getNode.mockReturnValue({ type: trigger.type });
expect(child).toEqual({
const child = { type: AGENT_NODE_TYPE, name: 'myAgent' };
const parent = { type: '@n8n/n8n-nodes-langchain.memoryBufferWindow' };
adjustNewNodes(parent, child, { childIsNew: false });
expect(parent).toEqual({
type: '@n8n/n8n-nodes-langchain.memoryBufferWindow', type: '@n8n/n8n-nodes-langchain.memoryBufferWindow',
parameters: { sessionIdType: 'customKey' }, parameters: { sessionIdType: 'customKey' },
}); });

View file

@ -7,8 +7,8 @@ import {
OPEN_AI_NODE_MESSAGE_ASSISTANT_TYPE, OPEN_AI_NODE_MESSAGE_ASSISTANT_TYPE,
QA_CHAIN_NODE_TYPE, QA_CHAIN_NODE_TYPE,
} from '@/constants'; } from '@/constants';
import { getParentNodes } from '@/components/ButtonParameter/utils';
import { useWorkflowsStore } from '@/stores/workflows.store'; import { useWorkflowsStore } from '@/stores/workflows.store';
import type { AddedNode } from '@/Interface';
const AI_NODES = [ const AI_NODES = [
QA_CHAIN_NODE_TYPE, QA_CHAIN_NODE_TYPE,
@ -29,25 +29,37 @@ const MEMORY_NODE_NAMES = [
const PROMPT_PROVIDER_NODE_NAMES = [CHAT_TRIGGER_NODE_TYPE]; const PROMPT_PROVIDER_NODE_NAMES = [CHAT_TRIGGER_NODE_TYPE];
type NodeWithType = Pick<INode, 'type'>; export function adjustNewNodes(
parent: AddedNode,
const { getCurrentWorkflow, getNodeByName } = useWorkflowsStore(); child: AddedNode,
{ parentIsNew = true, childIsNew = true } = {},
export function adjustNewlyConnectedNodes(parent: INode, child: INode) { ) {
const workflow = getCurrentWorkflow(); if (childIsNew) adjustNewChild(parent, child);
if (parentIsNew) adjustNewParent(parent, child);
if (workflow.getParentNodesByDepth(child.name, 1).length > 0) { }
return;
}
function adjustNewChild(parent: AddedNode, child: AddedNode) {
if (!PROMPT_PROVIDER_NODE_NAMES.includes(parent.type) && AI_NODES.includes(child.type)) { if (!PROMPT_PROVIDER_NODE_NAMES.includes(parent.type) && AI_NODES.includes(child.type)) {
Object.assign<INode, Partial<INode>>(child, { Object.assign<AddedNode, Partial<INode>>(child, {
parameters: { promptType: 'define' }, parameters: { promptType: 'define' },
}); });
} }
if (!PROMPT_PROVIDER_NODE_NAMES.includes(parent.type) && MEMORY_NODE_NAMES.includes(child.type)) { }
Object.assign<INode, Partial<INode>>(child, {
parameters: { sessionIdType: 'customKey' }, function adjustNewParent(parent: AddedNode, child: AddedNode) {
}); if (MEMORY_NODE_NAMES.includes(parent.type) && child.name) {
const { getCurrentWorkflow } = useWorkflowsStore();
const workflow = getCurrentWorkflow();
// If a memory node is added to an Agent, the memory node is actually the parent since it provides input
// So we need to look for the Agent's parents to determine if it's a prompt provider
const ps = workflow.getParentNodesByDepth(child.name, 1);
if (
!ps.some((x) => PROMPT_PROVIDER_NODE_NAMES.includes(workflow.getNode(x.name)?.type ?? ''))
) {
Object.assign<AddedNode, Partial<INode>>(parent, {
parameters: { sessionIdType: 'customKey' },
});
}
} }
} }