mirror of
https://github.com/n8n-io/n8n.git
synced 2025-03-05 20:50:17 -08:00
feat(Text Classifier Node): Add Text Classifier Node (#9997)
Co-authored-by: oleg <me@olegivaniv.com>
This commit is contained in:
parent
4a3b97cede
commit
28ca7d6a2d
|
@ -0,0 +1,223 @@
|
||||||
|
import type {
|
||||||
|
IDataObject,
|
||||||
|
IExecuteFunctions,
|
||||||
|
INodeExecutionData,
|
||||||
|
INodeParameters,
|
||||||
|
INodeType,
|
||||||
|
INodeTypeDescription,
|
||||||
|
} from 'n8n-workflow';
|
||||||
|
|
||||||
|
import { NodeConnectionType } from 'n8n-workflow';
|
||||||
|
|
||||||
|
import type { BaseLanguageModel } from '@langchain/core/language_models/base';
|
||||||
|
import { HumanMessage } from '@langchain/core/messages';
|
||||||
|
import { SystemMessagePromptTemplate, ChatPromptTemplate } from '@langchain/core/prompts';
|
||||||
|
import { StructuredOutputParser } from 'langchain/output_parsers';
|
||||||
|
import { z } from 'zod';
|
||||||
|
import { getTracingConfig } from '../../../utils/tracing';
|
||||||
|
|
||||||
|
const SYSTEM_PROMPT_TEMPLATE =
|
||||||
|
"Please classify the text provided by the user into one of the following categories: {categories}, and use the provided formatting instructions below. Don't explain, and only output the json.";
|
||||||
|
|
||||||
|
const configuredOutputs = (parameters: INodeParameters) => {
|
||||||
|
const categories = ((parameters.categories as IDataObject)?.categories as IDataObject[]) ?? [];
|
||||||
|
const fallback = (parameters.options as IDataObject)?.fallback as boolean;
|
||||||
|
const ret = categories.map((cat) => {
|
||||||
|
return { type: NodeConnectionType.Main, displayName: cat.category };
|
||||||
|
});
|
||||||
|
if (fallback) ret.push({ type: NodeConnectionType.Main, displayName: 'Other' });
|
||||||
|
return ret;
|
||||||
|
};
|
||||||
|
|
||||||
|
export class TextClassifier implements INodeType {
|
||||||
|
description: INodeTypeDescription = {
|
||||||
|
displayName: 'Text Classifier',
|
||||||
|
name: 'textClassifier',
|
||||||
|
icon: 'fa:tags',
|
||||||
|
group: ['transform'],
|
||||||
|
version: 1,
|
||||||
|
description: 'Classify your text into distinct categories',
|
||||||
|
codex: {
|
||||||
|
categories: ['AI'],
|
||||||
|
subcategories: {
|
||||||
|
AI: ['Chains', 'Root Nodes'],
|
||||||
|
},
|
||||||
|
resources: {
|
||||||
|
primaryDocumentation: [
|
||||||
|
{
|
||||||
|
url: 'https://docs.n8n.io/integrations/builtin/cluster-nodes/root-nodes/n8n-nodes-langchain.chainllm/',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
defaults: {
|
||||||
|
name: 'Text Classifier',
|
||||||
|
},
|
||||||
|
inputs: [
|
||||||
|
{ displayName: '', type: NodeConnectionType.Main },
|
||||||
|
{
|
||||||
|
displayName: 'Model',
|
||||||
|
maxConnections: 1,
|
||||||
|
type: NodeConnectionType.AiLanguageModel,
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
outputs: `={{(${configuredOutputs})($parameter)}}`,
|
||||||
|
properties: [
|
||||||
|
{
|
||||||
|
displayName: 'Text to Classify',
|
||||||
|
name: 'inputText',
|
||||||
|
type: 'string',
|
||||||
|
required: true,
|
||||||
|
default: '',
|
||||||
|
description: 'Use an expression to reference data in previous nodes or enter static text',
|
||||||
|
typeOptions: {
|
||||||
|
rows: 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
displayName: 'Categories',
|
||||||
|
name: 'categories',
|
||||||
|
placeholder: 'Add Category',
|
||||||
|
type: 'fixedCollection',
|
||||||
|
default: {},
|
||||||
|
typeOptions: {
|
||||||
|
multipleValues: true,
|
||||||
|
},
|
||||||
|
options: [
|
||||||
|
{
|
||||||
|
name: 'categories',
|
||||||
|
displayName: 'Categories',
|
||||||
|
values: [
|
||||||
|
{
|
||||||
|
displayName: 'Category',
|
||||||
|
name: 'category',
|
||||||
|
type: 'string',
|
||||||
|
default: '',
|
||||||
|
description: 'Category to add',
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
displayName: 'Description',
|
||||||
|
name: 'description',
|
||||||
|
type: 'string',
|
||||||
|
default: '',
|
||||||
|
description: "Describe your category if it's not obvious",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
displayName: 'Options',
|
||||||
|
name: 'options',
|
||||||
|
type: 'collection',
|
||||||
|
default: {},
|
||||||
|
placeholder: 'Add Option',
|
||||||
|
options: [
|
||||||
|
{
|
||||||
|
displayName: 'Allow Multiple Classes To Be True',
|
||||||
|
name: 'multiClass',
|
||||||
|
type: 'boolean',
|
||||||
|
default: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
displayName: 'Add Fallback Option',
|
||||||
|
name: 'fallback',
|
||||||
|
type: 'boolean',
|
||||||
|
default: false,
|
||||||
|
description: 'Whether to add a "fallback" option if no other categories match',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
displayName: 'System Prompt Template',
|
||||||
|
name: 'systemPromptTemplate',
|
||||||
|
type: 'string',
|
||||||
|
default: SYSTEM_PROMPT_TEMPLATE,
|
||||||
|
description: 'String to use directly as the system prompt template',
|
||||||
|
typeOptions: {
|
||||||
|
rows: 6,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
async execute(this: IExecuteFunctions): Promise<INodeExecutionData[][]> {
|
||||||
|
const items = this.getInputData();
|
||||||
|
|
||||||
|
const llm = (await this.getInputConnectionData(
|
||||||
|
NodeConnectionType.AiLanguageModel,
|
||||||
|
0,
|
||||||
|
)) as BaseLanguageModel;
|
||||||
|
|
||||||
|
const categories = this.getNodeParameter('categories.categories', 0) as Array<{
|
||||||
|
category: string;
|
||||||
|
description: string;
|
||||||
|
}>;
|
||||||
|
|
||||||
|
const options = this.getNodeParameter('options', 0, {}) as {
|
||||||
|
multiClass: boolean;
|
||||||
|
fallback: boolean;
|
||||||
|
systemPromptTemplate?: string;
|
||||||
|
};
|
||||||
|
const multiClass = options?.multiClass ?? false;
|
||||||
|
const fallback = options?.fallback ?? false;
|
||||||
|
|
||||||
|
const schemaEntries = categories.map((cat) => [
|
||||||
|
cat.category,
|
||||||
|
z
|
||||||
|
.boolean()
|
||||||
|
.describe(
|
||||||
|
`Should be true if the input has category "${cat.category}" (description: ${cat.description})`,
|
||||||
|
),
|
||||||
|
]);
|
||||||
|
if (fallback)
|
||||||
|
schemaEntries.push([
|
||||||
|
'fallback',
|
||||||
|
z.boolean().describe('Should be true if none of the other categories apply'),
|
||||||
|
]);
|
||||||
|
const schema = z.object(Object.fromEntries(schemaEntries));
|
||||||
|
|
||||||
|
const parser = StructuredOutputParser.fromZodSchema(schema);
|
||||||
|
|
||||||
|
const multiClassPrompt = multiClass
|
||||||
|
? 'Categories are not mutually exclusive, and multiple can be true'
|
||||||
|
: 'Categories are mutually exclusive, and only one can be true';
|
||||||
|
const fallbackPrompt = fallback
|
||||||
|
? 'If no categories apply, select the "fallback" option.'
|
||||||
|
: 'One of the options must always be true.';
|
||||||
|
|
||||||
|
const systemPromptTemplate = SystemMessagePromptTemplate.fromTemplate(
|
||||||
|
`${options.systemPromptTemplate ?? SYSTEM_PROMPT_TEMPLATE}
|
||||||
|
{format_instructions}
|
||||||
|
${multiClassPrompt}
|
||||||
|
${fallbackPrompt}`,
|
||||||
|
);
|
||||||
|
|
||||||
|
const returnData: INodeExecutionData[][] = Array.from(
|
||||||
|
{ length: categories.length + (fallback ? 1 : 0) },
|
||||||
|
(_) => [],
|
||||||
|
);
|
||||||
|
for (let itemIdx = 0; itemIdx < items.length; itemIdx++) {
|
||||||
|
const input = this.getNodeParameter('inputText', itemIdx) as string;
|
||||||
|
const inputPrompt = new HumanMessage(input);
|
||||||
|
const messages = [
|
||||||
|
await systemPromptTemplate.format({
|
||||||
|
categories: categories.map((cat) => cat.category).join(', '),
|
||||||
|
format_instructions: parser.getFormatInstructions(),
|
||||||
|
}),
|
||||||
|
inputPrompt,
|
||||||
|
];
|
||||||
|
const prompt = ChatPromptTemplate.fromMessages(messages);
|
||||||
|
const chain = prompt.pipe(llm).pipe(parser).withConfig(getTracingConfig(this));
|
||||||
|
|
||||||
|
const output = await chain.invoke(messages);
|
||||||
|
categories.forEach((cat, idx) => {
|
||||||
|
if (output[cat.category]) returnData[idx].push(items[itemIdx]);
|
||||||
|
});
|
||||||
|
if (fallback && output.fallback) returnData[returnData.length - 1].push(items[itemIdx]);
|
||||||
|
}
|
||||||
|
return returnData;
|
||||||
|
}
|
||||||
|
}
|
|
@ -45,6 +45,7 @@
|
||||||
"dist/nodes/chains/ChainSummarization/ChainSummarization.node.js",
|
"dist/nodes/chains/ChainSummarization/ChainSummarization.node.js",
|
||||||
"dist/nodes/chains/ChainLLM/ChainLlm.node.js",
|
"dist/nodes/chains/ChainLLM/ChainLlm.node.js",
|
||||||
"dist/nodes/chains/ChainRetrievalQA/ChainRetrievalQa.node.js",
|
"dist/nodes/chains/ChainRetrievalQA/ChainRetrievalQa.node.js",
|
||||||
|
"dist/nodes/chains/TextClassifier/TextClassifier.node.js",
|
||||||
"dist/nodes/code/Code.node.js",
|
"dist/nodes/code/Code.node.js",
|
||||||
"dist/nodes/document_loaders/DocumentDefaultDataLoader/DocumentDefaultDataLoader.node.js",
|
"dist/nodes/document_loaders/DocumentDefaultDataLoader/DocumentDefaultDataLoader.node.js",
|
||||||
"dist/nodes/document_loaders/DocumentBinaryInputLoader/DocumentBinaryInputLoader.node.js",
|
"dist/nodes/document_loaders/DocumentBinaryInputLoader/DocumentBinaryInputLoader.node.js",
|
||||||
|
|
|
@ -129,6 +129,7 @@ import {
|
||||||
faSync,
|
faSync,
|
||||||
faSyncAlt,
|
faSyncAlt,
|
||||||
faTable,
|
faTable,
|
||||||
|
faTags,
|
||||||
faTasks,
|
faTasks,
|
||||||
faTerminal,
|
faTerminal,
|
||||||
faThLarge,
|
faThLarge,
|
||||||
|
@ -300,6 +301,7 @@ export const FontAwesomePlugin: Plugin = {
|
||||||
addIcon(faSync);
|
addIcon(faSync);
|
||||||
addIcon(faSyncAlt);
|
addIcon(faSyncAlt);
|
||||||
addIcon(faTable);
|
addIcon(faTable);
|
||||||
|
addIcon(faTags);
|
||||||
addIcon(faTasks);
|
addIcon(faTasks);
|
||||||
addIcon(faTerminal);
|
addIcon(faTerminal);
|
||||||
addIcon(faThLarge);
|
addIcon(faThLarge);
|
||||||
|
|
Loading…
Reference in a new issue