n8n/packages/@n8n/nodes-langchain/nodes/chains/TextClassifier/TextClassifier.node.ts

261 lines
7.2 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import type {
IDataObject,
IExecuteFunctions,
INodeExecutionData,
INodeParameters,
INodeType,
INodeTypeDescription,
} from 'n8n-workflow';
import { NodeConnectionType } from 'n8n-workflow';
import type { BaseLanguageModel } from '@langchain/core/language_models/base';
import { HumanMessage } from '@langchain/core/messages';
import { SystemMessagePromptTemplate, ChatPromptTemplate } from '@langchain/core/prompts';
import { StructuredOutputParser } from 'langchain/output_parsers';
import { z } from 'zod';
import { getTracingConfig } from '../../../utils/tracing';
const SYSTEM_PROMPT_TEMPLATE =
"Please classify the text provided by the user into one of the following categories: {categories}, and use the provided formatting instructions below. Don't explain, and only output the json.";
const configuredOutputs = (parameters: INodeParameters) => {
const categories = ((parameters.categories as IDataObject)?.categories as IDataObject[]) ?? [];
const fallback = (parameters.options as IDataObject)?.fallback as string;
const ret = categories.map((cat) => {
return { type: NodeConnectionType.Main, displayName: cat.category };
});
if (fallback === 'other') ret.push({ type: NodeConnectionType.Main, displayName: 'Other' });
return ret;
};
export class TextClassifier implements INodeType {
description: INodeTypeDescription = {
displayName: 'Text Classifier',
name: 'textClassifier',
icon: 'fa:tags',
iconColor: 'black',
group: ['transform'],
version: 1,
description: 'Classify your text into distinct categories',
codex: {
categories: ['AI'],
subcategories: {
AI: ['Chains', 'Root Nodes'],
},
resources: {
primaryDocumentation: [
{
url: 'https://docs.n8n.io/integrations/builtin/cluster-nodes/root-nodes/n8n-nodes-langchain.text-classifier/',
},
],
},
},
defaults: {
name: 'Text Classifier',
},
inputs: [
{ displayName: '', type: NodeConnectionType.Main },
{
displayName: 'Model',
maxConnections: 1,
type: NodeConnectionType.AiLanguageModel,
required: true,
},
],
outputs: `={{(${configuredOutputs})($parameter)}}`,
properties: [
{
displayName: 'Text to Classify',
name: 'inputText',
type: 'string',
required: true,
default: '',
description: 'Use an expression to reference data in previous nodes or enter static text',
typeOptions: {
rows: 2,
},
},
{
displayName: 'Categories',
name: 'categories',
placeholder: 'Add Category',
type: 'fixedCollection',
default: {},
typeOptions: {
multipleValues: true,
},
options: [
{
name: 'categories',
displayName: 'Categories',
values: [
{
displayName: 'Category',
name: 'category',
type: 'string',
default: '',
description: 'Category to add',
required: true,
},
{
displayName: 'Description',
name: 'description',
type: 'string',
default: '',
description: "Describe your category if it's not obvious",
},
],
},
],
},
{
displayName: 'Options',
name: 'options',
type: 'collection',
default: {},
placeholder: 'Add Option',
options: [
{
displayName: 'Allow Multiple Classes To Be True',
name: 'multiClass',
type: 'boolean',
default: false,
},
{
displayName: 'When No Clear Match',
name: 'fallback',
type: 'options',
default: 'discard',
description: 'What to do with items that dont match the categories exactly',
options: [
{
name: 'Discard Item',
value: 'discard',
description: 'Ignore the item and drop it from the output',
},
{
name: "Output on Extra, 'Other' Branch",
value: 'other',
description: "Create a separate output branch called 'Other'",
},
],
},
{
displayName: 'System Prompt Template',
name: 'systemPromptTemplate',
type: 'string',
default: SYSTEM_PROMPT_TEMPLATE,
description: 'String to use directly as the system prompt template',
typeOptions: {
rows: 6,
},
},
],
},
],
};
async execute(this: IExecuteFunctions): Promise<INodeExecutionData[][]> {
const items = this.getInputData();
const llm = (await this.getInputConnectionData(
NodeConnectionType.AiLanguageModel,
0,
)) as BaseLanguageModel;
const categories = this.getNodeParameter('categories.categories', 0) as Array<{
category: string;
description: string;
}>;
const options = this.getNodeParameter('options', 0, {}) as {
multiClass: boolean;
fallback?: string;
systemPromptTemplate?: string;
};
const multiClass = options?.multiClass ?? false;
const fallback = options?.fallback ?? 'discard';
const schemaEntries = categories.map((cat) => [
cat.category,
z
.boolean()
.describe(
`Should be true if the input has category "${cat.category}" (description: ${cat.description})`,
),
]);
if (fallback === 'other')
schemaEntries.push([
'fallback',
z.boolean().describe('Should be true if none of the other categories apply'),
]);
const schema = z.object(Object.fromEntries(schemaEntries));
const parser = StructuredOutputParser.fromZodSchema(schema);
const multiClassPrompt = multiClass
? 'Categories are not mutually exclusive, and multiple can be true'
: 'Categories are mutually exclusive, and only one can be true';
const fallbackPrompt = {
other: 'If no categories apply, select the "fallback" option.',
discard: 'If there is not a very fitting category, select none of the categories.',
}[fallback];
const returnData: INodeExecutionData[][] = Array.from(
{ length: categories.length + (fallback === 'other' ? 1 : 0) },
(_) => [],
);
for (let itemIdx = 0; itemIdx < items.length; itemIdx++) {
const item = items[itemIdx];
item.pairedItem = { item: itemIdx };
const input = this.getNodeParameter('inputText', itemIdx) as string;
const inputPrompt = new HumanMessage(input);
const systemPromptTemplateOpt = this.getNodeParameter(
'options.systemPromptTemplate',
itemIdx,
) as string;
const systemPromptTemplate = SystemMessagePromptTemplate.fromTemplate(
`${systemPromptTemplateOpt ?? SYSTEM_PROMPT_TEMPLATE}
{format_instructions}
${multiClassPrompt}
${fallbackPrompt}`,
);
const messages = [
await systemPromptTemplate.format({
categories: categories.map((cat) => cat.category).join(', '),
format_instructions: parser.getFormatInstructions(),
}),
inputPrompt,
];
const prompt = ChatPromptTemplate.fromMessages(messages);
const chain = prompt.pipe(llm).pipe(parser).withConfig(getTracingConfig(this));
try {
const output = await chain.invoke(messages);
categories.forEach((cat, idx) => {
if (output[cat.category]) returnData[idx].push(item);
});
if (fallback === 'other' && output.fallback) returnData[returnData.length - 1].push(item);
} catch (error) {
if (this.continueOnFail(error)) {
returnData[0].push({
json: { error: error.message },
pairedItem: { item: itemIdx },
});
continue;
}
throw error;
}
}
return returnData;
}
}