fix(Text Classifier Node): Use proper documentation URL and respect continueOnFail (#10216)

This commit is contained in:
jeanpaul 2024-07-30 16:19:47 +02:00 committed by GitHub
parent 3ccb9df2f9
commit 452f52c124
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -46,7 +46,7 @@ export class TextClassifier implements INodeType {
resources: { resources: {
primaryDocumentation: [ primaryDocumentation: [
{ {
url: 'https://docs.n8n.io/integrations/builtin/cluster-nodes/root-nodes/n8n-nodes-langchain.chainllm/', url: 'https://docs.n8n.io/integrations/builtin/cluster-nodes/root-nodes/n8n-nodes-langchain.text-classifier/',
}, },
], ],
}, },
@ -203,20 +203,27 @@ export class TextClassifier implements INodeType {
discard: 'If there is not a very fitting category, select none of the categories.', discard: 'If there is not a very fitting category, select none of the categories.',
}[fallback]; }[fallback];
const systemPromptTemplate = SystemMessagePromptTemplate.fromTemplate(
`${options.systemPromptTemplate ?? SYSTEM_PROMPT_TEMPLATE}
{format_instructions}
${multiClassPrompt}
${fallbackPrompt}`,
);
const returnData: INodeExecutionData[][] = Array.from( const returnData: INodeExecutionData[][] = Array.from(
{ length: categories.length + (fallback === 'other' ? 1 : 0) }, { length: categories.length + (fallback === 'other' ? 1 : 0) },
(_) => [], (_) => [],
); );
for (let itemIdx = 0; itemIdx < items.length; itemIdx++) { for (let itemIdx = 0; itemIdx < items.length; itemIdx++) {
const item = items[itemIdx];
item.pairedItem = { item: itemIdx };
const input = this.getNodeParameter('inputText', itemIdx) as string; const input = this.getNodeParameter('inputText', itemIdx) as string;
const inputPrompt = new HumanMessage(input); const inputPrompt = new HumanMessage(input);
const systemPromptTemplateOpt = this.getNodeParameter(
'options.systemPromptTemplate',
itemIdx,
) as string;
const systemPromptTemplate = SystemMessagePromptTemplate.fromTemplate(
`${systemPromptTemplateOpt ?? SYSTEM_PROMPT_TEMPLATE}
{format_instructions}
${multiClassPrompt}
${fallbackPrompt}`,
);
const messages = [ const messages = [
await systemPromptTemplate.format({ await systemPromptTemplate.format({
categories: categories.map((cat) => cat.category).join(', '), categories: categories.map((cat) => cat.category).join(', '),
@ -227,13 +234,27 @@ ${fallbackPrompt}`,
const prompt = ChatPromptTemplate.fromMessages(messages); const prompt = ChatPromptTemplate.fromMessages(messages);
const chain = prompt.pipe(llm).pipe(parser).withConfig(getTracingConfig(this)); const chain = prompt.pipe(llm).pipe(parser).withConfig(getTracingConfig(this));
const output = await chain.invoke(messages); try {
categories.forEach((cat, idx) => { const output = await chain.invoke(messages);
if (output[cat.category]) returnData[idx].push(items[itemIdx]);
}); categories.forEach((cat, idx) => {
if (fallback === 'other' && output.fallback) if (output[cat.category]) returnData[idx].push(item);
returnData[returnData.length - 1].push(items[itemIdx]); });
if (fallback === 'other' && output.fallback) returnData[returnData.length - 1].push(item);
} catch (error) {
if (this.continueOnFail(error)) {
returnData[0].push({
json: { error: error.message },
pairedItem: { item: itemIdx },
});
continue;
}
throw error;
}
} }
return returnData; return returnData;
} }
} }