feat(AI Agent Node): Implement Tool calling agent (#9339)

Signed-off-by: Oleg Ivaniv <me@olegivaniv.com>
This commit is contained in:
oleg 2024-05-15 12:02:21 +02:00 committed by GitHub
parent 1081429a4d
commit 677f534661
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 344 additions and 75 deletions

View file

@ -7,6 +7,7 @@ import type {
INodeExecutionData,
INodeType,
INodeTypeDescription,
INodeProperties,
} from 'n8n-workflow';
import { getTemplateNoticeField } from '../../../utils/sharedFields';
import { promptTypeOptions, textInput } from '../../../utils/descriptions';
@ -20,11 +21,13 @@ import { reActAgentAgentProperties } from './agents/ReActAgent/description';
import { reActAgentAgentExecute } from './agents/ReActAgent/execute';
import { sqlAgentAgentProperties } from './agents/SqlAgent/description';
import { sqlAgentAgentExecute } from './agents/SqlAgent/execute';
import { toolsAgentProperties } from './agents/ToolsAgent/description';
import { toolsAgentExecute } from './agents/ToolsAgent/execute';
// Function used in the inputs expression to figure out which inputs to
// display based on the agent type
function getInputs(
agent: 'conversationalAgent' | 'openAiFunctionsAgent' | 'reActAgent' | 'sqlAgent',
agent: 'toolsAgent' | 'conversationalAgent' | 'openAiFunctionsAgent' | 'reActAgent' | 'sqlAgent',
hasOutputParser?: boolean,
): Array<ConnectionTypes | INodeInputConfiguration> {
interface SpecialInput {
@ -92,6 +95,31 @@ function getInputs(
type: NodeConnectionType.AiOutputParser,
},
];
} else if (agent === 'toolsAgent') {
specialInputs = [
{
type: NodeConnectionType.AiLanguageModel,
filter: {
nodes: [
'@n8n/n8n-nodes-langchain.lmChatAnthropic',
'@n8n/n8n-nodes-langchain.lmChatAzureOpenAi',
'@n8n/n8n-nodes-langchain.lmChatMistralCloud',
'@n8n/n8n-nodes-langchain.lmChatOpenAi',
'@n8n/n8n-nodes-langchain.lmChatGroq',
],
},
},
{
type: NodeConnectionType.AiMemory,
},
{
type: NodeConnectionType.AiTool,
required: true,
},
{
type: NodeConnectionType.AiOutputParser,
},
];
} else if (agent === 'openAiFunctionsAgent') {
specialInputs = [
{
@ -157,16 +185,60 @@ function getInputs(
return [NodeConnectionType.Main, ...getInputData(specialInputs)];
}
const agentTypeProperty: INodeProperties = {
displayName: 'Agent',
name: 'agent',
type: 'options',
noDataExpression: true,
options: [
{
name: 'Conversational Agent',
value: 'conversationalAgent',
description:
'Selects tools to accomplish its task and uses memory to recall previous conversations',
},
{
name: 'OpenAI Functions Agent',
value: 'openAiFunctionsAgent',
description:
"Utilizes OpenAI's Function Calling feature to select the appropriate tool and arguments for execution",
},
{
name: 'Plan and Execute Agent',
value: 'planAndExecuteAgent',
description:
'Plan and execute agents accomplish an objective by first planning what to do, then executing the sub tasks',
},
{
name: 'ReAct Agent',
value: 'reActAgent',
description: 'Strategically select tools to accomplish a given task',
},
{
name: 'SQL Agent',
value: 'sqlAgent',
description: 'Answers questions about data in an SQL database',
},
{
name: 'Tools Agent',
value: 'toolsAgent',
description:
'Utilized unified Tool calling interface to select the appropriate tools and argument for execution',
},
],
default: '',
};
export class Agent implements INodeType {
description: INodeTypeDescription = {
displayName: 'AI Agent',
name: 'agent',
icon: 'fa:robot',
group: ['transform'],
version: [1, 1.1, 1.2, 1.3, 1.4, 1.5],
version: [1, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6],
description: 'Generates an action plan and executes it. Can use external tools.',
subtitle:
"={{ { conversationalAgent: 'Conversational Agent', openAiFunctionsAgent: 'OpenAI Functions Agent', reActAgent: 'ReAct Agent', sqlAgent: 'SQL Agent', planAndExecuteAgent: 'Plan and Execute Agent' }[$parameter.agent] }}",
"={{ { toolsAgent: 'Tools Agent', conversationalAgent: 'Conversational Agent', openAiFunctionsAgent: 'OpenAI Functions Agent', reActAgent: 'ReAct Agent', sqlAgent: 'SQL Agent', planAndExecuteAgent: 'Plan and Execute Agent' }[$parameter.agent] }}",
defaults: {
name: 'AI Agent',
color: '#404040',
@ -225,43 +297,18 @@ export class Agent implements INodeType {
},
},
},
// Make Conversational Agent the default agent for versions 1.5 and below
{
displayName: 'Agent',
name: 'agent',
type: 'options',
noDataExpression: true,
options: [
{
name: 'Conversational Agent',
value: 'conversationalAgent',
description:
'Selects tools to accomplish its task and uses memory to recall previous conversations',
},
{
name: 'OpenAI Functions Agent',
value: 'openAiFunctionsAgent',
description:
"Utilizes OpenAI's Function Calling feature to select the appropriate tool and arguments for execution",
},
{
name: 'Plan and Execute Agent',
value: 'planAndExecuteAgent',
description:
'Plan and execute agents accomplish an objective by first planning what to do, then executing the sub tasks',
},
{
name: 'ReAct Agent',
value: 'reActAgent',
description: 'Strategically select tools to accomplish a given task',
},
{
name: 'SQL Agent',
value: 'sqlAgent',
description: 'Answers questions about data in an SQL database',
},
],
...agentTypeProperty,
displayOptions: { show: { '@version': [{ _cnd: { lte: 1.5 } }] } },
default: 'conversationalAgent',
},
// Make Tools Agent the default agent for versions 1.6 and above
{
...agentTypeProperty,
displayOptions: { show: { '@version': [{ _cnd: { gte: 1.6 } }] } },
default: 'toolsAgent',
},
{
...promptTypeOptions,
displayOptions: {
@ -307,6 +354,7 @@ export class Agent implements INodeType {
},
},
...toolsAgentProperties,
...conversationalAgentProperties,
...openAiFunctionsAgentProperties,
...reActAgentAgentProperties,
@ -321,6 +369,8 @@ export class Agent implements INodeType {
if (agentType === 'conversationalAgent') {
return await conversationalAgentExecute.call(this, nodeVersion);
} else if (agentType === 'toolsAgent') {
return await toolsAgentExecute.call(this, nodeVersion);
} else if (agentType === 'openAiFunctionsAgent') {
return await openAiFunctionsAgentExecute.call(this, nodeVersion);
} else if (agentType === 'reActAgent') {

View file

@ -0,0 +1,43 @@
import type { INodeProperties } from 'n8n-workflow';
import { SYSTEM_MESSAGE } from './prompt';
export const toolsAgentProperties: INodeProperties[] = [
{
displayName: 'Options',
name: 'options',
type: 'collection',
displayOptions: {
show: {
agent: ['toolsAgent'],
},
},
default: {},
placeholder: 'Add Option',
options: [
{
displayName: 'System Message',
name: 'systemMessage',
type: 'string',
default: SYSTEM_MESSAGE,
description: 'The message that will be sent to the agent before the conversation starts',
typeOptions: {
rows: 6,
},
},
{
displayName: 'Max Iterations',
name: 'maxIterations',
type: 'number',
default: 10,
description: 'The maximum number of iterations the agent will run before stopping',
},
{
displayName: 'Return Intermediate Steps',
name: 'returnIntermediateSteps',
type: 'boolean',
default: false,
description: 'Whether or not the output should include intermediate steps the agent took',
},
],
},
];

View file

@ -0,0 +1,189 @@
import { NodeConnectionType, NodeOperationError } from 'n8n-workflow';
import type { IExecuteFunctions, INodeExecutionData } from 'n8n-workflow';
import type { AgentAction, AgentFinish, AgentStep } from 'langchain/agents';
import { AgentExecutor, createToolCallingAgent } from 'langchain/agents';
import type { BaseChatMemory } from '@langchain/community/memory/chat_memory';
import { ChatPromptTemplate } from '@langchain/core/prompts';
import { omit } from 'lodash';
import type { Tool } from '@langchain/core/tools';
import { DynamicStructuredTool } from '@langchain/core/tools';
import { RunnableSequence } from '@langchain/core/runnables';
import type { ZodObject } from 'zod';
import { z } from 'zod';
import type { BaseOutputParser, StructuredOutputParser } from '@langchain/core/output_parsers';
import { OutputFixingParser } from 'langchain/output_parsers';
import {
isChatInstance,
getPromptInputByType,
getOptionalOutputParsers,
getConnectedTools,
} from '../../../../../utils/helpers';
import { SYSTEM_MESSAGE } from './prompt';
function getOutputParserSchema(outputParser: BaseOutputParser): ZodObject<any, any, any, any> {
const parserType = outputParser.lc_namespace[outputParser.lc_namespace.length - 1];
let schema: ZodObject<any, any, any, any>;
if (parserType === 'structured') {
// If the output parser is a structured output parser, we will use the schema from the parser
schema = (outputParser as StructuredOutputParser<ZodObject<any, any, any, any>>).schema;
} else if (parserType === 'fix' && outputParser instanceof OutputFixingParser) {
// If the output parser is a fixing parser, we will use the schema from the connected structured output parser
schema = (outputParser.parser as StructuredOutputParser<ZodObject<any, any, any, any>>).schema;
} else {
// If the output parser is not a structured output parser, we will use a fallback schema
schema = z.object({ text: z.string() });
}
return schema;
}
export async function toolsAgentExecute(
this: IExecuteFunctions,
nodeVersion: number,
): Promise<INodeExecutionData[][]> {
this.logger.verbose('Executing Tools Agent');
const model = await this.getInputConnectionData(NodeConnectionType.AiLanguageModel, 0);
if (!isChatInstance(model) || !model.bindTools) {
throw new NodeOperationError(
this.getNode(),
'Tools Agent requires Chat Model which supports Tools calling',
);
}
const memory = (await this.getInputConnectionData(NodeConnectionType.AiMemory, 0)) as
| BaseChatMemory
| undefined;
const tools = (await getConnectedTools(this, true)) as Array<DynamicStructuredTool | Tool>;
const outputParser = (await getOptionalOutputParsers(this))?.[0];
let structuredOutputParserTool: DynamicStructuredTool | undefined;
async function agentStepsParser(
steps: AgentFinish | AgentAction[],
): Promise<AgentFinish | AgentAction[]> {
if (Array.isArray(steps)) {
const responseParserTool = steps.find((step) => step.tool === 'format_final_response');
if (responseParserTool) {
const toolInput = responseParserTool?.toolInput;
const returnValues = (await outputParser.parse(toolInput as unknown as string)) as Record<
string,
unknown
>;
return {
returnValues,
log: 'Final response formatted',
};
}
}
// If the steps are an AgentFinish and the outputParser is defined it must mean that the LLM didn't use `format_final_response` tool so we will parse the output manually
if (outputParser && typeof steps === 'object' && (steps as AgentFinish).returnValues) {
const finalResponse = (steps as AgentFinish).returnValues;
const returnValues = (await outputParser.parse(finalResponse as unknown as string)) as Record<
string,
unknown
>;
return {
returnValues,
log: 'Final response formatted',
};
}
return steps;
}
if (outputParser) {
const schema = getOutputParserSchema(outputParser);
structuredOutputParserTool = new DynamicStructuredTool({
schema,
name: 'format_final_response',
description:
'Always use this tool for the final output to the user. It validates the output so only use it when you are sure the output is final.',
// We will not use the function here as we will use the parser to intercept & parse the output in the agentStepsParser
func: async () => '',
});
tools.push(structuredOutputParserTool);
}
const options = this.getNodeParameter('options', 0, {}) as {
systemMessage?: string;
maxIterations?: number;
returnIntermediateSteps?: boolean;
};
const prompt = ChatPromptTemplate.fromMessages([
['system', `{system_message}${outputParser ? '\n\n{formatting_instructions}' : ''}`],
['placeholder', '{chat_history}'],
['human', '{input}'],
['placeholder', '{agent_scratchpad}'],
]);
const agent = createToolCallingAgent({
llm: model,
tools,
prompt,
streamRunnable: false,
});
agent.streamRunnable = false;
const runnableAgent = RunnableSequence.from<{
steps: AgentStep[];
}>([agent, agentStepsParser]);
const executor = AgentExecutor.fromAgentAndTools({
agent: runnableAgent,
memory,
tools,
returnIntermediateSteps: options.returnIntermediateSteps === true,
maxIterations: options.maxIterations ?? 10,
});
const returnData: INodeExecutionData[] = [];
const items = this.getInputData();
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
try {
const input = getPromptInputByType({
ctx: this,
i: itemIndex,
inputKey: 'text',
promptTypeKey: 'promptType',
});
if (input === undefined) {
throw new NodeOperationError(this.getNode(), 'The text parameter is empty.');
}
const response = await executor.invoke({
input,
system_message: options.systemMessage ?? SYSTEM_MESSAGE,
formatting_instructions:
'IMPORTANT: Always call `format_final_response` to format your final response!', //outputParser?.getFormatInstructions(),
});
returnData.push({
json: omit(
response,
'system_message',
'formatting_instructions',
'input',
'chat_history',
'agent_scratchpad',
),
});
} catch (error) {
if (this.continueOnFail()) {
returnData.push({ json: { error: error.message }, pairedItem: { item: itemIndex } });
continue;
}
throw error;
}
}
return await this.prepareOutputData(returnData);
}

View file

@ -0,0 +1 @@
export const SYSTEM_MESSAGE = 'You are a helpful assistant';

View file

@ -13,7 +13,6 @@ import type { Document } from '@langchain/core/documents';
import { TextSplitter } from 'langchain/text_splitter';
import { BaseChatMemory } from '@langchain/community/memory/chat_memory';
import { BaseRetriever } from '@langchain/core/retrievers';
import type { FormatInstructionsOptions } from '@langchain/core/output_parsers';
import { BaseOutputParser, OutputParserException } from '@langchain/core/output_parsers';
import { isObject } from 'lodash';
import type { BaseDocumentLoader } from 'langchain/dist/document_loaders/base';
@ -222,31 +221,7 @@ export function logWrapper(
// ========== BaseOutputParser ==========
if (originalInstance instanceof BaseOutputParser) {
if (prop === 'getFormatInstructions' && 'getFormatInstructions' in target) {
return (options?: FormatInstructionsOptions): string => {
connectionType = NodeConnectionType.AiOutputParser;
const { index } = executeFunctions.addInputData(connectionType, [
[{ json: { action: 'getFormatInstructions' } }],
]);
// @ts-ignore
const response = callMethodSync.call(target, {
executeFunctions,
connectionType,
currentNodeRunIndex: index,
method: target[prop],
arguments: [options],
}) as string;
executeFunctions.addOutputData(connectionType, index, [
[{ json: { action: 'getFormatInstructions', response } }],
]);
void logAiEvent(executeFunctions, 'n8n.ai.output.parser.get.instructions', {
response,
});
return response;
};
} else if (prop === 'parse' && 'parse' in target) {
if (prop === 'parse' && 'parse' in target) {
return async (text: string | Record<string, unknown>): Promise<unknown> => {
connectionType = NodeConnectionType.AiOutputParser;
const stringifiedText = isObject(text) ? JSON.stringify(text) : text;
@ -254,19 +229,30 @@ export function logWrapper(
[{ json: { action: 'parse', text: stringifiedText } }],
]);
const response = (await callMethodAsync.call(target, {
executeFunctions,
connectionType,
currentNodeRunIndex: index,
method: target[prop],
arguments: [stringifiedText],
})) as object;
try {
const response = (await callMethodAsync.call(target, {
executeFunctions,
connectionType,
currentNodeRunIndex: index,
method: target[prop],
arguments: [stringifiedText],
})) as object;
void logAiEvent(executeFunctions, 'n8n.ai.output.parser.parsed', { text, response });
executeFunctions.addOutputData(connectionType, index, [
[{ json: { action: 'parse', response } }],
]);
return response;
void logAiEvent(executeFunctions, 'n8n.ai.output.parser.parsed', { text, response });
executeFunctions.addOutputData(connectionType, index, [
[{ json: { action: 'parse', response } }],
]);
return response;
} catch (error) {
void logAiEvent(executeFunctions, 'n8n.ai.output.parser.parsed', {
text,
response: error.message ?? error,
});
executeFunctions.addOutputData(connectionType, index, [
[{ json: { action: 'parse', response: error.message ?? error } }],
]);
throw error;
}
};
}
}