n8n/packages/@n8n/nodes-langchain/nodes/chains/ChainSummarization/helpers.ts
कारतोफ्फेलस्क्रिप्ट™ 2ce1644d01
refactor(core): Shovel around more of AI code (no-changelog) (#12218)
2024-12-16 13:46:19 +01:00

72 lines
1.7 KiB
TypeScript

import { PromptTemplate } from '@langchain/core/prompts';
import type { SummarizationChainParams } from 'langchain/chains';
interface ChainTypeOptions {
combineMapPrompt?: string;
prompt?: string;
refinePrompt?: string;
refineQuestionPrompt?: string;
}
export function getChainPromptsArgs(
type: 'stuff' | 'map_reduce' | 'refine',
options: ChainTypeOptions,
) {
const chainArgs: SummarizationChainParams = {
type,
};
// Map reduce prompt override
if (type === 'map_reduce') {
const mapReduceArgs = chainArgs as SummarizationChainParams & {
type: 'map_reduce';
};
if (options.combineMapPrompt) {
mapReduceArgs.combineMapPrompt = new PromptTemplate({
template: options.combineMapPrompt,
inputVariables: ['text'],
});
}
if (options.prompt) {
mapReduceArgs.combinePrompt = new PromptTemplate({
template: options.prompt,
inputVariables: ['text'],
});
}
}
// Stuff prompt override
if (type === 'stuff') {
const stuffArgs = chainArgs as SummarizationChainParams & {
type: 'stuff';
};
if (options.prompt) {
stuffArgs.prompt = new PromptTemplate({
template: options.prompt,
inputVariables: ['text'],
});
}
}
// Refine prompt override
if (type === 'refine') {
const refineArgs = chainArgs as SummarizationChainParams & {
type: 'refine';
};
if (options.refinePrompt) {
refineArgs.refinePrompt = new PromptTemplate({
template: options.refinePrompt,
inputVariables: ['existing_answer', 'text'],
});
}
if (options.refineQuestionPrompt) {
refineArgs.questionPrompt = new PromptTemplate({
template: options.refineQuestionPrompt,
inputVariables: ['text'],
});
}
}
return chainArgs;
}