mirror of
https://github.com/n8n-io/n8n.git
synced 2024-11-09 22:24:05 -08:00
feat(AI Agent Node): Implement Tool calling agent (#9339)
Signed-off-by: Oleg Ivaniv <me@olegivaniv.com>
This commit is contained in:
parent
1081429a4d
commit
677f534661
|
@ -7,6 +7,7 @@ import type {
|
|||
INodeExecutionData,
|
||||
INodeType,
|
||||
INodeTypeDescription,
|
||||
INodeProperties,
|
||||
} from 'n8n-workflow';
|
||||
import { getTemplateNoticeField } from '../../../utils/sharedFields';
|
||||
import { promptTypeOptions, textInput } from '../../../utils/descriptions';
|
||||
|
@ -20,11 +21,13 @@ import { reActAgentAgentProperties } from './agents/ReActAgent/description';
|
|||
import { reActAgentAgentExecute } from './agents/ReActAgent/execute';
|
||||
import { sqlAgentAgentProperties } from './agents/SqlAgent/description';
|
||||
import { sqlAgentAgentExecute } from './agents/SqlAgent/execute';
|
||||
import { toolsAgentProperties } from './agents/ToolsAgent/description';
|
||||
import { toolsAgentExecute } from './agents/ToolsAgent/execute';
|
||||
|
||||
// Function used in the inputs expression to figure out which inputs to
|
||||
// display based on the agent type
|
||||
function getInputs(
|
||||
agent: 'conversationalAgent' | 'openAiFunctionsAgent' | 'reActAgent' | 'sqlAgent',
|
||||
agent: 'toolsAgent' | 'conversationalAgent' | 'openAiFunctionsAgent' | 'reActAgent' | 'sqlAgent',
|
||||
hasOutputParser?: boolean,
|
||||
): Array<ConnectionTypes | INodeInputConfiguration> {
|
||||
interface SpecialInput {
|
||||
|
@ -92,6 +95,31 @@ function getInputs(
|
|||
type: NodeConnectionType.AiOutputParser,
|
||||
},
|
||||
];
|
||||
} else if (agent === 'toolsAgent') {
|
||||
specialInputs = [
|
||||
{
|
||||
type: NodeConnectionType.AiLanguageModel,
|
||||
filter: {
|
||||
nodes: [
|
||||
'@n8n/n8n-nodes-langchain.lmChatAnthropic',
|
||||
'@n8n/n8n-nodes-langchain.lmChatAzureOpenAi',
|
||||
'@n8n/n8n-nodes-langchain.lmChatMistralCloud',
|
||||
'@n8n/n8n-nodes-langchain.lmChatOpenAi',
|
||||
'@n8n/n8n-nodes-langchain.lmChatGroq',
|
||||
],
|
||||
},
|
||||
},
|
||||
{
|
||||
type: NodeConnectionType.AiMemory,
|
||||
},
|
||||
{
|
||||
type: NodeConnectionType.AiTool,
|
||||
required: true,
|
||||
},
|
||||
{
|
||||
type: NodeConnectionType.AiOutputParser,
|
||||
},
|
||||
];
|
||||
} else if (agent === 'openAiFunctionsAgent') {
|
||||
specialInputs = [
|
||||
{
|
||||
|
@ -157,16 +185,60 @@ function getInputs(
|
|||
return [NodeConnectionType.Main, ...getInputData(specialInputs)];
|
||||
}
|
||||
|
||||
const agentTypeProperty: INodeProperties = {
|
||||
displayName: 'Agent',
|
||||
name: 'agent',
|
||||
type: 'options',
|
||||
noDataExpression: true,
|
||||
options: [
|
||||
{
|
||||
name: 'Conversational Agent',
|
||||
value: 'conversationalAgent',
|
||||
description:
|
||||
'Selects tools to accomplish its task and uses memory to recall previous conversations',
|
||||
},
|
||||
{
|
||||
name: 'OpenAI Functions Agent',
|
||||
value: 'openAiFunctionsAgent',
|
||||
description:
|
||||
"Utilizes OpenAI's Function Calling feature to select the appropriate tool and arguments for execution",
|
||||
},
|
||||
{
|
||||
name: 'Plan and Execute Agent',
|
||||
value: 'planAndExecuteAgent',
|
||||
description:
|
||||
'Plan and execute agents accomplish an objective by first planning what to do, then executing the sub tasks',
|
||||
},
|
||||
{
|
||||
name: 'ReAct Agent',
|
||||
value: 'reActAgent',
|
||||
description: 'Strategically select tools to accomplish a given task',
|
||||
},
|
||||
{
|
||||
name: 'SQL Agent',
|
||||
value: 'sqlAgent',
|
||||
description: 'Answers questions about data in an SQL database',
|
||||
},
|
||||
{
|
||||
name: 'Tools Agent',
|
||||
value: 'toolsAgent',
|
||||
description:
|
||||
'Utilized unified Tool calling interface to select the appropriate tools and argument for execution',
|
||||
},
|
||||
],
|
||||
default: '',
|
||||
};
|
||||
|
||||
export class Agent implements INodeType {
|
||||
description: INodeTypeDescription = {
|
||||
displayName: 'AI Agent',
|
||||
name: 'agent',
|
||||
icon: 'fa:robot',
|
||||
group: ['transform'],
|
||||
version: [1, 1.1, 1.2, 1.3, 1.4, 1.5],
|
||||
version: [1, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6],
|
||||
description: 'Generates an action plan and executes it. Can use external tools.',
|
||||
subtitle:
|
||||
"={{ { conversationalAgent: 'Conversational Agent', openAiFunctionsAgent: 'OpenAI Functions Agent', reActAgent: 'ReAct Agent', sqlAgent: 'SQL Agent', planAndExecuteAgent: 'Plan and Execute Agent' }[$parameter.agent] }}",
|
||||
"={{ { toolsAgent: 'Tools Agent', conversationalAgent: 'Conversational Agent', openAiFunctionsAgent: 'OpenAI Functions Agent', reActAgent: 'ReAct Agent', sqlAgent: 'SQL Agent', planAndExecuteAgent: 'Plan and Execute Agent' }[$parameter.agent] }}",
|
||||
defaults: {
|
||||
name: 'AI Agent',
|
||||
color: '#404040',
|
||||
|
@ -225,43 +297,18 @@ export class Agent implements INodeType {
|
|||
},
|
||||
},
|
||||
},
|
||||
// Make Conversational Agent the default agent for versions 1.5 and below
|
||||
{
|
||||
displayName: 'Agent',
|
||||
name: 'agent',
|
||||
type: 'options',
|
||||
noDataExpression: true,
|
||||
options: [
|
||||
{
|
||||
name: 'Conversational Agent',
|
||||
value: 'conversationalAgent',
|
||||
description:
|
||||
'Selects tools to accomplish its task and uses memory to recall previous conversations',
|
||||
},
|
||||
{
|
||||
name: 'OpenAI Functions Agent',
|
||||
value: 'openAiFunctionsAgent',
|
||||
description:
|
||||
"Utilizes OpenAI's Function Calling feature to select the appropriate tool and arguments for execution",
|
||||
},
|
||||
{
|
||||
name: 'Plan and Execute Agent',
|
||||
value: 'planAndExecuteAgent',
|
||||
description:
|
||||
'Plan and execute agents accomplish an objective by first planning what to do, then executing the sub tasks',
|
||||
},
|
||||
{
|
||||
name: 'ReAct Agent',
|
||||
value: 'reActAgent',
|
||||
description: 'Strategically select tools to accomplish a given task',
|
||||
},
|
||||
{
|
||||
name: 'SQL Agent',
|
||||
value: 'sqlAgent',
|
||||
description: 'Answers questions about data in an SQL database',
|
||||
},
|
||||
],
|
||||
...agentTypeProperty,
|
||||
displayOptions: { show: { '@version': [{ _cnd: { lte: 1.5 } }] } },
|
||||
default: 'conversationalAgent',
|
||||
},
|
||||
// Make Tools Agent the default agent for versions 1.6 and above
|
||||
{
|
||||
...agentTypeProperty,
|
||||
displayOptions: { show: { '@version': [{ _cnd: { gte: 1.6 } }] } },
|
||||
default: 'toolsAgent',
|
||||
},
|
||||
{
|
||||
...promptTypeOptions,
|
||||
displayOptions: {
|
||||
|
@ -307,6 +354,7 @@ export class Agent implements INodeType {
|
|||
},
|
||||
},
|
||||
|
||||
...toolsAgentProperties,
|
||||
...conversationalAgentProperties,
|
||||
...openAiFunctionsAgentProperties,
|
||||
...reActAgentAgentProperties,
|
||||
|
@ -321,6 +369,8 @@ export class Agent implements INodeType {
|
|||
|
||||
if (agentType === 'conversationalAgent') {
|
||||
return await conversationalAgentExecute.call(this, nodeVersion);
|
||||
} else if (agentType === 'toolsAgent') {
|
||||
return await toolsAgentExecute.call(this, nodeVersion);
|
||||
} else if (agentType === 'openAiFunctionsAgent') {
|
||||
return await openAiFunctionsAgentExecute.call(this, nodeVersion);
|
||||
} else if (agentType === 'reActAgent') {
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
import type { INodeProperties } from 'n8n-workflow';
|
||||
import { SYSTEM_MESSAGE } from './prompt';
|
||||
|
||||
export const toolsAgentProperties: INodeProperties[] = [
|
||||
{
|
||||
displayName: 'Options',
|
||||
name: 'options',
|
||||
type: 'collection',
|
||||
displayOptions: {
|
||||
show: {
|
||||
agent: ['toolsAgent'],
|
||||
},
|
||||
},
|
||||
default: {},
|
||||
placeholder: 'Add Option',
|
||||
options: [
|
||||
{
|
||||
displayName: 'System Message',
|
||||
name: 'systemMessage',
|
||||
type: 'string',
|
||||
default: SYSTEM_MESSAGE,
|
||||
description: 'The message that will be sent to the agent before the conversation starts',
|
||||
typeOptions: {
|
||||
rows: 6,
|
||||
},
|
||||
},
|
||||
{
|
||||
displayName: 'Max Iterations',
|
||||
name: 'maxIterations',
|
||||
type: 'number',
|
||||
default: 10,
|
||||
description: 'The maximum number of iterations the agent will run before stopping',
|
||||
},
|
||||
{
|
||||
displayName: 'Return Intermediate Steps',
|
||||
name: 'returnIntermediateSteps',
|
||||
type: 'boolean',
|
||||
default: false,
|
||||
description: 'Whether or not the output should include intermediate steps the agent took',
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
|
@ -0,0 +1,189 @@
|
|||
import { NodeConnectionType, NodeOperationError } from 'n8n-workflow';
|
||||
import type { IExecuteFunctions, INodeExecutionData } from 'n8n-workflow';
|
||||
|
||||
import type { AgentAction, AgentFinish, AgentStep } from 'langchain/agents';
|
||||
import { AgentExecutor, createToolCallingAgent } from 'langchain/agents';
|
||||
import type { BaseChatMemory } from '@langchain/community/memory/chat_memory';
|
||||
import { ChatPromptTemplate } from '@langchain/core/prompts';
|
||||
import { omit } from 'lodash';
|
||||
import type { Tool } from '@langchain/core/tools';
|
||||
import { DynamicStructuredTool } from '@langchain/core/tools';
|
||||
import { RunnableSequence } from '@langchain/core/runnables';
|
||||
import type { ZodObject } from 'zod';
|
||||
import { z } from 'zod';
|
||||
import type { BaseOutputParser, StructuredOutputParser } from '@langchain/core/output_parsers';
|
||||
import { OutputFixingParser } from 'langchain/output_parsers';
|
||||
import {
|
||||
isChatInstance,
|
||||
getPromptInputByType,
|
||||
getOptionalOutputParsers,
|
||||
getConnectedTools,
|
||||
} from '../../../../../utils/helpers';
|
||||
import { SYSTEM_MESSAGE } from './prompt';
|
||||
|
||||
function getOutputParserSchema(outputParser: BaseOutputParser): ZodObject<any, any, any, any> {
|
||||
const parserType = outputParser.lc_namespace[outputParser.lc_namespace.length - 1];
|
||||
let schema: ZodObject<any, any, any, any>;
|
||||
|
||||
if (parserType === 'structured') {
|
||||
// If the output parser is a structured output parser, we will use the schema from the parser
|
||||
schema = (outputParser as StructuredOutputParser<ZodObject<any, any, any, any>>).schema;
|
||||
} else if (parserType === 'fix' && outputParser instanceof OutputFixingParser) {
|
||||
// If the output parser is a fixing parser, we will use the schema from the connected structured output parser
|
||||
schema = (outputParser.parser as StructuredOutputParser<ZodObject<any, any, any, any>>).schema;
|
||||
} else {
|
||||
// If the output parser is not a structured output parser, we will use a fallback schema
|
||||
schema = z.object({ text: z.string() });
|
||||
}
|
||||
|
||||
return schema;
|
||||
}
|
||||
|
||||
export async function toolsAgentExecute(
|
||||
this: IExecuteFunctions,
|
||||
nodeVersion: number,
|
||||
): Promise<INodeExecutionData[][]> {
|
||||
this.logger.verbose('Executing Tools Agent');
|
||||
const model = await this.getInputConnectionData(NodeConnectionType.AiLanguageModel, 0);
|
||||
|
||||
if (!isChatInstance(model) || !model.bindTools) {
|
||||
throw new NodeOperationError(
|
||||
this.getNode(),
|
||||
'Tools Agent requires Chat Model which supports Tools calling',
|
||||
);
|
||||
}
|
||||
|
||||
const memory = (await this.getInputConnectionData(NodeConnectionType.AiMemory, 0)) as
|
||||
| BaseChatMemory
|
||||
| undefined;
|
||||
|
||||
const tools = (await getConnectedTools(this, true)) as Array<DynamicStructuredTool | Tool>;
|
||||
const outputParser = (await getOptionalOutputParsers(this))?.[0];
|
||||
let structuredOutputParserTool: DynamicStructuredTool | undefined;
|
||||
|
||||
async function agentStepsParser(
|
||||
steps: AgentFinish | AgentAction[],
|
||||
): Promise<AgentFinish | AgentAction[]> {
|
||||
if (Array.isArray(steps)) {
|
||||
const responseParserTool = steps.find((step) => step.tool === 'format_final_response');
|
||||
if (responseParserTool) {
|
||||
const toolInput = responseParserTool?.toolInput;
|
||||
const returnValues = (await outputParser.parse(toolInput as unknown as string)) as Record<
|
||||
string,
|
||||
unknown
|
||||
>;
|
||||
|
||||
return {
|
||||
returnValues,
|
||||
log: 'Final response formatted',
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// If the steps are an AgentFinish and the outputParser is defined it must mean that the LLM didn't use `format_final_response` tool so we will parse the output manually
|
||||
if (outputParser && typeof steps === 'object' && (steps as AgentFinish).returnValues) {
|
||||
const finalResponse = (steps as AgentFinish).returnValues;
|
||||
const returnValues = (await outputParser.parse(finalResponse as unknown as string)) as Record<
|
||||
string,
|
||||
unknown
|
||||
>;
|
||||
|
||||
return {
|
||||
returnValues,
|
||||
log: 'Final response formatted',
|
||||
};
|
||||
}
|
||||
return steps;
|
||||
}
|
||||
|
||||
if (outputParser) {
|
||||
const schema = getOutputParserSchema(outputParser);
|
||||
structuredOutputParserTool = new DynamicStructuredTool({
|
||||
schema,
|
||||
name: 'format_final_response',
|
||||
description:
|
||||
'Always use this tool for the final output to the user. It validates the output so only use it when you are sure the output is final.',
|
||||
// We will not use the function here as we will use the parser to intercept & parse the output in the agentStepsParser
|
||||
func: async () => '',
|
||||
});
|
||||
|
||||
tools.push(structuredOutputParserTool);
|
||||
}
|
||||
|
||||
const options = this.getNodeParameter('options', 0, {}) as {
|
||||
systemMessage?: string;
|
||||
maxIterations?: number;
|
||||
returnIntermediateSteps?: boolean;
|
||||
};
|
||||
|
||||
const prompt = ChatPromptTemplate.fromMessages([
|
||||
['system', `{system_message}${outputParser ? '\n\n{formatting_instructions}' : ''}`],
|
||||
['placeholder', '{chat_history}'],
|
||||
['human', '{input}'],
|
||||
['placeholder', '{agent_scratchpad}'],
|
||||
]);
|
||||
|
||||
const agent = createToolCallingAgent({
|
||||
llm: model,
|
||||
tools,
|
||||
prompt,
|
||||
streamRunnable: false,
|
||||
});
|
||||
agent.streamRunnable = false;
|
||||
|
||||
const runnableAgent = RunnableSequence.from<{
|
||||
steps: AgentStep[];
|
||||
}>([agent, agentStepsParser]);
|
||||
|
||||
const executor = AgentExecutor.fromAgentAndTools({
|
||||
agent: runnableAgent,
|
||||
memory,
|
||||
tools,
|
||||
returnIntermediateSteps: options.returnIntermediateSteps === true,
|
||||
maxIterations: options.maxIterations ?? 10,
|
||||
});
|
||||
const returnData: INodeExecutionData[] = [];
|
||||
|
||||
const items = this.getInputData();
|
||||
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
|
||||
try {
|
||||
const input = getPromptInputByType({
|
||||
ctx: this,
|
||||
i: itemIndex,
|
||||
inputKey: 'text',
|
||||
promptTypeKey: 'promptType',
|
||||
});
|
||||
|
||||
if (input === undefined) {
|
||||
throw new NodeOperationError(this.getNode(), 'The ‘text parameter is empty.');
|
||||
}
|
||||
|
||||
const response = await executor.invoke({
|
||||
input,
|
||||
system_message: options.systemMessage ?? SYSTEM_MESSAGE,
|
||||
formatting_instructions:
|
||||
'IMPORTANT: Always call `format_final_response` to format your final response!', //outputParser?.getFormatInstructions(),
|
||||
});
|
||||
|
||||
returnData.push({
|
||||
json: omit(
|
||||
response,
|
||||
'system_message',
|
||||
'formatting_instructions',
|
||||
'input',
|
||||
'chat_history',
|
||||
'agent_scratchpad',
|
||||
),
|
||||
});
|
||||
} catch (error) {
|
||||
if (this.continueOnFail()) {
|
||||
returnData.push({ json: { error: error.message }, pairedItem: { item: itemIndex } });
|
||||
continue;
|
||||
}
|
||||
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
return await this.prepareOutputData(returnData);
|
||||
}
|
|
@ -0,0 +1 @@
|
|||
export const SYSTEM_MESSAGE = 'You are a helpful assistant';
|
|
@ -13,7 +13,6 @@ import type { Document } from '@langchain/core/documents';
|
|||
import { TextSplitter } from 'langchain/text_splitter';
|
||||
import { BaseChatMemory } from '@langchain/community/memory/chat_memory';
|
||||
import { BaseRetriever } from '@langchain/core/retrievers';
|
||||
import type { FormatInstructionsOptions } from '@langchain/core/output_parsers';
|
||||
import { BaseOutputParser, OutputParserException } from '@langchain/core/output_parsers';
|
||||
import { isObject } from 'lodash';
|
||||
import type { BaseDocumentLoader } from 'langchain/dist/document_loaders/base';
|
||||
|
@ -222,31 +221,7 @@ export function logWrapper(
|
|||
|
||||
// ========== BaseOutputParser ==========
|
||||
if (originalInstance instanceof BaseOutputParser) {
|
||||
if (prop === 'getFormatInstructions' && 'getFormatInstructions' in target) {
|
||||
return (options?: FormatInstructionsOptions): string => {
|
||||
connectionType = NodeConnectionType.AiOutputParser;
|
||||
const { index } = executeFunctions.addInputData(connectionType, [
|
||||
[{ json: { action: 'getFormatInstructions' } }],
|
||||
]);
|
||||
|
||||
// @ts-ignore
|
||||
const response = callMethodSync.call(target, {
|
||||
executeFunctions,
|
||||
connectionType,
|
||||
currentNodeRunIndex: index,
|
||||
method: target[prop],
|
||||
arguments: [options],
|
||||
}) as string;
|
||||
|
||||
executeFunctions.addOutputData(connectionType, index, [
|
||||
[{ json: { action: 'getFormatInstructions', response } }],
|
||||
]);
|
||||
void logAiEvent(executeFunctions, 'n8n.ai.output.parser.get.instructions', {
|
||||
response,
|
||||
});
|
||||
return response;
|
||||
};
|
||||
} else if (prop === 'parse' && 'parse' in target) {
|
||||
if (prop === 'parse' && 'parse' in target) {
|
||||
return async (text: string | Record<string, unknown>): Promise<unknown> => {
|
||||
connectionType = NodeConnectionType.AiOutputParser;
|
||||
const stringifiedText = isObject(text) ? JSON.stringify(text) : text;
|
||||
|
@ -254,19 +229,30 @@ export function logWrapper(
|
|||
[{ json: { action: 'parse', text: stringifiedText } }],
|
||||
]);
|
||||
|
||||
const response = (await callMethodAsync.call(target, {
|
||||
executeFunctions,
|
||||
connectionType,
|
||||
currentNodeRunIndex: index,
|
||||
method: target[prop],
|
||||
arguments: [stringifiedText],
|
||||
})) as object;
|
||||
try {
|
||||
const response = (await callMethodAsync.call(target, {
|
||||
executeFunctions,
|
||||
connectionType,
|
||||
currentNodeRunIndex: index,
|
||||
method: target[prop],
|
||||
arguments: [stringifiedText],
|
||||
})) as object;
|
||||
|
||||
void logAiEvent(executeFunctions, 'n8n.ai.output.parser.parsed', { text, response });
|
||||
executeFunctions.addOutputData(connectionType, index, [
|
||||
[{ json: { action: 'parse', response } }],
|
||||
]);
|
||||
return response;
|
||||
void logAiEvent(executeFunctions, 'n8n.ai.output.parser.parsed', { text, response });
|
||||
executeFunctions.addOutputData(connectionType, index, [
|
||||
[{ json: { action: 'parse', response } }],
|
||||
]);
|
||||
return response;
|
||||
} catch (error) {
|
||||
void logAiEvent(executeFunctions, 'n8n.ai.output.parser.parsed', {
|
||||
text,
|
||||
response: error.message ?? error,
|
||||
});
|
||||
executeFunctions.addOutputData(connectionType, index, [
|
||||
[{ json: { action: 'parse', response: error.message ?? error } }],
|
||||
]);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue