mirror of
https://github.com/n8n-io/n8n.git
synced 2024-12-25 04:34:06 -08:00
feat(Question and Answer Chain Node): Customize question and answer system prompt (#10385)
This commit is contained in:
parent
7073ec6fe5
commit
08a27b3148
|
@ -10,10 +10,21 @@ import {
|
|||
import { RetrievalQAChain } from 'langchain/chains';
|
||||
import type { BaseLanguageModel } from '@langchain/core/language_models/base';
|
||||
import type { BaseRetriever } from '@langchain/core/retrievers';
|
||||
import {
|
||||
ChatPromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
PromptTemplate,
|
||||
} from '@langchain/core/prompts';
|
||||
import { getTemplateNoticeField } from '../../../utils/sharedFields';
|
||||
import { getPromptInputByType } from '../../../utils/helpers';
|
||||
import { getPromptInputByType, isChatInstance } from '../../../utils/helpers';
|
||||
import { getTracingConfig } from '../../../utils/tracing';
|
||||
|
||||
const SYSTEM_PROMPT_TEMPLATE = `Use the following pieces of context to answer the users question.
|
||||
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
||||
----------------
|
||||
{context}`;
|
||||
|
||||
export class ChainRetrievalQa implements INodeType {
|
||||
description: INodeTypeDescription = {
|
||||
displayName: 'Question and Answer Chain',
|
||||
|
@ -137,6 +148,26 @@ export class ChainRetrievalQa implements INodeType {
|
|||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
displayName: 'Options',
|
||||
name: 'options',
|
||||
type: 'collection',
|
||||
default: {},
|
||||
placeholder: 'Add Option',
|
||||
options: [
|
||||
{
|
||||
displayName: 'System Prompt Template',
|
||||
name: 'systemPromptTemplate',
|
||||
type: 'string',
|
||||
default: SYSTEM_PROMPT_TEMPLATE,
|
||||
description:
|
||||
'Template string used for the system prompt. This should include the variable `{context}` for the provided context. For text completion models, you should also include the variable `{question}` for the user’s query.',
|
||||
typeOptions: {
|
||||
rows: 6,
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
|
@ -154,7 +185,6 @@ export class ChainRetrievalQa implements INodeType {
|
|||
)) as BaseRetriever;
|
||||
|
||||
const items = this.getInputData();
|
||||
const chain = RetrievalQAChain.fromLLM(model, retriever);
|
||||
|
||||
const returnData: INodeExecutionData[] = [];
|
||||
|
||||
|
@ -178,6 +208,35 @@ export class ChainRetrievalQa implements INodeType {
|
|||
throw new NodeOperationError(this.getNode(), 'The ‘query‘ parameter is empty.');
|
||||
}
|
||||
|
||||
const options = this.getNodeParameter('options', itemIndex, {}) as {
|
||||
systemPromptTemplate?: string;
|
||||
};
|
||||
|
||||
const chainParameters = {} as {
|
||||
prompt?: PromptTemplate | ChatPromptTemplate;
|
||||
};
|
||||
|
||||
if (options.systemPromptTemplate !== undefined) {
|
||||
if (isChatInstance(model)) {
|
||||
const messages = [
|
||||
SystemMessagePromptTemplate.fromTemplate(options.systemPromptTemplate),
|
||||
HumanMessagePromptTemplate.fromTemplate('{question}'),
|
||||
];
|
||||
const chatPromptTemplate = ChatPromptTemplate.fromMessages(messages);
|
||||
|
||||
chainParameters.prompt = chatPromptTemplate;
|
||||
} else {
|
||||
const completionPromptTemplate = new PromptTemplate({
|
||||
template: options.systemPromptTemplate,
|
||||
inputVariables: ['context', 'question'],
|
||||
});
|
||||
|
||||
chainParameters.prompt = completionPromptTemplate;
|
||||
}
|
||||
}
|
||||
|
||||
const chain = RetrievalQAChain.fromLLM(model, retriever, chainParameters);
|
||||
|
||||
const response = await chain.withConfig(getTracingConfig(this)).invoke({ query });
|
||||
returnData.push({ json: { response } });
|
||||
} catch (error) {
|
||||
|
|
Loading…
Reference in a new issue