feat: Introduce prompt type option for Agent, Basic LLM Chain, and QA Chain nodes (#8697)

Signed-off-by: Oleg Ivaniv <me@olegivaniv.com>
Co-authored-by: Michael Kret <michael.k@radency.com>
This commit is contained in:
oleg 2024-02-21 14:59:37 +01:00 committed by GitHub
parent 40aecd1715
commit 2068f186ff
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 348 additions and 55 deletions

View file

@ -2,6 +2,8 @@
* Getters
*/
import { getVisibleSelect } from "../utils";
export function getCredentialSelect(eq = 0) {
return cy.getByTestId('node-credentials-select').eq(eq);
}
@ -71,3 +73,12 @@ export function clickExecuteNode() {
export function setParameterInputByName(name: string, value: string) {
getParameterInputByName(name).clear().type(value);
}
export function toggleParameterCheckboxInputByName(name: string) {
getParameterInputByName(name).find('input[type="checkbox"]').realClick()
}
export function setParameterSelectByContent(name: string, content: string) {
getParameterInputByName(name).realClick();
getVisibleSelect().find('.option-headline').contains(content).click();
}

View file

@ -26,7 +26,10 @@ import {
clickExecuteNode,
clickGetBackToCanvas,
getOutputPanelTable,
getParameterInputByName,
setParameterInputByName,
setParameterSelectByContent,
toggleParameterCheckboxInputByName,
} from '../composables/ndv';
import { setCredentialValues } from '../composables/modals/credential-modal';
import {
@ -45,7 +48,9 @@ describe('Langchain Integration', () => {
it('should add nodes to all Agent node input types', () => {
addNodeToCanvas(MANUAL_TRIGGER_NODE_NAME, true);
addNodeToCanvas(AGENT_NODE_NAME, true);
addNodeToCanvas(AGENT_NODE_NAME, true, true);
toggleParameterCheckboxInputByName('hasOutputParser');
clickGetBackToCanvas();
addLanguageModelNodeToParent(AI_LANGUAGE_MODEL_OPENAI_CHAT_MODEL_NODE_NAME, AGENT_NODE_NAME, true);
clickGetBackToCanvas();
@ -94,10 +99,11 @@ describe('Langchain Integration', () => {
openNode(BASIC_LLM_CHAIN_NODE_NAME);
setParameterSelectByContent('promptType', 'Define below')
const inputMessage = 'Hello!';
const outputMessage = 'Hi there! How can I assist you today?';
setParameterInputByName('prompt', inputMessage);
setParameterInputByName('text', inputMessage);
runMockWorkflowExcution({
trigger: () => clickExecuteNode(),
@ -135,6 +141,7 @@ describe('Langchain Integration', () => {
const inputMessage = 'Hello!';
const outputMessage = 'Hi there! How can I assist you today?';
setParameterSelectByContent('promptType', 'Define below')
setParameterInputByName('text', inputMessage);
runMockWorkflowExcution({

View file

@ -24,10 +24,12 @@ import { sqlAgentAgentExecute } from './agents/SqlAgent/execute';
// display based on the agent type
function getInputs(
agent: 'conversationalAgent' | 'openAiFunctionsAgent' | 'reActAgent' | 'sqlAgent',
hasOutputParser?: boolean,
): Array<ConnectionTypes | INodeInputConfiguration> {
interface SpecialInput {
type: ConnectionTypes;
filter?: INodeInputFilter;
required?: boolean;
}
const getInputData = (
@ -40,7 +42,7 @@ function getInputs(
[NodeConnectionType.AiOutputParser]: 'Output Parser',
};
return inputs.map(({ type, filter }) => {
return inputs.map(({ type, filter, required }) => {
const input: INodeInputConfiguration = {
type,
displayName: type in displayNames ? displayNames[type] : undefined,
@ -100,6 +102,7 @@ function getInputs(
},
{
type: NodeConnectionType.AiTool,
required: true,
},
{
type: NodeConnectionType.AiOutputParser,
@ -137,6 +140,11 @@ function getInputs(
];
}
if (hasOutputParser === false) {
specialInputs = specialInputs.filter(
(input) => input.type !== NodeConnectionType.AiOutputParser,
);
}
return [NodeConnectionType.Main, ...getInputData(specialInputs)];
}
@ -146,7 +154,7 @@ export class Agent implements INodeType {
name: 'agent',
icon: 'fa:robot',
group: ['transform'],
version: [1, 1.1, 1.2],
version: [1, 1.1, 1.2, 1.3],
description: 'Generates an action plan and executes it. Can use external tools.',
subtitle:
"={{ { conversationalAgent: 'Conversational Agent', openAiFunctionsAgent: 'OpenAI Functions Agent', reactAgent: 'ReAct Agent', sqlAgent: 'SQL Agent' }[$parameter.agent] }}",
@ -168,7 +176,12 @@ export class Agent implements INodeType {
],
},
},
inputs: `={{ ((agent) => { ${getInputs.toString()}; return getInputs(agent) })($parameter.agent) }}`,
inputs: `={{
((agent, hasOutputParser) => {
${getInputs.toString()};
return getInputs(agent, hasOutputParser)
})($parameter.agent, $parameter.hasOutputParser === undefined || $parameter.hasOutputParser === true)
}}`,
outputs: [NodeConnectionType.Main],
credentials: [
{
@ -240,6 +253,71 @@ export class Agent implements INodeType {
],
default: 'conversationalAgent',
},
{
displayName: 'Prompt',
name: 'promptType',
type: 'options',
options: [
{
// eslint-disable-next-line n8n-nodes-base/node-param-display-name-miscased
name: 'Take from previous node automatically',
value: 'auto',
description: 'Looks for an input field called chatInput',
},
{
// eslint-disable-next-line n8n-nodes-base/node-param-display-name-miscased
name: 'Define below',
value: 'define',
description:
'Use an expression to reference data in previous nodes or enter static text',
},
],
displayOptions: {
hide: {
'@version': [{ _cnd: { lte: 1.2 } }],
},
},
default: 'auto',
},
{
displayName: 'Text',
name: 'text',
type: 'string',
required: true,
default: '',
placeholder: 'e.g. Hello, how can you help me?',
typeOptions: {
rows: 2,
},
displayOptions: {
show: {
promptType: ['define'],
},
},
},
{
displayName: 'Require Specific Output Format',
name: 'hasOutputParser',
type: 'boolean',
default: false,
displayOptions: {
hide: {
'@version': [{ _cnd: { lte: 1.2 } }],
agent: ['sqlAgent'],
},
},
},
{
displayName: `Connect an <a data-action='openSelectiveNodeCreator' data-action-parameter-connectiontype='${NodeConnectionType.AiOutputParser}'>output parser</a> on the canvas to specify the output format you require`,
name: 'notice',
type: 'notice',
default: '',
displayOptions: {
show: {
hasOutputParser: [true],
},
},
},
...conversationalAgentProperties,
...openAiFunctionsAgentProperties,

View file

@ -11,7 +11,11 @@ import type { BaseChatMemory } from 'langchain/memory';
import type { BaseOutputParser } from 'langchain/schema/output_parser';
import { PromptTemplate } from 'langchain/prompts';
import { CombiningOutputParser } from 'langchain/output_parsers';
import { isChatInstance } from '../../../../../utils/helpers';
import {
isChatInstance,
getPromptInputByType,
getOptionalOutputParsers,
} from '../../../../../utils/helpers';
export async function conversationalAgentExecute(
this: IExecuteFunctions,
@ -28,10 +32,7 @@ export async function conversationalAgentExecute(
| BaseChatMemory
| undefined;
const tools = (await this.getInputConnectionData(NodeConnectionType.AiTool, 0)) as Tool[];
const outputParsers = (await this.getInputConnectionData(
NodeConnectionType.AiOutputParser,
0,
)) as BaseOutputParser[];
const outputParsers = await getOptionalOutputParsers(this);
// TODO: Make it possible in the future to use values for other items than just 0
const options = this.getNodeParameter('options', 0, {}) as {
@ -80,7 +81,18 @@ export async function conversationalAgentExecute(
const items = this.getInputData();
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
let input = this.getNodeParameter('text', itemIndex) as string;
let input;
if (this.getNode().typeVersion <= 1.2) {
input = this.getNodeParameter('text', itemIndex) as string;
} else {
input = getPromptInputByType({
ctx: this,
i: itemIndex,
inputKey: 'text',
promptTypeKey: 'promptType',
});
}
if (input === undefined) {
throw new NodeOperationError(this.getNode(), 'The text parameter is empty.');

View file

@ -13,6 +13,7 @@ import { PromptTemplate } from 'langchain/prompts';
import { CombiningOutputParser } from 'langchain/output_parsers';
import { BufferMemory, type BaseChatMemory } from 'langchain/memory';
import { ChatOpenAI } from 'langchain/chat_models/openai';
import { getOptionalOutputParsers, getPromptInputByType } from '../../../../../utils/helpers';
export async function openAiFunctionsAgentExecute(
this: IExecuteFunctions,
@ -33,10 +34,7 @@ export async function openAiFunctionsAgentExecute(
| BaseChatMemory
| undefined;
const tools = (await this.getInputConnectionData(NodeConnectionType.AiTool, 0)) as Tool[];
const outputParsers = (await this.getInputConnectionData(
NodeConnectionType.AiOutputParser,
0,
)) as BaseOutputParser[];
const outputParsers = await getOptionalOutputParsers(this);
const options = this.getNodeParameter('options', 0, {}) as {
systemMessage?: string;
maxIterations?: number;
@ -82,7 +80,17 @@ export async function openAiFunctionsAgentExecute(
const items = this.getInputData();
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
let input = this.getNodeParameter('text', itemIndex) as string;
let input;
if (this.getNode().typeVersion <= 1.2) {
input = this.getNodeParameter('text', itemIndex) as string;
} else {
input = getPromptInputByType({
ctx: this,
i: itemIndex,
inputKey: 'text',
promptTypeKey: 'promptType',
});
}
if (input === undefined) {
throw new NodeOperationError(this.getNode(), 'The text parameter is empty.');

View file

@ -39,7 +39,7 @@ export const planAndExecuteAgentProperties: INodeProperties[] = [
'@version': [1.2],
},
},
default: '={{ $json.chatInput } }',
default: '={{ $json.chatInput }}',
},
{
displayName: 'Options',

View file

@ -11,6 +11,7 @@ import { PromptTemplate } from 'langchain/prompts';
import { CombiningOutputParser } from 'langchain/output_parsers';
import type { BaseChatModel } from 'langchain/chat_models/base';
import { PlanAndExecuteAgentExecutor } from 'langchain/experimental/plan_and_execute';
import { getOptionalOutputParsers, getPromptInputByType } from '../../../../../utils/helpers';
export async function planAndExecuteAgentExecute(
this: IExecuteFunctions,
@ -23,10 +24,7 @@ export async function planAndExecuteAgentExecute(
const tools = (await this.getInputConnectionData(NodeConnectionType.AiTool, 0)) as Tool[];
const outputParsers = (await this.getInputConnectionData(
NodeConnectionType.AiOutputParser,
0,
)) as BaseOutputParser[];
const outputParsers = await getOptionalOutputParsers(this);
const options = this.getNodeParameter('options', 0, {}) as {
humanMessageTemplate?: string;
@ -57,7 +55,17 @@ export async function planAndExecuteAgentExecute(
const items = this.getInputData();
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
let input = this.getNodeParameter('text', itemIndex) as string;
let input;
if (this.getNode().typeVersion <= 1.2) {
input = this.getNodeParameter('text', itemIndex) as string;
} else {
input = getPromptInputByType({
ctx: this,
i: itemIndex,
inputKey: 'text',
promptTypeKey: 'promptType',
});
}
if (input === undefined) {
throw new NodeOperationError(this.getNode(), 'The text parameter is empty.');

View file

@ -12,7 +12,11 @@ import type { BaseOutputParser } from 'langchain/schema/output_parser';
import { PromptTemplate } from 'langchain/prompts';
import { CombiningOutputParser } from 'langchain/output_parsers';
import type { BaseChatModel } from 'langchain/chat_models/base';
import { isChatInstance } from '../../../../../utils/helpers';
import {
getOptionalOutputParsers,
getPromptInputByType,
isChatInstance,
} from '../../../../../utils/helpers';
export async function reActAgentAgentExecute(
this: IExecuteFunctions,
@ -25,10 +29,7 @@ export async function reActAgentAgentExecute(
const tools = (await this.getInputConnectionData(NodeConnectionType.AiTool, 0)) as Tool[];
const outputParsers = (await this.getInputConnectionData(
NodeConnectionType.AiOutputParser,
0,
)) as BaseOutputParser[];
const outputParsers = await getOptionalOutputParsers(this);
const options = this.getNodeParameter('options', 0, {}) as {
prefix?: string;
@ -77,7 +78,18 @@ export async function reActAgentAgentExecute(
const items = this.getInputData();
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
let input = this.getNodeParameter('text', itemIndex) as string;
let input;
if (this.getNode().typeVersion <= 1.2) {
input = this.getNodeParameter('text', itemIndex) as string;
} else {
input = getPromptInputByType({
ctx: this,
i: itemIndex,
inputKey: 'text',
promptTypeKey: 'promptType',
});
}
if (input === undefined) {
throw new NodeOperationError(this.getNode(), 'The text parameter is empty.');

View file

@ -38,6 +38,7 @@ export const sqlAgentAgentProperties: INodeProperties[] = [
displayOptions: {
show: {
agent: ['sqlAgent'],
'@version': [{ _cnd: { lte: 1.2 } }],
},
},
default: '',

View file

@ -11,6 +11,7 @@ import { SqlToolkit, createSqlAgent } from 'langchain/agents/toolkits/sql';
import type { BaseLanguageModel } from 'langchain/dist/base_language';
import type { DataSource } from '@n8n/typeorm';
import { getPromptInputByType } from '../../../../../utils/helpers';
import { getSqliteDataSource } from './other/handlers/sqlite';
import { getPostgresDataSource } from './other/handlers/postgres';
import { SQL_PREFIX, SQL_SUFFIX } from './other/prompts';
@ -37,7 +38,17 @@ export async function sqlAgentAgentExecute(
for (let i = 0; i < items.length; i++) {
const item = items[i];
const input = this.getNodeParameter('input', i) as string;
let input;
if (this.getNode().typeVersion <= 1.2) {
input = this.getNodeParameter('input', i) as string;
} else {
input = getPromptInputByType({
ctx: this,
i,
inputKey: 'input',
promptTypeKey: 'promptType',
});
}
if (input === undefined) {
throw new NodeOperationError(this.getNode(), 'The prompt parameter is empty.');

View file

@ -22,7 +22,11 @@ import { LLMChain } from 'langchain/chains';
import type { BaseChatModel } from 'langchain/chat_models/base';
import { HumanMessage } from 'langchain/schema';
import { getTemplateNoticeField } from '../../../utils/sharedFields';
import { isChatInstance } from '../../../utils/helpers';
import {
getOptionalOutputParsers,
getPromptInputByType,
isChatInstance,
} from '../../../utils/helpers';
interface MessagesTemplate {
type: string;
@ -204,7 +208,7 @@ function getInputs(parameters: IDataObject) {
},
];
// If `hasOutputParser` is undefined it must be version 1.1 or earlier so we
// If `hasOutputParser` is undefined it must be version 1.3 or earlier so we
// always add the output parser input
if (hasOutputParser === undefined || hasOutputParser === true) {
inputs.push({ displayName: 'Output Parser', type: NodeConnectionType.AiOutputParser });
@ -218,7 +222,7 @@ export class ChainLlm implements INodeType {
name: 'chainLlm',
icon: 'fa:link',
group: ['transform'],
version: [1, 1.1, 1.2, 1.3],
version: [1, 1.1, 1.2, 1.3, 1.4],
description: 'A simple chain to prompt a large language model',
defaults: {
name: 'Basic LLM Chain',
@ -279,6 +283,59 @@ export class ChainLlm implements INodeType {
},
},
},
{
displayName: 'Prompt',
name: 'promptType',
type: 'options',
options: [
{
// eslint-disable-next-line n8n-nodes-base/node-param-display-name-miscased
name: 'Take from previous node automatically',
value: 'auto',
description: 'Looks for an input field called chatInput',
},
{
// eslint-disable-next-line n8n-nodes-base/node-param-display-name-miscased
name: 'Define below',
value: 'define',
description:
'Use an expression to reference data in previous nodes or enter static text',
},
],
displayOptions: {
hide: {
'@version': [1, 1.1, 1.2, 1.3],
},
},
default: 'auto',
},
{
displayName: 'Text',
name: 'text',
type: 'string',
required: true,
default: '',
placeholder: 'e.g. Hello, how can you help me?',
typeOptions: {
rows: 2,
},
displayOptions: {
show: {
promptType: ['define'],
},
},
},
{
displayName: 'Require Specific Output Format',
name: 'hasOutputParser',
type: 'boolean',
default: false,
displayOptions: {
hide: {
'@version': [1, 1.1, 1.3],
},
},
},
{
displayName: 'Chat Messages (if Using a Chat Model)',
name: 'messages',
@ -419,17 +476,6 @@ export class ChainLlm implements INodeType {
},
],
},
{
displayName: 'Require Specific Output Format',
name: 'hasOutputParser',
type: 'boolean',
default: false,
displayOptions: {
show: {
'@version': [1.2],
},
},
},
{
displayName: `Connect an <a data-action='openSelectiveNodeCreator' data-action-parameter-connectiontype='${NodeConnectionType.AiOutputParser}'>output parser</a> on the canvas to specify the output format you require`,
name: 'notice',
@ -454,17 +500,20 @@ export class ChainLlm implements INodeType {
0,
)) as BaseLanguageModel;
let outputParsers: BaseOutputParser[] = [];
if (this.getNodeParameter('hasOutputParser', 0, true) === true) {
outputParsers = (await this.getInputConnectionData(
NodeConnectionType.AiOutputParser,
0,
)) as BaseOutputParser[];
}
const outputParsers = await getOptionalOutputParsers(this);
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
const prompt = this.getNodeParameter('prompt', itemIndex) as string;
let prompt: string;
if (this.getNode().typeVersion <= 1.2) {
prompt = this.getNodeParameter('prompt', itemIndex) as string;
} else {
prompt = getPromptInputByType({
ctx: this,
i: itemIndex,
inputKey: 'text',
promptTypeKey: 'promptType',
});
}
const messages = this.getNodeParameter(
'messages.messageValues',
itemIndex,

View file

@ -11,6 +11,7 @@ import { RetrievalQAChain } from 'langchain/chains';
import type { BaseLanguageModel } from 'langchain/dist/base_language';
import type { BaseRetriever } from 'langchain/schema/retriever';
import { getTemplateNoticeField } from '../../../utils/sharedFields';
import { getPromptInputByType } from '../../../utils/helpers';
export class ChainRetrievalQa implements INodeType {
description: INodeTypeDescription = {
@ -18,7 +19,7 @@ export class ChainRetrievalQa implements INodeType {
name: 'chainRetrievalQa',
icon: 'fa:link',
group: ['transform'],
version: [1, 1.1, 1.2],
version: [1, 1.1, 1.2, 1.3],
description: 'Answer questions about retrieved documents',
defaults: {
name: 'Question and Answer Chain',
@ -94,6 +95,47 @@ export class ChainRetrievalQa implements INodeType {
},
},
},
{
displayName: 'Prompt',
name: 'promptType',
type: 'options',
options: [
{
// eslint-disable-next-line n8n-nodes-base/node-param-display-name-miscased
name: 'Take from previous node automatically',
value: 'auto',
description: 'Looks for an input field called chatInput',
},
{
// eslint-disable-next-line n8n-nodes-base/node-param-display-name-miscased
name: 'Define below',
value: 'define',
description:
'Use an expression to reference data in previous nodes or enter static text',
},
],
displayOptions: {
hide: {
'@version': [{ _cnd: { lte: 1.2 } }],
},
},
default: 'auto',
},
{
displayName: 'Text',
name: 'text',
type: 'string',
required: true,
default: '',
typeOptions: {
rows: 2,
},
displayOptions: {
show: {
promptType: ['define'],
},
},
},
],
};
@ -117,7 +159,18 @@ export class ChainRetrievalQa implements INodeType {
// Run for each item
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
const query = this.getNodeParameter('query', itemIndex) as string;
let query;
if (this.getNode().typeVersion <= 1.2) {
query = this.getNodeParameter('query', itemIndex) as string;
} else {
query = getPromptInputByType({
ctx: this,
i: itemIndex,
inputKey: 'text',
promptTypeKey: 'promptType',
});
}
if (query === undefined) {
throw new NodeOperationError(this.getNode(), 'The query parameter is empty.');

View file

@ -1,6 +1,7 @@
import type { IExecuteFunctions } from 'n8n-workflow';
import { NodeConnectionType, type IExecuteFunctions, NodeOperationError } from 'n8n-workflow';
import { BaseChatModel } from 'langchain/chat_models/base';
import { BaseChatModel as BaseChatModelCore } from '@langchain/core/language_models/chat_models';
import type { BaseOutputParser } from '@langchain/core/output_parsers';
export function getMetadataFiltersValues(
ctx: IExecuteFunctions,
@ -18,6 +19,48 @@ export function getMetadataFiltersValues(
}
// TODO: Remove this function once langchain package is updated to 0.1.x
// eslint-disable-next-line @typescript-eslint/no-duplicate-type-constituents
export function isChatInstance(model: any): model is BaseChatModel | BaseChatModelCore {
return model instanceof BaseChatModel || model instanceof BaseChatModelCore;
}
export async function getOptionalOutputParsers(
ctx: IExecuteFunctions,
): Promise<Array<BaseOutputParser<unknown>>> {
let outputParsers: BaseOutputParser[] = [];
if (ctx.getNodeParameter('hasOutputParser', 0, true) === true) {
outputParsers = (await ctx.getInputConnectionData(
NodeConnectionType.AiOutputParser,
0,
)) as BaseOutputParser[];
}
return outputParsers;
}
export function getPromptInputByType(options: {
ctx: IExecuteFunctions;
i: number;
promptTypeKey: string;
inputKey: string;
}) {
const { ctx, i, promptTypeKey, inputKey } = options;
const prompt = ctx.getNodeParameter(promptTypeKey, i) as string;
let input;
if (prompt === 'auto') {
input = ctx.evaluateExpression('{{ $json["chatInput"] }}', i) as string;
} else {
input = ctx.getNodeParameter(inputKey, i) as string;
}
if (input === undefined) {
throw new NodeOperationError(ctx.getNode(), 'No prompt specified', {
description:
"Expected to find the prompt in an input field called 'chatInput' (this is what the chat trigger node outputs). To use something else, change the 'Prompt' parameter",
});
}
return input;
}