feat(Question and Answer Chain Node): Customize question and answer system prompt (#10385)

This commit is contained in:
Miguel Prytoluk 2024-09-27 07:09:39 -03:00 committed by GitHub
parent 7073ec6fe5
commit 08a27b3148
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -10,10 +10,21 @@ import {
import { RetrievalQAChain } from 'langchain/chains';
import type { BaseLanguageModel } from '@langchain/core/language_models/base';
import type { BaseRetriever } from '@langchain/core/retrievers';
import {
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
PromptTemplate,
} from '@langchain/core/prompts';
import { getTemplateNoticeField } from '../../../utils/sharedFields';
import { getPromptInputByType } from '../../../utils/helpers';
import { getPromptInputByType, isChatInstance } from '../../../utils/helpers';
import { getTracingConfig } from '../../../utils/tracing';
const SYSTEM_PROMPT_TEMPLATE = `Use the following pieces of context to answer the users question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
----------------
{context}`;
export class ChainRetrievalQa implements INodeType {
description: INodeTypeDescription = {
displayName: 'Question and Answer Chain',
@ -137,6 +148,26 @@ export class ChainRetrievalQa implements INodeType {
},
},
},
{
displayName: 'Options',
name: 'options',
type: 'collection',
default: {},
placeholder: 'Add Option',
options: [
{
displayName: 'System Prompt Template',
name: 'systemPromptTemplate',
type: 'string',
default: SYSTEM_PROMPT_TEMPLATE,
description:
'Template string used for the system prompt. This should include the variable `{context}` for the provided context. For text completion models, you should also include the variable `{question}` for the users query.',
typeOptions: {
rows: 6,
},
},
],
},
],
};
@ -154,7 +185,6 @@ export class ChainRetrievalQa implements INodeType {
)) as BaseRetriever;
const items = this.getInputData();
const chain = RetrievalQAChain.fromLLM(model, retriever);
const returnData: INodeExecutionData[] = [];
@ -178,6 +208,35 @@ export class ChainRetrievalQa implements INodeType {
throw new NodeOperationError(this.getNode(), 'The query parameter is empty.');
}
const options = this.getNodeParameter('options', itemIndex, {}) as {
systemPromptTemplate?: string;
};
const chainParameters = {} as {
prompt?: PromptTemplate | ChatPromptTemplate;
};
if (options.systemPromptTemplate !== undefined) {
if (isChatInstance(model)) {
const messages = [
SystemMessagePromptTemplate.fromTemplate(options.systemPromptTemplate),
HumanMessagePromptTemplate.fromTemplate('{question}'),
];
const chatPromptTemplate = ChatPromptTemplate.fromMessages(messages);
chainParameters.prompt = chatPromptTemplate;
} else {
const completionPromptTemplate = new PromptTemplate({
template: options.systemPromptTemplate,
inputVariables: ['context', 'question'],
});
chainParameters.prompt = completionPromptTemplate;
}
}
const chain = RetrievalQAChain.fromLLM(model, retriever, chainParameters);
const response = await chain.withConfig(getTracingConfig(this)).invoke({ query });
returnData.push({ json: { response } });
} catch (error) {