diff --git a/packages/@n8n/nodes-langchain/nodes/chains/ChainRetrievalQA/ChainRetrievalQa.node.ts b/packages/@n8n/nodes-langchain/nodes/chains/ChainRetrievalQA/ChainRetrievalQa.node.ts index 8249cbd90c..dfbdb3e9d1 100644 --- a/packages/@n8n/nodes-langchain/nodes/chains/ChainRetrievalQA/ChainRetrievalQa.node.ts +++ b/packages/@n8n/nodes-langchain/nodes/chains/ChainRetrievalQA/ChainRetrievalQa.node.ts @@ -10,10 +10,21 @@ import { import { RetrievalQAChain } from 'langchain/chains'; import type { BaseLanguageModel } from '@langchain/core/language_models/base'; import type { BaseRetriever } from '@langchain/core/retrievers'; +import { + ChatPromptTemplate, + SystemMessagePromptTemplate, + HumanMessagePromptTemplate, + PromptTemplate, +} from '@langchain/core/prompts'; import { getTemplateNoticeField } from '../../../utils/sharedFields'; -import { getPromptInputByType } from '../../../utils/helpers'; +import { getPromptInputByType, isChatInstance } from '../../../utils/helpers'; import { getTracingConfig } from '../../../utils/tracing'; +const SYSTEM_PROMPT_TEMPLATE = `Use the following pieces of context to answer the users question. +If you don't know the answer, just say that you don't know, don't try to make up an answer. +---------------- +{context}`; + export class ChainRetrievalQa implements INodeType { description: INodeTypeDescription = { displayName: 'Question and Answer Chain', @@ -137,6 +148,26 @@ export class ChainRetrievalQa implements INodeType { }, }, }, + { + displayName: 'Options', + name: 'options', + type: 'collection', + default: {}, + placeholder: 'Add Option', + options: [ + { + displayName: 'System Prompt Template', + name: 'systemPromptTemplate', + type: 'string', + default: SYSTEM_PROMPT_TEMPLATE, + description: + 'Template string used for the system prompt. This should include the variable `{context}` for the provided context. For text completion models, you should also include the variable `{question}` for the user’s query.', + typeOptions: { + rows: 6, + }, + }, + ], + }, ], }; @@ -154,7 +185,6 @@ export class ChainRetrievalQa implements INodeType { )) as BaseRetriever; const items = this.getInputData(); - const chain = RetrievalQAChain.fromLLM(model, retriever); const returnData: INodeExecutionData[] = []; @@ -178,6 +208,35 @@ export class ChainRetrievalQa implements INodeType { throw new NodeOperationError(this.getNode(), 'The ‘query‘ parameter is empty.'); } + const options = this.getNodeParameter('options', itemIndex, {}) as { + systemPromptTemplate?: string; + }; + + const chainParameters = {} as { + prompt?: PromptTemplate | ChatPromptTemplate; + }; + + if (options.systemPromptTemplate !== undefined) { + if (isChatInstance(model)) { + const messages = [ + SystemMessagePromptTemplate.fromTemplate(options.systemPromptTemplate), + HumanMessagePromptTemplate.fromTemplate('{question}'), + ]; + const chatPromptTemplate = ChatPromptTemplate.fromMessages(messages); + + chainParameters.prompt = chatPromptTemplate; + } else { + const completionPromptTemplate = new PromptTemplate({ + template: options.systemPromptTemplate, + inputVariables: ['context', 'question'], + }); + + chainParameters.prompt = completionPromptTemplate; + } + } + + const chain = RetrievalQAChain.fromLLM(model, retriever, chainParameters); + const response = await chain.withConfig(getTracingConfig(this)).invoke({ query }); returnData.push({ json: { response } }); } catch (error) {