mirror of
https://github.com/n8n-io/n8n.git
synced 2025-01-05 18:07:27 -08:00
278 lines
7.8 KiB
TypeScript
278 lines
7.8 KiB
TypeScript
import type { BaseLanguageModel } from '@langchain/core/language_models/base';
|
||
import { HumanMessage } from '@langchain/core/messages';
|
||
import { SystemMessagePromptTemplate, ChatPromptTemplate } from '@langchain/core/prompts';
|
||
import { OutputFixingParser, StructuredOutputParser } from 'langchain/output_parsers';
|
||
import { NodeOperationError, NodeConnectionType } from 'n8n-workflow';
|
||
import type {
|
||
IDataObject,
|
||
IExecuteFunctions,
|
||
INodeExecutionData,
|
||
INodeParameters,
|
||
INodeType,
|
||
INodeTypeDescription,
|
||
} from 'n8n-workflow';
|
||
import { z } from 'zod';
|
||
|
||
import { getTracingConfig } from '../../../utils/tracing';
|
||
|
||
const SYSTEM_PROMPT_TEMPLATE =
|
||
"Please classify the text provided by the user into one of the following categories: {categories}, and use the provided formatting instructions below. Don't explain, and only output the json.";
|
||
|
||
const configuredOutputs = (parameters: INodeParameters) => {
|
||
const categories = ((parameters.categories as IDataObject)?.categories as IDataObject[]) ?? [];
|
||
const fallback = (parameters.options as IDataObject)?.fallback as string;
|
||
const ret = categories.map((cat) => {
|
||
return { type: NodeConnectionType.Main, displayName: cat.category };
|
||
});
|
||
if (fallback === 'other') ret.push({ type: NodeConnectionType.Main, displayName: 'Other' });
|
||
return ret;
|
||
};
|
||
|
||
export class TextClassifier implements INodeType {
|
||
description: INodeTypeDescription = {
|
||
displayName: 'Text Classifier',
|
||
name: 'textClassifier',
|
||
icon: 'fa:tags',
|
||
iconColor: 'black',
|
||
group: ['transform'],
|
||
version: 1,
|
||
description: 'Classify your text into distinct categories',
|
||
codex: {
|
||
categories: ['AI'],
|
||
subcategories: {
|
||
AI: ['Chains', 'Root Nodes'],
|
||
},
|
||
resources: {
|
||
primaryDocumentation: [
|
||
{
|
||
url: 'https://docs.n8n.io/integrations/builtin/cluster-nodes/root-nodes/n8n-nodes-langchain.text-classifier/',
|
||
},
|
||
],
|
||
},
|
||
},
|
||
defaults: {
|
||
name: 'Text Classifier',
|
||
},
|
||
inputs: [
|
||
{ displayName: '', type: NodeConnectionType.Main },
|
||
{
|
||
displayName: 'Model',
|
||
maxConnections: 1,
|
||
type: NodeConnectionType.AiLanguageModel,
|
||
required: true,
|
||
},
|
||
],
|
||
outputs: `={{(${configuredOutputs})($parameter)}}`,
|
||
properties: [
|
||
{
|
||
displayName: 'Text to Classify',
|
||
name: 'inputText',
|
||
type: 'string',
|
||
required: true,
|
||
default: '',
|
||
description: 'Use an expression to reference data in previous nodes or enter static text',
|
||
typeOptions: {
|
||
rows: 2,
|
||
},
|
||
},
|
||
{
|
||
displayName: 'Categories',
|
||
name: 'categories',
|
||
placeholder: 'Add Category',
|
||
type: 'fixedCollection',
|
||
default: {},
|
||
typeOptions: {
|
||
multipleValues: true,
|
||
},
|
||
options: [
|
||
{
|
||
name: 'categories',
|
||
displayName: 'Categories',
|
||
values: [
|
||
{
|
||
displayName: 'Category',
|
||
name: 'category',
|
||
type: 'string',
|
||
default: '',
|
||
description: 'Category to add',
|
||
required: true,
|
||
},
|
||
{
|
||
displayName: 'Description',
|
||
name: 'description',
|
||
type: 'string',
|
||
default: '',
|
||
description: "Describe your category if it's not obvious",
|
||
},
|
||
],
|
||
},
|
||
],
|
||
},
|
||
{
|
||
displayName: 'Options',
|
||
name: 'options',
|
||
type: 'collection',
|
||
default: {},
|
||
placeholder: 'Add Option',
|
||
options: [
|
||
{
|
||
displayName: 'Allow Multiple Classes To Be True',
|
||
name: 'multiClass',
|
||
type: 'boolean',
|
||
default: false,
|
||
},
|
||
{
|
||
displayName: 'When No Clear Match',
|
||
name: 'fallback',
|
||
type: 'options',
|
||
default: 'discard',
|
||
description: 'What to do with items that don’t match the categories exactly',
|
||
options: [
|
||
{
|
||
name: 'Discard Item',
|
||
value: 'discard',
|
||
description: 'Ignore the item and drop it from the output',
|
||
},
|
||
{
|
||
name: "Output on Extra, 'Other' Branch",
|
||
value: 'other',
|
||
description: "Create a separate output branch called 'Other'",
|
||
},
|
||
],
|
||
},
|
||
{
|
||
displayName: 'System Prompt Template',
|
||
name: 'systemPromptTemplate',
|
||
type: 'string',
|
||
default: SYSTEM_PROMPT_TEMPLATE,
|
||
description: 'String to use directly as the system prompt template',
|
||
typeOptions: {
|
||
rows: 6,
|
||
},
|
||
},
|
||
{
|
||
displayName: 'Enable Auto-Fixing',
|
||
name: 'enableAutoFixing',
|
||
type: 'boolean',
|
||
default: true,
|
||
description:
|
||
'Whether to enable auto-fixing (may trigger an additional LLM call if output is broken)',
|
||
},
|
||
],
|
||
},
|
||
],
|
||
};
|
||
|
||
async execute(this: IExecuteFunctions): Promise<INodeExecutionData[][]> {
|
||
const items = this.getInputData();
|
||
|
||
const llm = (await this.getInputConnectionData(
|
||
NodeConnectionType.AiLanguageModel,
|
||
0,
|
||
)) as BaseLanguageModel;
|
||
|
||
const categories = this.getNodeParameter('categories.categories', 0, []) as Array<{
|
||
category: string;
|
||
description: string;
|
||
}>;
|
||
|
||
if (categories.length === 0) {
|
||
throw new NodeOperationError(this.getNode(), 'At least one category must be defined');
|
||
}
|
||
|
||
const options = this.getNodeParameter('options', 0, {}) as {
|
||
multiClass: boolean;
|
||
fallback?: string;
|
||
systemPromptTemplate?: string;
|
||
enableAutoFixing: boolean;
|
||
};
|
||
const multiClass = options?.multiClass ?? false;
|
||
const fallback = options?.fallback ?? 'discard';
|
||
|
||
const schemaEntries = categories.map((cat) => [
|
||
cat.category,
|
||
z
|
||
.boolean()
|
||
.describe(
|
||
`Should be true if the input has category "${cat.category}" (description: ${cat.description})`,
|
||
),
|
||
]);
|
||
if (fallback === 'other')
|
||
schemaEntries.push([
|
||
'fallback',
|
||
z.boolean().describe('Should be true if none of the other categories apply'),
|
||
]);
|
||
const schema = z.object(Object.fromEntries(schemaEntries));
|
||
|
||
const structuredParser = StructuredOutputParser.fromZodSchema(schema);
|
||
|
||
const parser = options.enableAutoFixing
|
||
? OutputFixingParser.fromLLM(llm, structuredParser)
|
||
: structuredParser;
|
||
|
||
const multiClassPrompt = multiClass
|
||
? 'Categories are not mutually exclusive, and multiple can be true'
|
||
: 'Categories are mutually exclusive, and only one can be true';
|
||
|
||
const fallbackPrompt = {
|
||
other: 'If no categories apply, select the "fallback" option.',
|
||
discard: 'If there is not a very fitting category, select none of the categories.',
|
||
}[fallback];
|
||
|
||
const returnData: INodeExecutionData[][] = Array.from(
|
||
{ length: categories.length + (fallback === 'other' ? 1 : 0) },
|
||
(_) => [],
|
||
);
|
||
for (let itemIdx = 0; itemIdx < items.length; itemIdx++) {
|
||
const item = items[itemIdx];
|
||
item.pairedItem = { item: itemIdx };
|
||
const input = this.getNodeParameter('inputText', itemIdx) as string;
|
||
const inputPrompt = new HumanMessage(input);
|
||
|
||
const systemPromptTemplateOpt = this.getNodeParameter(
|
||
'options.systemPromptTemplate',
|
||
itemIdx,
|
||
SYSTEM_PROMPT_TEMPLATE,
|
||
) as string;
|
||
const systemPromptTemplate = SystemMessagePromptTemplate.fromTemplate(
|
||
`${systemPromptTemplateOpt ?? SYSTEM_PROMPT_TEMPLATE}
|
||
{format_instructions}
|
||
${multiClassPrompt}
|
||
${fallbackPrompt}`,
|
||
);
|
||
|
||
const messages = [
|
||
await systemPromptTemplate.format({
|
||
categories: categories.map((cat) => cat.category).join(', '),
|
||
format_instructions: parser.getFormatInstructions(),
|
||
}),
|
||
inputPrompt,
|
||
];
|
||
const prompt = ChatPromptTemplate.fromMessages(messages);
|
||
const chain = prompt.pipe(llm).pipe(parser).withConfig(getTracingConfig(this));
|
||
|
||
try {
|
||
const output = await chain.invoke(messages);
|
||
|
||
categories.forEach((cat, idx) => {
|
||
if (output[cat.category]) returnData[idx].push(item);
|
||
});
|
||
if (fallback === 'other' && output.fallback) returnData[returnData.length - 1].push(item);
|
||
} catch (error) {
|
||
if (this.continueOnFail()) {
|
||
returnData[0].push({
|
||
json: { error: error.message },
|
||
pairedItem: { item: itemIdx },
|
||
});
|
||
|
||
continue;
|
||
}
|
||
|
||
throw error;
|
||
}
|
||
}
|
||
|
||
return returnData;
|
||
}
|
||
}
|