From a1cf28dce9c280b6f815905b0becd4bc2a9df3a0 Mon Sep 17 00:00:00 2001 From: Charlie Kolb Date: Wed, 13 Nov 2024 09:36:08 +0100 Subject: [PATCH] Add working version --- .../NodeCreator/composables/useActions.ts | 4 +- .../src/composables/useCanvasOperations.ts | 90 +++++++++++-------- .../src/utils/connectionNodeUtils.test.ts | 67 ++++++++++---- .../src/utils/connectionNodeUtils.ts | 44 +++++---- 4 files changed, 133 insertions(+), 72 deletions(-) diff --git a/packages/editor-ui/src/components/Node/NodeCreator/composables/useActions.ts b/packages/editor-ui/src/components/Node/NodeCreator/composables/useActions.ts index 65a53f76aa..3dde2d9376 100644 --- a/packages/editor-ui/src/components/Node/NodeCreator/composables/useActions.ts +++ b/packages/editor-ui/src/components/Node/NodeCreator/composables/useActions.ts @@ -43,7 +43,7 @@ import { useExternalHooks } from '@/composables/useExternalHooks'; import { sortNodeCreateElements, transformNodeType } from '../utils'; import { useI18n } from '@/composables/useI18n'; import { useCanvasStore } from '@/stores/canvas.store'; -import { adjustNewlyConnectedNodes } from '@/utils/connectionNodeUtils'; +import { adjustNewNodes } from '@/utils/connectionNodeUtils'; export const useActions = () => { const nodeCreatorStore = useNodeCreatorStore(); @@ -287,7 +287,7 @@ export const useActions = () => { } if (addedNodes.length === 2) { - adjustNewlyConnectedNodes(addedNodes[0], addedNodes[1]); + adjustNewNodes(addedNodes[0], addedNodes[1]); } addedNodes.forEach((node, index) => { diff --git a/packages/editor-ui/src/composables/useCanvasOperations.ts b/packages/editor-ui/src/composables/useCanvasOperations.ts index 32fa5b0c80..421c92f3a3 100644 --- a/packages/editor-ui/src/composables/useCanvasOperations.ts +++ b/packages/editor-ui/src/composables/useCanvasOperations.ts @@ -96,7 +96,7 @@ import type { useRouter } from 'vue-router'; import { useClipboard } from '@/composables/useClipboard'; import { useUniqueNodeName } from '@/composables/useUniqueNodeName'; import { isPresent } from '../utils/typesUtils'; -import { adjustNewlyConnectedNodes } from '@/utils/connectionNodeUtils'; +import { adjustNewNodes } from '@/utils/connectionNodeUtils'; type AddNodeData = Partial & { type: string; @@ -673,37 +673,46 @@ export function useCanvasOperations({ router }: { router: ReturnType ({ + useWorkflowsStore: vi.fn(() => ({ + getParentNodesByDepth, + getNode, + })), +})); describe('adjustNewlyConnectedNodes', () => { - it('modifies promptType with ChatTrigger->Agent', () => { + beforeEach(() => { + setActivePinia(createPinia()); + }); + + afterEach(() => { + vi.clearAllMocks(); + }); + + it('modifies promptType with ChatTrigger->new Agent', () => { const parent = { type: CHAT_TRIGGER_NODE_TYPE }; const child = { type: AGENT_NODE_TYPE }; - adjustNewlyConnectedNodes(parent, child); + adjustNewNodes(parent, child, { parentIsNew: false }); expect(child).toEqual({ type: AGENT_NODE_TYPE, }); }); - it('does not modify promptType with ManualTrigger->Agent', () => { + + it('modifies promptType with new ChatTrigger->new Agent', () => { + const parent = { type: CHAT_TRIGGER_NODE_TYPE }; + const child = { type: AGENT_NODE_TYPE }; + adjustNewNodes(parent, child); + expect(child).toEqual({ + type: AGENT_NODE_TYPE, + }); + }); + + it('does not modify promptType with ManualTrigger->new Agent', () => { const parent = { type: MANUAL_TRIGGER_NODE_TYPE }; const child = { type: AGENT_NODE_TYPE }; - adjustNewlyConnectedNodes(parent, child); + adjustNewNodes(parent, child, { parentIsNew: false }); expect(child).toEqual({ type: AGENT_NODE_TYPE, parameters: { promptType: 'define' }, }); }); - it('modifies sessionId with ChatTrigger->Memory', () => { - const parent = { type: CHAT_TRIGGER_NODE_TYPE }; - const child = { type: '@n8n/n8n-nodes-langchain.memoryBufferWindow' }; - adjustNewlyConnectedNodes(parent, child); - expect(child).toEqual({ + it('modifies sessionId with ChatTrigger->(new Memory->Agent)', () => { + const trigger = { type: CHAT_TRIGGER_NODE_TYPE, name: 'trigger' }; + getParentNodesByDepth.mockReturnValue([{ name: trigger.name }]); + getNode.mockReturnValue({ type: trigger.type }); + + const child = { type: AGENT_NODE_TYPE }; + const parent = { type: '@n8n/n8n-nodes-langchain.memoryBufferWindow' }; + adjustNewNodes(parent, child, { childIsNew: false }); + expect(parent).toEqual({ type: '@n8n/n8n-nodes-langchain.memoryBufferWindow', }); }); - it('does not modify sessionId with ManualTrigger->Memory', () => { - const parent = { type: MANUAL_TRIGGER_NODE_TYPE }; - const child = { type: '@n8n/n8n-nodes-langchain.memoryBufferWindow' }; - adjustNewlyConnectedNodes(parent, child); - expect(child).toEqual({ + it('does not modify sessionId with ManualTrigger->(new Memory->Agent)', () => { + const trigger = { type: MANUAL_TRIGGER_NODE_TYPE, name: 'trigger' }; + getParentNodesByDepth.mockReturnValue([{ name: trigger.name }]); + getNode.mockReturnValue({ type: trigger.type }); + + const child = { type: AGENT_NODE_TYPE, name: 'myAgent' }; + const parent = { type: '@n8n/n8n-nodes-langchain.memoryBufferWindow' }; + adjustNewNodes(parent, child, { childIsNew: false }); + expect(parent).toEqual({ type: '@n8n/n8n-nodes-langchain.memoryBufferWindow', parameters: { sessionIdType: 'customKey' }, }); diff --git a/packages/editor-ui/src/utils/connectionNodeUtils.ts b/packages/editor-ui/src/utils/connectionNodeUtils.ts index 5ea81f91d6..bbb0e1e5e0 100644 --- a/packages/editor-ui/src/utils/connectionNodeUtils.ts +++ b/packages/editor-ui/src/utils/connectionNodeUtils.ts @@ -7,8 +7,8 @@ import { OPEN_AI_NODE_MESSAGE_ASSISTANT_TYPE, QA_CHAIN_NODE_TYPE, } from '@/constants'; -import { getParentNodes } from '@/components/ButtonParameter/utils'; import { useWorkflowsStore } from '@/stores/workflows.store'; +import type { AddedNode } from '@/Interface'; const AI_NODES = [ QA_CHAIN_NODE_TYPE, @@ -29,25 +29,37 @@ const MEMORY_NODE_NAMES = [ const PROMPT_PROVIDER_NODE_NAMES = [CHAT_TRIGGER_NODE_TYPE]; -type NodeWithType = Pick; - -const { getCurrentWorkflow, getNodeByName } = useWorkflowsStore(); - -export function adjustNewlyConnectedNodes(parent: INode, child: INode) { - const workflow = getCurrentWorkflow(); - - if (workflow.getParentNodesByDepth(child.name, 1).length > 0) { - return; - } +export function adjustNewNodes( + parent: AddedNode, + child: AddedNode, + { parentIsNew = true, childIsNew = true } = {}, +) { + if (childIsNew) adjustNewChild(parent, child); + if (parentIsNew) adjustNewParent(parent, child); +} +function adjustNewChild(parent: AddedNode, child: AddedNode) { if (!PROMPT_PROVIDER_NODE_NAMES.includes(parent.type) && AI_NODES.includes(child.type)) { - Object.assign>(child, { + Object.assign>(child, { parameters: { promptType: 'define' }, }); } - if (!PROMPT_PROVIDER_NODE_NAMES.includes(parent.type) && MEMORY_NODE_NAMES.includes(child.type)) { - Object.assign>(child, { - parameters: { sessionIdType: 'customKey' }, - }); +} + +function adjustNewParent(parent: AddedNode, child: AddedNode) { + if (MEMORY_NODE_NAMES.includes(parent.type) && child.name) { + const { getCurrentWorkflow } = useWorkflowsStore(); + const workflow = getCurrentWorkflow(); + + // If a memory node is added to an Agent, the memory node is actually the parent since it provides input + // So we need to look for the Agent's parents to determine if it's a prompt provider + const ps = workflow.getParentNodesByDepth(child.name, 1); + if ( + !ps.some((x) => PROMPT_PROVIDER_NODE_NAMES.includes(workflow.getNode(x.name)?.type ?? '')) + ) { + Object.assign>(parent, { + parameters: { sessionIdType: 'customKey' }, + }); + } } }