From 28ca7d6a2dd818c8795acda6ddf7329b8621d9de Mon Sep 17 00:00:00 2001 From: jeanpaul Date: Thu, 11 Jul 2024 16:24:03 +0200 Subject: [PATCH] feat(Text Classifier Node): Add Text Classifier Node (#9997) Co-authored-by: oleg --- .../TextClassifier/TextClassifier.node.ts | 223 ++++++++++++++++++ packages/@n8n/nodes-langchain/package.json | 1 + packages/editor-ui/src/plugins/icons/index.ts | 2 + 3 files changed, 226 insertions(+) create mode 100644 packages/@n8n/nodes-langchain/nodes/chains/TextClassifier/TextClassifier.node.ts diff --git a/packages/@n8n/nodes-langchain/nodes/chains/TextClassifier/TextClassifier.node.ts b/packages/@n8n/nodes-langchain/nodes/chains/TextClassifier/TextClassifier.node.ts new file mode 100644 index 0000000000..9c1e9824d9 --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/chains/TextClassifier/TextClassifier.node.ts @@ -0,0 +1,223 @@ +import type { + IDataObject, + IExecuteFunctions, + INodeExecutionData, + INodeParameters, + INodeType, + INodeTypeDescription, +} from 'n8n-workflow'; + +import { NodeConnectionType } from 'n8n-workflow'; + +import type { BaseLanguageModel } from '@langchain/core/language_models/base'; +import { HumanMessage } from '@langchain/core/messages'; +import { SystemMessagePromptTemplate, ChatPromptTemplate } from '@langchain/core/prompts'; +import { StructuredOutputParser } from 'langchain/output_parsers'; +import { z } from 'zod'; +import { getTracingConfig } from '../../../utils/tracing'; + +const SYSTEM_PROMPT_TEMPLATE = + "Please classify the text provided by the user into one of the following categories: {categories}, and use the provided formatting instructions below. Don't explain, and only output the json."; + +const configuredOutputs = (parameters: INodeParameters) => { + const categories = ((parameters.categories as IDataObject)?.categories as IDataObject[]) ?? []; + const fallback = (parameters.options as IDataObject)?.fallback as boolean; + const ret = categories.map((cat) => { + return { type: NodeConnectionType.Main, displayName: cat.category }; + }); + if (fallback) ret.push({ type: NodeConnectionType.Main, displayName: 'Other' }); + return ret; +}; + +export class TextClassifier implements INodeType { + description: INodeTypeDescription = { + displayName: 'Text Classifier', + name: 'textClassifier', + icon: 'fa:tags', + group: ['transform'], + version: 1, + description: 'Classify your text into distinct categories', + codex: { + categories: ['AI'], + subcategories: { + AI: ['Chains', 'Root Nodes'], + }, + resources: { + primaryDocumentation: [ + { + url: 'https://docs.n8n.io/integrations/builtin/cluster-nodes/root-nodes/n8n-nodes-langchain.chainllm/', + }, + ], + }, + }, + defaults: { + name: 'Text Classifier', + }, + inputs: [ + { displayName: '', type: NodeConnectionType.Main }, + { + displayName: 'Model', + maxConnections: 1, + type: NodeConnectionType.AiLanguageModel, + required: true, + }, + ], + outputs: `={{(${configuredOutputs})($parameter)}}`, + properties: [ + { + displayName: 'Text to Classify', + name: 'inputText', + type: 'string', + required: true, + default: '', + description: 'Use an expression to reference data in previous nodes or enter static text', + typeOptions: { + rows: 2, + }, + }, + { + displayName: 'Categories', + name: 'categories', + placeholder: 'Add Category', + type: 'fixedCollection', + default: {}, + typeOptions: { + multipleValues: true, + }, + options: [ + { + name: 'categories', + displayName: 'Categories', + values: [ + { + displayName: 'Category', + name: 'category', + type: 'string', + default: '', + description: 'Category to add', + required: true, + }, + { + displayName: 'Description', + name: 'description', + type: 'string', + default: '', + description: "Describe your category if it's not obvious", + }, + ], + }, + ], + }, + { + displayName: 'Options', + name: 'options', + type: 'collection', + default: {}, + placeholder: 'Add Option', + options: [ + { + displayName: 'Allow Multiple Classes To Be True', + name: 'multiClass', + type: 'boolean', + default: false, + }, + { + displayName: 'Add Fallback Option', + name: 'fallback', + type: 'boolean', + default: false, + description: 'Whether to add a "fallback" option if no other categories match', + }, + { + displayName: 'System Prompt Template', + name: 'systemPromptTemplate', + type: 'string', + default: SYSTEM_PROMPT_TEMPLATE, + description: 'String to use directly as the system prompt template', + typeOptions: { + rows: 6, + }, + }, + ], + }, + ], + }; + + async execute(this: IExecuteFunctions): Promise { + const items = this.getInputData(); + + const llm = (await this.getInputConnectionData( + NodeConnectionType.AiLanguageModel, + 0, + )) as BaseLanguageModel; + + const categories = this.getNodeParameter('categories.categories', 0) as Array<{ + category: string; + description: string; + }>; + + const options = this.getNodeParameter('options', 0, {}) as { + multiClass: boolean; + fallback: boolean; + systemPromptTemplate?: string; + }; + const multiClass = options?.multiClass ?? false; + const fallback = options?.fallback ?? false; + + const schemaEntries = categories.map((cat) => [ + cat.category, + z + .boolean() + .describe( + `Should be true if the input has category "${cat.category}" (description: ${cat.description})`, + ), + ]); + if (fallback) + schemaEntries.push([ + 'fallback', + z.boolean().describe('Should be true if none of the other categories apply'), + ]); + const schema = z.object(Object.fromEntries(schemaEntries)); + + const parser = StructuredOutputParser.fromZodSchema(schema); + + const multiClassPrompt = multiClass + ? 'Categories are not mutually exclusive, and multiple can be true' + : 'Categories are mutually exclusive, and only one can be true'; + const fallbackPrompt = fallback + ? 'If no categories apply, select the "fallback" option.' + : 'One of the options must always be true.'; + + const systemPromptTemplate = SystemMessagePromptTemplate.fromTemplate( + `${options.systemPromptTemplate ?? SYSTEM_PROMPT_TEMPLATE} +{format_instructions} +${multiClassPrompt} +${fallbackPrompt}`, + ); + + const returnData: INodeExecutionData[][] = Array.from( + { length: categories.length + (fallback ? 1 : 0) }, + (_) => [], + ); + for (let itemIdx = 0; itemIdx < items.length; itemIdx++) { + const input = this.getNodeParameter('inputText', itemIdx) as string; + const inputPrompt = new HumanMessage(input); + const messages = [ + await systemPromptTemplate.format({ + categories: categories.map((cat) => cat.category).join(', '), + format_instructions: parser.getFormatInstructions(), + }), + inputPrompt, + ]; + const prompt = ChatPromptTemplate.fromMessages(messages); + const chain = prompt.pipe(llm).pipe(parser).withConfig(getTracingConfig(this)); + + const output = await chain.invoke(messages); + categories.forEach((cat, idx) => { + if (output[cat.category]) returnData[idx].push(items[itemIdx]); + }); + if (fallback && output.fallback) returnData[returnData.length - 1].push(items[itemIdx]); + } + return returnData; + } +} diff --git a/packages/@n8n/nodes-langchain/package.json b/packages/@n8n/nodes-langchain/package.json index 2ca142f2f1..74f3e92016 100644 --- a/packages/@n8n/nodes-langchain/package.json +++ b/packages/@n8n/nodes-langchain/package.json @@ -45,6 +45,7 @@ "dist/nodes/chains/ChainSummarization/ChainSummarization.node.js", "dist/nodes/chains/ChainLLM/ChainLlm.node.js", "dist/nodes/chains/ChainRetrievalQA/ChainRetrievalQa.node.js", + "dist/nodes/chains/TextClassifier/TextClassifier.node.js", "dist/nodes/code/Code.node.js", "dist/nodes/document_loaders/DocumentDefaultDataLoader/DocumentDefaultDataLoader.node.js", "dist/nodes/document_loaders/DocumentBinaryInputLoader/DocumentBinaryInputLoader.node.js", diff --git a/packages/editor-ui/src/plugins/icons/index.ts b/packages/editor-ui/src/plugins/icons/index.ts index 4f86547847..6eeaef0a07 100644 --- a/packages/editor-ui/src/plugins/icons/index.ts +++ b/packages/editor-ui/src/plugins/icons/index.ts @@ -129,6 +129,7 @@ import { faSync, faSyncAlt, faTable, + faTags, faTasks, faTerminal, faThLarge, @@ -300,6 +301,7 @@ export const FontAwesomePlugin: Plugin = { addIcon(faSync); addIcon(faSyncAlt); addIcon(faTable); + addIcon(faTags); addIcon(faTasks); addIcon(faTerminal); addIcon(faThLarge);