2023-12-08 04:42:32 -08:00
|
|
|
import type { SummarizationChainParams } from 'langchain/chains';
|
2024-03-07 02:36:36 -08:00
|
|
|
import { PromptTemplate } from '@langchain/core/prompts';
|
2023-12-08 04:42:32 -08:00
|
|
|
interface ChainTypeOptions {
|
|
|
|
combineMapPrompt?: string;
|
|
|
|
prompt?: string;
|
|
|
|
refinePrompt?: string;
|
|
|
|
refineQuestionPrompt?: string;
|
|
|
|
}
|
|
|
|
|
|
|
|
export function getChainPromptsArgs(
|
|
|
|
type: 'stuff' | 'map_reduce' | 'refine',
|
|
|
|
options: ChainTypeOptions,
|
|
|
|
) {
|
|
|
|
const chainArgs: SummarizationChainParams = {
|
|
|
|
type,
|
|
|
|
};
|
|
|
|
// Map reduce prompt override
|
|
|
|
if (type === 'map_reduce') {
|
|
|
|
const mapReduceArgs = chainArgs as SummarizationChainParams & {
|
|
|
|
type: 'map_reduce';
|
|
|
|
};
|
|
|
|
if (options.combineMapPrompt) {
|
|
|
|
mapReduceArgs.combineMapPrompt = new PromptTemplate({
|
|
|
|
template: options.combineMapPrompt,
|
|
|
|
inputVariables: ['text'],
|
|
|
|
});
|
|
|
|
}
|
|
|
|
if (options.prompt) {
|
|
|
|
mapReduceArgs.combinePrompt = new PromptTemplate({
|
|
|
|
template: options.prompt,
|
|
|
|
inputVariables: ['text'],
|
|
|
|
});
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Stuff prompt override
|
|
|
|
if (type === 'stuff') {
|
|
|
|
const stuffArgs = chainArgs as SummarizationChainParams & {
|
|
|
|
type: 'stuff';
|
|
|
|
};
|
|
|
|
if (options.prompt) {
|
|
|
|
stuffArgs.prompt = new PromptTemplate({
|
|
|
|
template: options.prompt,
|
|
|
|
inputVariables: ['text'],
|
|
|
|
});
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Refine prompt override
|
|
|
|
if (type === 'refine') {
|
|
|
|
const refineArgs = chainArgs as SummarizationChainParams & {
|
|
|
|
type: 'refine';
|
|
|
|
};
|
|
|
|
|
|
|
|
if (options.refinePrompt) {
|
|
|
|
refineArgs.refinePrompt = new PromptTemplate({
|
|
|
|
template: options.refinePrompt,
|
|
|
|
inputVariables: ['existing_answer', 'text'],
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
|
|
|
if (options.refineQuestionPrompt) {
|
|
|
|
refineArgs.questionPrompt = new PromptTemplate({
|
|
|
|
template: options.refineQuestionPrompt,
|
|
|
|
inputVariables: ['text'],
|
|
|
|
});
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return chainArgs;
|
|
|
|
}
|