fix(AI Agent Node): Move model retrieval into try/catch to fix continueOnFail handling (#13165)

This commit is contained in:
oleg 2025-02-13 15:47:41 +01:00 committed by GitHub
parent ba95f97d10
commit 47c5688618
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 602 additions and 172 deletions

View file

@ -1,4 +1,5 @@
import type { BaseChatMemory } from '@langchain/community/memory/chat_memory'; import type { BaseChatMemory } from '@langchain/community/memory/chat_memory';
import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { HumanMessage } from '@langchain/core/messages'; import { HumanMessage } from '@langchain/core/messages';
import type { BaseMessage } from '@langchain/core/messages'; import type { BaseMessage } from '@langchain/core/messages';
import type { BaseMessagePromptTemplateLike } from '@langchain/core/prompts'; import type { BaseMessagePromptTemplateLike } from '@langchain/core/prompts';
@ -8,6 +9,7 @@ import type { Tool } from '@langchain/core/tools';
import { DynamicStructuredTool } from '@langchain/core/tools'; import { DynamicStructuredTool } from '@langchain/core/tools';
import type { AgentAction, AgentFinish } from 'langchain/agents'; import type { AgentAction, AgentFinish } from 'langchain/agents';
import { AgentExecutor, createToolCallingAgent } from 'langchain/agents'; import { AgentExecutor, createToolCallingAgent } from 'langchain/agents';
import type { ToolsAgentAction } from 'langchain/dist/agents/tool_calling/output_parser';
import { omit } from 'lodash'; import { omit } from 'lodash';
import { BINARY_ENCODING, jsonParse, NodeConnectionType, NodeOperationError } from 'n8n-workflow'; import { BINARY_ENCODING, jsonParse, NodeConnectionType, NodeOperationError } from 'n8n-workflow';
import type { IExecuteFunctions, INodeExecutionData } from 'n8n-workflow'; import type { IExecuteFunctions, INodeExecutionData } from 'n8n-workflow';
@ -22,28 +24,53 @@ import {
import { SYSTEM_MESSAGE } from './prompt'; import { SYSTEM_MESSAGE } from './prompt';
function getOutputParserSchema(outputParser: N8nOutputParser): ZodObject<any, any, any, any> { /* -----------------------------------------------------------
Output Parser Helper
----------------------------------------------------------- */
/**
* Retrieve the output parser schema.
* If the parser does not return a valid schema, default to a schema with a single text field.
*/
export function getOutputParserSchema(
outputParser: N8nOutputParser,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
): ZodObject<any, any, any, any> {
const schema = const schema =
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(outputParser.getSchema() as ZodObject<any, any, any, any>) ?? z.object({ text: z.string() }); (outputParser.getSchema() as ZodObject<any, any, any, any>) ?? z.object({ text: z.string() });
return schema; return schema;
} }
async function extractBinaryMessages(ctx: IExecuteFunctions) { /* -----------------------------------------------------------
const binaryData = ctx.getInputData()?.[0]?.binary ?? {}; Binary Data Helpers
----------------------------------------------------------- */
/**
* Extracts binary image messages from the input data.
* When operating in filesystem mode, the binary stream is first converted to a buffer.
*
* @param ctx - The execution context
* @param itemIndex - The current item index
* @returns A HumanMessage containing the binary image messages.
*/
export async function extractBinaryMessages(
ctx: IExecuteFunctions,
itemIndex: number,
): Promise<HumanMessage> {
const binaryData = ctx.getInputData()?.[itemIndex]?.binary ?? {};
const binaryMessages = await Promise.all( const binaryMessages = await Promise.all(
Object.values(binaryData) Object.values(binaryData)
.filter((data) => data.mimeType.startsWith('image/')) .filter((data) => data.mimeType.startsWith('image/'))
.map(async (data) => { .map(async (data) => {
let binaryUrlString; let binaryUrlString: string;
// In filesystem mode we need to get binary stream by id before converting it to buffer // In filesystem mode we need to get binary stream by id before converting it to buffer
if (data.id) { if (data.id) {
const binaryBuffer = await ctx.helpers.binaryToBuffer( const binaryBuffer = await ctx.helpers.binaryToBuffer(
await ctx.helpers.getBinaryStream(data.id), await ctx.helpers.getBinaryStream(data.id),
); );
binaryUrlString = `data:${data.mimeType};base64,${Buffer.from(binaryBuffer).toString(
binaryUrlString = `data:${data.mimeType};base64,${Buffer.from(binaryBuffer).toString(BINARY_ENCODING)}`; BINARY_ENCODING,
)}`;
} else { } else {
binaryUrlString = data.data.includes('base64') binaryUrlString = data.data.includes('base64')
? data.data ? data.data
@ -62,6 +89,10 @@ async function extractBinaryMessages(ctx: IExecuteFunctions) {
content: [...binaryMessages], content: [...binaryMessages],
}); });
} }
/* -----------------------------------------------------------
Agent Output Format Helpers
----------------------------------------------------------- */
/** /**
* Fixes empty content messages in agent steps. * Fixes empty content messages in agent steps.
* *
@ -73,7 +104,9 @@ async function extractBinaryMessages(ctx: IExecuteFunctions) {
* @param steps - The agent steps to fix * @param steps - The agent steps to fix
* @returns The fixed agent steps * @returns The fixed agent steps
*/ */
function fixEmptyContentMessage(steps: AgentFinish | AgentAction[]) { export function fixEmptyContentMessage(
steps: AgentFinish | ToolsAgentAction[],
): AgentFinish | ToolsAgentAction[] {
if (!Array.isArray(steps)) return steps; if (!Array.isArray(steps)) return steps;
steps.forEach((step) => { steps.forEach((step) => {
@ -96,24 +129,6 @@ function fixEmptyContentMessage(steps: AgentFinish | AgentAction[]) {
return steps; return steps;
} }
export async function toolsAgentExecute(this: IExecuteFunctions): Promise<INodeExecutionData[][]> {
this.logger.debug('Executing Tools Agent');
const model = await this.getInputConnectionData(NodeConnectionType.AiLanguageModel, 0);
if (!isChatInstance(model) || !model.bindTools) {
throw new NodeOperationError(
this.getNode(),
'Tools Agent requires Chat Model which supports Tools calling',
);
}
const memory = (await this.getInputConnectionData(NodeConnectionType.AiMemory, 0)) as
| BaseChatMemory
| undefined;
const tools = (await getConnectedTools(this, true, false)) as Array<DynamicStructuredTool | Tool>;
const outputParser = (await getOptionalOutputParsers(this))?.[0];
let structuredOutputParserTool: DynamicStructuredTool | undefined;
/** /**
* Ensures consistent handling of outputs regardless of the model used, * Ensures consistent handling of outputs regardless of the model used,
* providing a unified output format for further processing. * providing a unified output format for further processing.
@ -145,8 +160,9 @@ export async function toolsAgentExecute(this: IExecuteFunctions): Promise<INodeE
* @param steps - The agent finish or agent action steps. * @param steps - The agent finish or agent action steps.
* @returns The modified agent finish steps or the original steps. * @returns The modified agent finish steps or the original steps.
*/ */
function handleAgentFinishOutput(steps: AgentFinish | AgentAction[]) { export function handleAgentFinishOutput(
// Check if the steps contain multiple outputs steps: AgentFinish | AgentAction[],
): AgentFinish | AgentAction[] {
type AgentMultiOutputFinish = AgentFinish & { type AgentMultiOutputFinish = AgentFinish & {
returnValues: { output: Array<{ text: string; type: string; index: number }> }; returnValues: { output: Array<{ text: string; type: string; index: number }> };
}; };
@ -154,17 +170,15 @@ export async function toolsAgentExecute(this: IExecuteFunctions): Promise<INodeE
if (agentFinishSteps.returnValues) { if (agentFinishSteps.returnValues) {
const isMultiOutput = Array.isArray(agentFinishSteps.returnValues?.output); const isMultiOutput = Array.isArray(agentFinishSteps.returnValues?.output);
if (isMultiOutput) { if (isMultiOutput) {
// Define the type for each item in the multi-output array // If all items in the multi-output array are of type 'text', merge them into a single string
type MultiOutputItem = { index: number; type: string; text: string }; const multiOutputSteps = agentFinishSteps.returnValues.output as Array<{
const multiOutputSteps = agentFinishSteps.returnValues.output as MultiOutputItem[]; index: number;
type: string;
// Check if all items in the multi-output array are of type 'text' text: string;
const isTextOnly = (multiOutputSteps ?? []).every((output) => 'text' in output); }>;
const isTextOnly = multiOutputSteps.every((output) => 'text' in output);
if (isTextOnly) { if (isTextOnly) {
// If all items are of type 'text', merge them into a single string
agentFinishSteps.returnValues.output = multiOutputSteps agentFinishSteps.returnValues.output = multiOutputSteps
.map((output) => output.text) .map((output) => output.text)
.join('\n') .join('\n')
@ -174,33 +188,52 @@ export async function toolsAgentExecute(this: IExecuteFunctions): Promise<INodeE
} }
} }
// If the steps do not contain multiple outputs, return them as is
return agentFinishSteps; return agentFinishSteps;
} }
// If memory is connected we need to stringify the returnValues so that it can be saved in the memory as a string /**
function handleParsedStepOutput(output: Record<string, unknown>) { * Wraps the parsed output so that it can be stored in memory.
* If memory is connected, the output is stringified.
*
* @param output - The parsed output object
* @param memory - The connected memory (if any)
* @returns The formatted output object
*/
export function handleParsedStepOutput(
output: Record<string, unknown>,
memory?: BaseChatMemory,
): { returnValues: Record<string, unknown>; log: string } {
return { return {
returnValues: memory ? { output: JSON.stringify(output) } : output, returnValues: memory ? { output: JSON.stringify(output) } : output,
log: 'Final response formatted', log: 'Final response formatted',
}; };
} }
async function agentStepsParser(
steps: AgentFinish | AgentAction[], /**
): Promise<AgentFinish | AgentAction[]> { * Parses agent steps using the provided output parser.
* If the agent used the 'format_final_response' tool, the output is parsed accordingly.
*
* @param steps - The agent finish or action steps
* @param outputParser - The output parser (if defined)
* @param memory - The connected memory (if any)
* @returns The parsed steps with the final output
*/
export const getAgentStepsParser =
(outputParser?: N8nOutputParser, memory?: BaseChatMemory) =>
async (steps: AgentFinish | AgentAction[]): Promise<AgentFinish | AgentAction[]> => {
// Check if the steps contain the 'format_final_response' tool invocation.
if (Array.isArray(steps)) { if (Array.isArray(steps)) {
const responseParserTool = steps.find((step) => step.tool === 'format_final_response'); const responseParserTool = steps.find((step) => step.tool === 'format_final_response');
if (responseParserTool) { if (responseParserTool && outputParser) {
const toolInput = responseParserTool?.toolInput; const toolInput = responseParserTool.toolInput;
// Check if the tool input is a string or an object and convert it to a string // Ensure the tool input is a string
const parserInput = toolInput instanceof Object ? JSON.stringify(toolInput) : toolInput; const parserInput = toolInput instanceof Object ? JSON.stringify(toolInput) : toolInput;
const returnValues = (await outputParser.parse(parserInput)) as Record<string, unknown>; const returnValues = (await outputParser.parse(parserInput)) as Record<string, unknown>;
return handleParsedStepOutput(returnValues, memory);
return handleParsedStepOutput(returnValues);
} }
} }
// If the steps are an AgentFinish and the outputParser is defined it must mean that the LLM didn't use `format_final_response` tool so we will try to parse the output manually // Otherwise, if the steps contain a returnValues field, try to parse them manually.
if (outputParser && typeof steps === 'object' && (steps as AgentFinish).returnValues) { if (outputParser && typeof steps === 'object' && (steps as AgentFinish).returnValues) {
const finalResponse = (steps as AgentFinish).returnValues; const finalResponse = (steps as AgentFinish).returnValues;
let parserInput: string; let parserInput: string;
@ -213,7 +246,7 @@ export async function toolsAgentExecute(this: IExecuteFunctions): Promise<INodeE
// so we try to parse the output before wrapping it and then stringify it // so we try to parse the output before wrapping it and then stringify it
parserInput = JSON.stringify({ output: jsonParse(finalResponse.output) }); parserInput = JSON.stringify({ output: jsonParse(finalResponse.output) });
} catch (error) { } catch (error) {
// If parsing of the output fails, we will use the raw output // Fallback to the raw output if parsing fails.
parserInput = finalResponse.output; parserInput = finalResponse.output;
} }
} else { } else {
@ -225,48 +258,174 @@ export async function toolsAgentExecute(this: IExecuteFunctions): Promise<INodeE
} }
const returnValues = (await outputParser.parse(parserInput)) as Record<string, unknown>; const returnValues = (await outputParser.parse(parserInput)) as Record<string, unknown>;
return handleParsedStepOutput(returnValues); return handleParsedStepOutput(returnValues, memory);
}
return handleAgentFinishOutput(steps);
} }
return handleAgentFinishOutput(steps);
};
/* -----------------------------------------------------------
Agent Setup Helpers
----------------------------------------------------------- */
/**
* Retrieves the language model from the input connection.
* Throws an error if the model is not a valid chat instance or does not support tools.
*
* @param ctx - The execution context
* @returns The validated chat model
*/
export async function getChatModel(ctx: IExecuteFunctions): Promise<BaseChatModel> {
const model = await ctx.getInputConnectionData(NodeConnectionType.AiLanguageModel, 0);
if (!isChatInstance(model) || !model.bindTools) {
throw new NodeOperationError(
ctx.getNode(),
'Tools Agent requires Chat Model which supports Tools calling',
);
}
return model;
}
/**
* Retrieves the memory instance from the input connection if it is connected
*
* @param ctx - The execution context
* @returns The connected memory (if any)
*/
export async function getOptionalMemory(
ctx: IExecuteFunctions,
): Promise<BaseChatMemory | undefined> {
return (await ctx.getInputConnectionData(NodeConnectionType.AiMemory, 0)) as
| BaseChatMemory
| undefined;
}
/**
* Retrieves the connected tools and (if an output parser is defined)
* appends a structured output parser tool.
*
* @param ctx - The execution context
* @param outputParser - The optional output parser
* @returns The array of connected tools
*/
export async function getTools(
ctx: IExecuteFunctions,
outputParser?: N8nOutputParser,
): Promise<Array<DynamicStructuredTool | Tool>> {
const tools = (await getConnectedTools(ctx, true, false)) as Array<DynamicStructuredTool | Tool>;
// If an output parser is available, create a dynamic tool to validate the final output.
if (outputParser) { if (outputParser) {
const schema = getOutputParserSchema(outputParser); const schema = getOutputParserSchema(outputParser);
structuredOutputParserTool = new DynamicStructuredTool({ const structuredOutputParserTool = new DynamicStructuredTool({
schema, schema,
name: 'format_final_response', name: 'format_final_response',
description: description:
'Always use this tool for the final output to the user. It validates the output so only use it when you are sure the output is final.', 'Always use this tool for the final output to the user. It validates the output so only use it when you are sure the output is final.',
// We will not use the function here as we will use the parser to intercept & parse the output in the agentStepsParser // We do not use a function here because we intercept the output with the parser.
func: async () => '', func: async () => '',
}); });
tools.push(structuredOutputParserTool); tools.push(structuredOutputParserTool);
} }
return tools;
}
const options = this.getNodeParameter('options', 0, {}) as { /**
* Prepares the prompt messages for the agent.
*
* @param ctx - The execution context
* @param itemIndex - The current item index
* @param options - Options containing systemMessage and other parameters
* @returns The array of prompt messages
*/
export async function prepareMessages(
ctx: IExecuteFunctions,
itemIndex: number,
options: {
systemMessage?: string; systemMessage?: string;
maxIterations?: number; passthroughBinaryImages?: boolean;
returnIntermediateSteps?: boolean; outputParser?: N8nOutputParser;
}; },
): Promise<BaseMessagePromptTemplateLike[]> {
const passthroughBinaryImages = this.getNodeParameter('options.passthroughBinaryImages', 0, true);
const messages: BaseMessagePromptTemplateLike[] = [ const messages: BaseMessagePromptTemplateLike[] = [
['system', `{system_message}${outputParser ? '\n\n{formatting_instructions}' : ''}`], ['system', `{system_message}${options.outputParser ? '\n\n{formatting_instructions}' : ''}`],
['placeholder', '{chat_history}'], ['placeholder', '{chat_history}'],
['human', '{input}'], ['human', '{input}'],
]; ];
const hasBinaryData = this.getInputData()?.[0]?.binary !== undefined; // If there is binary data and the node option permits it, add a binary message
if (hasBinaryData && passthroughBinaryImages) { const hasBinaryData = ctx.getInputData()?.[itemIndex]?.binary !== undefined;
const binaryMessage = await extractBinaryMessages(this); if (hasBinaryData && options.passthroughBinaryImages) {
const binaryMessage = await extractBinaryMessages(ctx, itemIndex);
messages.push(binaryMessage); messages.push(binaryMessage);
} }
// We add the agent scratchpad last, so that the agent will not run in loops // We add the agent scratchpad last, so that the agent will not run in loops
// by adding binary messages between each interaction // by adding binary messages between each interaction
messages.push(['placeholder', '{agent_scratchpad}']); messages.push(['placeholder', '{agent_scratchpad}']);
const prompt = ChatPromptTemplate.fromMessages(messages); return messages;
}
/**
* Creates the chat prompt from messages.
*
* @param messages - The messages array
* @returns The ChatPromptTemplate instance
*/
export function preparePrompt(messages: BaseMessagePromptTemplateLike[]): ChatPromptTemplate {
return ChatPromptTemplate.fromMessages(messages);
}
/* -----------------------------------------------------------
Main Executor Function
----------------------------------------------------------- */
/**
* The main executor method for the Tools Agent.
*
* This function retrieves necessary components (model, memory, tools), prepares the prompt,
* creates the agent, and processes each input item. The error handling for each item is also
* managed here based on the node's continueOnFail setting.
*
* @returns The array of execution data for all processed items
*/
export async function toolsAgentExecute(this: IExecuteFunctions): Promise<INodeExecutionData[][]> {
this.logger.debug('Executing Tools Agent');
const returnData: INodeExecutionData[] = [];
const items = this.getInputData();
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
try {
const model = await getChatModel(this);
const memory = await getOptionalMemory(this);
const outputParsers = await getOptionalOutputParsers(this);
const outputParser = outputParsers?.[0];
const tools = await getTools(this, outputParser);
const input = getPromptInputByType({
ctx: this,
i: itemIndex,
inputKey: 'text',
promptTypeKey: 'promptType',
});
if (input === undefined) {
throw new NodeOperationError(this.getNode(), 'The “text” parameter is empty.');
}
const options = this.getNodeParameter('options', itemIndex, {}) as {
systemMessage?: string;
maxIterations?: number;
returnIntermediateSteps?: boolean;
passthroughBinaryImages?: boolean;
};
// Prepare the prompt messages and prompt template.
const messages = await prepareMessages(this, itemIndex, {
systemMessage: options.systemMessage,
passthroughBinaryImages: options.passthroughBinaryImages ?? true,
outputParser,
});
const prompt = preparePrompt(messages);
// Create the base agent that calls tools.
const agent = createToolCallingAgent({ const agent = createToolCallingAgent({
llm: model, llm: model,
tools, tools,
@ -274,9 +433,12 @@ export async function toolsAgentExecute(this: IExecuteFunctions): Promise<INodeE
streamRunnable: false, streamRunnable: false,
}); });
agent.streamRunnable = false; agent.streamRunnable = false;
// Wrap the agent with parsers and fixes.
const runnableAgent = RunnableSequence.from([agent, agentStepsParser, fixEmptyContentMessage]); const runnableAgent = RunnableSequence.from([
agent,
getAgentStepsParser(outputParser, memory),
fixEmptyContentMessage,
]);
const executor = AgentExecutor.fromAgentAndTools({ const executor = AgentExecutor.fromAgentAndTools({
agent: runnableAgent, agent: runnableAgent,
memory, memory,
@ -284,29 +446,19 @@ export async function toolsAgentExecute(this: IExecuteFunctions): Promise<INodeE
returnIntermediateSteps: options.returnIntermediateSteps === true, returnIntermediateSteps: options.returnIntermediateSteps === true,
maxIterations: options.maxIterations ?? 10, maxIterations: options.maxIterations ?? 10,
}); });
const returnData: INodeExecutionData[] = [];
const items = this.getInputData(); // Invoke the executor with the given input and system message.
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) { const response = await executor.invoke(
try { {
const input = getPromptInputByType({
ctx: this,
i: itemIndex,
inputKey: 'text',
promptTypeKey: 'promptType',
});
if (input === undefined) {
throw new NodeOperationError(this.getNode(), 'The text parameter is empty.');
}
const response = await executor.invoke({
input, input,
system_message: options.systemMessage ?? SYSTEM_MESSAGE, system_message: options.systemMessage ?? SYSTEM_MESSAGE,
formatting_instructions: formatting_instructions:
'IMPORTANT: Always call `format_final_response` to format your final response!', 'IMPORTANT: Always call `format_final_response` to format your final response!',
}); },
{ signal: this.getExecutionCancelSignal() },
);
// If memory and outputParser are connected, parse the output.
if (memory && outputParser) { if (memory && outputParser) {
const parsedOutput = jsonParse<{ output: Record<string, unknown> }>( const parsedOutput = jsonParse<{ output: Record<string, unknown> }>(
response.output as string, response.output as string,
@ -314,7 +466,8 @@ export async function toolsAgentExecute(this: IExecuteFunctions): Promise<INodeE
response.output = parsedOutput?.output ?? parsedOutput; response.output = parsedOutput?.output ?? parsedOutput;
} }
returnData.push({ // Omit internal keys before returning the result.
const itemResult = {
json: omit( json: omit(
response, response,
'system_message', 'system_message',
@ -323,7 +476,9 @@ export async function toolsAgentExecute(this: IExecuteFunctions): Promise<INodeE
'chat_history', 'chat_history',
'agent_scratchpad', 'agent_scratchpad',
), ),
}); };
returnData.push(itemResult);
} catch (error) { } catch (error) {
if (this.continueOnFail()) { if (this.continueOnFail()) {
returnData.push({ returnData.push({
@ -332,7 +487,6 @@ export async function toolsAgentExecute(this: IExecuteFunctions): Promise<INodeE
}); });
continue; continue;
} }
throw error; throw error;
} }
} }

View file

@ -0,0 +1,273 @@
// ToolsAgent.test.ts
import type { BaseChatMemory } from '@langchain/community/memory/chat_memory';
import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { HumanMessage } from '@langchain/core/messages';
import type { BaseMessagePromptTemplateLike } from '@langchain/core/prompts';
import { FakeTool } from '@langchain/core/utils/testing';
import { Buffer } from 'buffer';
import { mock } from 'jest-mock-extended';
import type { ToolsAgentAction } from 'langchain/dist/agents/tool_calling/output_parser';
import type { Tool } from 'langchain/tools';
import type { IExecuteFunctions } from 'n8n-workflow';
import { NodeOperationError, BINARY_ENCODING } from 'n8n-workflow';
import type { ZodType } from 'zod';
import { z } from 'zod';
import * as helpersModule from '@utils/helpers';
import type { N8nOutputParser } from '@utils/output_parsers/N8nOutputParser';
import {
getOutputParserSchema,
extractBinaryMessages,
fixEmptyContentMessage,
handleParsedStepOutput,
getChatModel,
getOptionalMemory,
prepareMessages,
preparePrompt,
getTools,
} from '../agents/ToolsAgent/execute';
// We need to override the imported getConnectedTools so that we control its output.
jest.spyOn(helpersModule, 'getConnectedTools').mockResolvedValue([FakeTool as unknown as Tool]);
function getFakeOutputParser(returnSchema?: ZodType): N8nOutputParser {
const fakeOutputParser = mock<N8nOutputParser>();
(fakeOutputParser.getSchema as jest.Mock).mockReturnValue(returnSchema);
return fakeOutputParser;
}
function createFakeExecuteFunctions(overrides: Partial<IExecuteFunctions> = {}): IExecuteFunctions {
return {
getNodeParameter: jest
.fn()
.mockImplementation((_arg1: string, _arg2: number, defaultValue?: unknown) => {
return defaultValue;
}),
getNode: jest.fn().mockReturnValue({}),
getInputConnectionData: jest.fn().mockResolvedValue({}),
getInputData: jest.fn().mockReturnValue([]),
continueOnFail: jest.fn().mockReturnValue(false),
logger: { debug: jest.fn() },
helpers: {},
...overrides,
} as unknown as IExecuteFunctions;
}
describe('getOutputParserSchema', () => {
it('should return a default schema if getSchema returns undefined', () => {
const schema = getOutputParserSchema(getFakeOutputParser(undefined));
// The default schema requires a "text" field.
expect(() => schema.parse({})).toThrow();
expect(schema.parse({ text: 'hello' })).toEqual({ text: 'hello' });
});
it('should return the custom schema if provided', () => {
const customSchema = z.object({ custom: z.number() });
const schema = getOutputParserSchema(getFakeOutputParser(customSchema));
expect(() => schema.parse({ custom: 'not a number' })).toThrow();
expect(schema.parse({ custom: 123 })).toEqual({ custom: 123 });
});
});
describe('extractBinaryMessages', () => {
it('should extract a binary message from the input data when no id is provided', async () => {
const fakeItem = {
binary: {
img1: {
mimeType: 'image/png',
// simulate that data already includes 'base64'
data: 'data:image/png;base64,sampledata',
},
},
};
const ctx = createFakeExecuteFunctions({
getInputData: jest.fn().mockReturnValue([fakeItem]),
});
const humanMsg: HumanMessage = await extractBinaryMessages(ctx, 0);
// Expect the HumanMessage's content to be an array containing one binary message.
expect(Array.isArray(humanMsg.content)).toBe(true);
expect(humanMsg.content[0]).toEqual({
type: 'image_url',
image_url: { url: 'data:image/png;base64,sampledata' },
});
});
it('should extract a binary message using binary stream if id is provided', async () => {
const fakeItem = {
binary: {
img2: {
mimeType: 'image/jpeg',
id: '1234',
data: 'nonsense',
},
},
};
// Cast fakeHelpers as any to satisfy type requirements.
const fakeHelpers = {
getBinaryStream: jest.fn().mockResolvedValue('stream'),
binaryToBuffer: jest.fn().mockResolvedValue(Buffer.from('fakebufferdata')),
} as unknown as IExecuteFunctions['helpers'];
const ctx = createFakeExecuteFunctions({
getInputData: jest.fn().mockReturnValue([fakeItem]),
helpers: fakeHelpers,
});
const humanMsg: HumanMessage = await extractBinaryMessages(ctx, 0);
// eslint-disable-next-line @typescript-eslint/unbound-method
expect(fakeHelpers.getBinaryStream).toHaveBeenCalledWith('1234');
// eslint-disable-next-line @typescript-eslint/unbound-method
expect(fakeHelpers.binaryToBuffer).toHaveBeenCalled();
const expectedUrl = `data:image/jpeg;base64,${Buffer.from('fakebufferdata').toString(
BINARY_ENCODING,
)}`;
expect(humanMsg.content[0]).toEqual({
type: 'image_url',
image_url: { url: expectedUrl },
});
});
});
describe('fixEmptyContentMessage', () => {
it('should replace empty string inputs with empty objects', () => {
// Cast to any to bypass type issues with AgentFinish/AgentAction.
const fakeSteps: ToolsAgentAction[] = [
{
messageLog: [
{
content: [{ input: '' }, { input: { already: 'object' } }],
},
],
},
] as unknown as ToolsAgentAction[];
const fixed = fixEmptyContentMessage(fakeSteps) as ToolsAgentAction[];
const messageContent = fixed?.[0]?.messageLog?.[0].content;
// Type assertion needed since we're extending MessageContentComplex
expect((messageContent?.[0] as { input: unknown })?.input).toEqual({});
expect((messageContent?.[1] as { input: unknown })?.input).toEqual({ already: 'object' });
});
});
describe('handleParsedStepOutput', () => {
it('should stringify the output if memory is provided', () => {
const output = { key: 'value' };
const fakeMemory = mock<BaseChatMemory>();
const result = handleParsedStepOutput(output, fakeMemory);
expect(result.returnValues).toEqual({ output: JSON.stringify(output) });
expect(result.log).toEqual('Final response formatted');
});
it('should not stringify the output if memory is not provided', () => {
const output = { key: 'value' };
const result = handleParsedStepOutput(output);
expect(result.returnValues).toEqual(output);
});
});
describe('getChatModel', () => {
it('should return the model if it is a valid chat model', async () => {
// Cast fakeChatModel as any
const fakeChatModel = mock<BaseChatModel>();
fakeChatModel.bindTools = jest.fn();
fakeChatModel.lc_namespace = ['chat_models'];
const ctx = createFakeExecuteFunctions({
getInputConnectionData: jest.fn().mockResolvedValue(fakeChatModel),
});
const model = await getChatModel(ctx);
expect(model).toEqual(fakeChatModel);
});
it('should throw if the model is not a valid chat model', async () => {
const fakeInvalidModel = mock<BaseChatModel>(); // missing bindTools & lc_namespace
fakeInvalidModel.lc_namespace = [];
const ctx = createFakeExecuteFunctions({
getInputConnectionData: jest.fn().mockResolvedValue(fakeInvalidModel),
getNode: jest.fn().mockReturnValue({}),
});
await expect(getChatModel(ctx)).rejects.toThrow(NodeOperationError);
});
});
describe('getOptionalMemory', () => {
it('should return the memory if available', async () => {
const fakeMemory = { some: 'memory' };
const ctx = createFakeExecuteFunctions({
getInputConnectionData: jest.fn().mockResolvedValue(fakeMemory),
});
const memory = await getOptionalMemory(ctx);
expect(memory).toEqual(fakeMemory);
});
});
describe('getTools', () => {
it('should retrieve tools without appending if outputParser is not provided', async () => {
const ctx = createFakeExecuteFunctions();
const tools = await getTools(ctx);
expect(tools.length).toEqual(1);
});
it('should retrieve tools and append the structured output parser tool if outputParser is provided', async () => {
const fakeOutputParser = getFakeOutputParser(z.object({ text: z.string() }));
const ctx = createFakeExecuteFunctions();
const tools = await getTools(ctx, fakeOutputParser);
// Our fake getConnectedTools returns one tool; with outputParser, one extra is appended.
expect(tools.length).toEqual(2);
const dynamicTool = tools.find((t) => t.name === 'format_final_response');
expect(dynamicTool).toBeDefined();
});
});
describe('prepareMessages', () => {
it('should include a binary message if binary data is present and passthroughBinaryImages is true', async () => {
const fakeItem = {
binary: {
img1: {
mimeType: 'image/png',
data: 'data:image/png;base64,sampledata',
},
},
};
const ctx = createFakeExecuteFunctions({
getInputData: jest.fn().mockReturnValue([fakeItem]),
});
const messages = await prepareMessages(ctx, 0, {
systemMessage: 'Test system',
passthroughBinaryImages: true,
});
// Check if any message is an instance of HumanMessage
const hasBinaryMessage = messages.some(
(m) => typeof m === 'object' && m instanceof HumanMessage,
);
expect(hasBinaryMessage).toBe(true);
});
it('should not include a binary message if no binary data is present', async () => {
const fakeItem = { json: {} }; // no binary key
const ctx = createFakeExecuteFunctions({
getInputData: jest.fn().mockReturnValue([fakeItem]),
});
const messages = await prepareMessages(ctx, 0, {
systemMessage: 'Test system',
passthroughBinaryImages: true,
});
const hasHumanMessage = messages.some((m) => m instanceof HumanMessage);
expect(hasHumanMessage).toBe(false);
});
});
describe('preparePrompt', () => {
it('should return a ChatPromptTemplate instance', () => {
const sampleMessages: BaseMessagePromptTemplateLike[] = [
['system', 'Test'],
['human', 'Hello'],
];
const prompt = preparePrompt(sampleMessages);
expect(prompt).toBeDefined();
});
});

View file

@ -524,16 +524,16 @@ export class ChainLlm implements INodeType {
const items = this.getInputData(); const items = this.getInputData();
const returnData: INodeExecutionData[] = []; const returnData: INodeExecutionData[] = [];
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
try {
let prompt: string;
const llm = (await this.getInputConnectionData( const llm = (await this.getInputConnectionData(
NodeConnectionType.AiLanguageModel, NodeConnectionType.AiLanguageModel,
0, 0,
)) as BaseLanguageModel; )) as BaseLanguageModel;
const outputParsers = await getOptionalOutputParsers(this); const outputParsers = await getOptionalOutputParsers(this);
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
try {
let prompt: string;
if (this.getNode().typeVersion <= 1.3) { if (this.getNode().typeVersion <= 1.3) {
prompt = this.getNodeParameter('prompt', itemIndex) as string; prompt = this.getNodeParameter('prompt', itemIndex) as string;
} else { } else {

View file

@ -163,6 +163,11 @@ export class ChainRetrievalQa implements INodeType {
async execute(this: IExecuteFunctions): Promise<INodeExecutionData[][]> { async execute(this: IExecuteFunctions): Promise<INodeExecutionData[][]> {
this.logger.debug('Executing Retrieval QA Chain'); this.logger.debug('Executing Retrieval QA Chain');
const items = this.getInputData();
const returnData: INodeExecutionData[] = [];
// Run for each item
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
try {
const model = (await this.getInputConnectionData( const model = (await this.getInputConnectionData(
NodeConnectionType.AiLanguageModel, NodeConnectionType.AiLanguageModel,
0, 0,
@ -173,13 +178,6 @@ export class ChainRetrievalQa implements INodeType {
0, 0,
)) as BaseRetriever; )) as BaseRetriever;
const items = this.getInputData();
const returnData: INodeExecutionData[] = [];
// Run for each item
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
try {
let query; let query;
if (this.getNode().typeVersion <= 1.2) { if (this.getNode().typeVersion <= 1.2) {
@ -226,7 +224,9 @@ export class ChainRetrievalQa implements INodeType {
const chain = RetrievalQAChain.fromLLM(model, retriever, chainParameters); const chain = RetrievalQAChain.fromLLM(model, retriever, chainParameters);
const response = await chain.withConfig(getTracingConfig(this)).invoke({ query }); const response = await chain
.withConfig(getTracingConfig(this))
.invoke({ query }, { signal: this.getExecutionCancelSignal() });
returnData.push({ json: { response } }); returnData.push({ json: { response } });
} catch (error) { } catch (error) {
if (this.continueOnFail()) { if (this.continueOnFail()) {

View file

@ -321,16 +321,16 @@ export class ChainSummarizationV2 implements INodeType {
| 'simple' | 'simple'
| 'advanced'; | 'advanced';
const model = (await this.getInputConnectionData(
NodeConnectionType.AiLanguageModel,
0,
)) as BaseLanguageModel;
const items = this.getInputData(); const items = this.getInputData();
const returnData: INodeExecutionData[] = []; const returnData: INodeExecutionData[] = [];
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) { for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
try { try {
const model = (await this.getInputConnectionData(
NodeConnectionType.AiLanguageModel,
0,
)) as BaseLanguageModel;
const summarizationMethodAndPrompts = this.getNodeParameter( const summarizationMethodAndPrompts = this.getNodeParameter(
'options.summarizationMethodAndPrompts.values', 'options.summarizationMethodAndPrompts.values',
itemIndex, itemIndex,
@ -411,9 +411,12 @@ export class ChainSummarizationV2 implements INodeType {
} }
const processedItem = await processor.processItem(item, itemIndex); const processedItem = await processor.processItem(item, itemIndex);
const response = await chain.call({ const response = await chain.invoke(
{
input_documents: processedItem, input_documents: processedItem,
}); },
{ signal: this.getExecutionCancelSignal() },
);
returnData.push({ json: { response } }); returnData.push({ json: { response } });
} }
} catch (error) { } catch (error) {