2024-07-11 07:24:03 -07:00
import type {
IDataObject ,
IExecuteFunctions ,
INodeExecutionData ,
INodeParameters ,
INodeType ,
INodeTypeDescription ,
} from 'n8n-workflow' ;
import { NodeConnectionType } from 'n8n-workflow' ;
import type { BaseLanguageModel } from '@langchain/core/language_models/base' ;
import { HumanMessage } from '@langchain/core/messages' ;
import { SystemMessagePromptTemplate , ChatPromptTemplate } from '@langchain/core/prompts' ;
2024-09-05 00:39:44 -07:00
import { OutputFixingParser , StructuredOutputParser } from 'langchain/output_parsers' ;
2024-07-11 07:24:03 -07:00
import { z } from 'zod' ;
import { getTracingConfig } from '../../../utils/tracing' ;
const SYSTEM_PROMPT_TEMPLATE =
"Please classify the text provided by the user into one of the following categories: {categories}, and use the provided formatting instructions below. Don't explain, and only output the json." ;
const configuredOutputs = ( parameters : INodeParameters ) = > {
const categories = ( ( parameters . categories as IDataObject ) ? . categories as IDataObject [ ] ) ? ? [ ] ;
2024-07-12 08:11:08 -07:00
const fallback = ( parameters . options as IDataObject ) ? . fallback as string ;
2024-07-11 07:24:03 -07:00
const ret = categories . map ( ( cat ) = > {
return { type : NodeConnectionType . Main , displayName : cat.category } ;
} ) ;
2024-07-12 08:11:08 -07:00
if ( fallback === 'other' ) ret . push ( { type : NodeConnectionType . Main , displayName : 'Other' } ) ;
2024-07-11 07:24:03 -07:00
return ret ;
} ;
export class TextClassifier implements INodeType {
description : INodeTypeDescription = {
displayName : 'Text Classifier' ,
name : 'textClassifier' ,
icon : 'fa:tags' ,
2024-07-19 05:51:03 -07:00
iconColor : 'black' ,
2024-07-11 07:24:03 -07:00
group : [ 'transform' ] ,
version : 1 ,
description : 'Classify your text into distinct categories' ,
codex : {
categories : [ 'AI' ] ,
subcategories : {
AI : [ 'Chains' , 'Root Nodes' ] ,
} ,
resources : {
primaryDocumentation : [
{
2024-07-30 07:19:47 -07:00
url : 'https://docs.n8n.io/integrations/builtin/cluster-nodes/root-nodes/n8n-nodes-langchain.text-classifier/' ,
2024-07-11 07:24:03 -07:00
} ,
] ,
} ,
} ,
defaults : {
name : 'Text Classifier' ,
} ,
inputs : [
{ displayName : '' , type : NodeConnectionType . Main } ,
{
displayName : 'Model' ,
maxConnections : 1 ,
type : NodeConnectionType . AiLanguageModel ,
required : true ,
} ,
] ,
outputs : ` ={{( ${ configuredOutputs } )( $ parameter)}} ` ,
properties : [
{
displayName : 'Text to Classify' ,
name : 'inputText' ,
type : 'string' ,
required : true ,
default : '' ,
description : 'Use an expression to reference data in previous nodes or enter static text' ,
typeOptions : {
rows : 2 ,
} ,
} ,
{
displayName : 'Categories' ,
name : 'categories' ,
placeholder : 'Add Category' ,
type : 'fixedCollection' ,
default : { } ,
typeOptions : {
multipleValues : true ,
} ,
options : [
{
name : 'categories' ,
displayName : 'Categories' ,
values : [
{
displayName : 'Category' ,
name : 'category' ,
type : 'string' ,
default : '' ,
description : 'Category to add' ,
required : true ,
} ,
{
displayName : 'Description' ,
name : 'description' ,
type : 'string' ,
default : '' ,
description : "Describe your category if it's not obvious" ,
} ,
] ,
} ,
] ,
} ,
{
displayName : 'Options' ,
name : 'options' ,
type : 'collection' ,
default : { } ,
placeholder : 'Add Option' ,
options : [
{
displayName : 'Allow Multiple Classes To Be True' ,
name : 'multiClass' ,
type : 'boolean' ,
default : false ,
} ,
{
2024-07-12 08:11:08 -07:00
displayName : 'When No Clear Match' ,
2024-07-11 07:24:03 -07:00
name : 'fallback' ,
2024-07-12 08:11:08 -07:00
type : 'options' ,
default : 'discard' ,
description : 'What to do with items that don’ t match the categories exactly' ,
options : [
{
name : 'Discard Item' ,
value : 'discard' ,
description : 'Ignore the item and drop it from the output' ,
} ,
{
name : "Output on Extra, 'Other' Branch" ,
value : 'other' ,
description : "Create a separate output branch called 'Other'" ,
} ,
] ,
2024-07-11 07:24:03 -07:00
} ,
{
displayName : 'System Prompt Template' ,
name : 'systemPromptTemplate' ,
type : 'string' ,
default : SYSTEM_PROMPT_TEMPLATE ,
description : 'String to use directly as the system prompt template' ,
typeOptions : {
rows : 6 ,
} ,
} ,
2024-09-05 00:39:44 -07:00
{
displayName : 'Enable Auto-Fixing' ,
name : 'enableAutoFixing' ,
type : 'boolean' ,
default : true ,
description :
'Whether to enable auto-fixing (may trigger an additional LLM call if output is broken)' ,
} ,
2024-07-11 07:24:03 -07:00
] ,
} ,
] ,
} ;
async execute ( this : IExecuteFunctions ) : Promise < INodeExecutionData [ ] [ ] > {
const items = this . getInputData ( ) ;
const llm = ( await this . getInputConnectionData (
NodeConnectionType . AiLanguageModel ,
0 ,
) ) as BaseLanguageModel ;
const categories = this . getNodeParameter ( 'categories.categories' , 0 ) as Array < {
category : string ;
description : string ;
} > ;
const options = this . getNodeParameter ( 'options' , 0 , { } ) as {
multiClass : boolean ;
2024-07-12 08:11:08 -07:00
fallback? : string ;
2024-07-11 07:24:03 -07:00
systemPromptTemplate? : string ;
2024-09-05 00:39:44 -07:00
enableAutoFixing : boolean ;
2024-07-11 07:24:03 -07:00
} ;
const multiClass = options ? . multiClass ? ? false ;
2024-07-12 08:11:08 -07:00
const fallback = options ? . fallback ? ? 'discard' ;
2024-07-11 07:24:03 -07:00
const schemaEntries = categories . map ( ( cat ) = > [
cat . category ,
z
. boolean ( )
. describe (
` Should be true if the input has category " ${ cat . category } " (description: ${ cat . description } ) ` ,
) ,
] ) ;
2024-07-12 08:11:08 -07:00
if ( fallback === 'other' )
2024-07-11 07:24:03 -07:00
schemaEntries . push ( [
'fallback' ,
z . boolean ( ) . describe ( 'Should be true if none of the other categories apply' ) ,
] ) ;
const schema = z . object ( Object . fromEntries ( schemaEntries ) ) ;
2024-09-05 00:39:44 -07:00
const structuredParser = StructuredOutputParser . fromZodSchema ( schema ) ;
const parser = options . enableAutoFixing
? OutputFixingParser . fromLLM ( llm , structuredParser )
: structuredParser ;
2024-07-11 07:24:03 -07:00
const multiClassPrompt = multiClass
? 'Categories are not mutually exclusive, and multiple can be true'
: 'Categories are mutually exclusive, and only one can be true' ;
2024-07-12 08:11:08 -07:00
const fallbackPrompt = {
other : 'If no categories apply, select the "fallback" option.' ,
discard : 'If there is not a very fitting category, select none of the categories.' ,
} [ fallback ] ;
2024-07-11 07:24:03 -07:00
const returnData : INodeExecutionData [ ] [ ] = Array . from (
2024-07-12 08:11:08 -07:00
{ length : categories.length + ( fallback === 'other' ? 1 : 0 ) } ,
2024-07-11 07:24:03 -07:00
( _ ) = > [ ] ,
) ;
for ( let itemIdx = 0 ; itemIdx < items . length ; itemIdx ++ ) {
2024-07-30 07:19:47 -07:00
const item = items [ itemIdx ] ;
item . pairedItem = { item : itemIdx } ;
2024-07-11 07:24:03 -07:00
const input = this . getNodeParameter ( 'inputText' , itemIdx ) as string ;
const inputPrompt = new HumanMessage ( input ) ;
2024-07-30 07:19:47 -07:00
const systemPromptTemplateOpt = this . getNodeParameter (
'options.systemPromptTemplate' ,
itemIdx ,
) as string ;
const systemPromptTemplate = SystemMessagePromptTemplate . fromTemplate (
` ${ systemPromptTemplateOpt ? ? SYSTEM_PROMPT_TEMPLATE }
{ format_instructions }
$ { multiClassPrompt }
$ { fallbackPrompt } ` ,
) ;
2024-07-11 07:24:03 -07:00
const messages = [
await systemPromptTemplate . format ( {
categories : categories.map ( ( cat ) = > cat . category ) . join ( ', ' ) ,
format_instructions : parser.getFormatInstructions ( ) ,
} ) ,
inputPrompt ,
] ;
const prompt = ChatPromptTemplate . fromMessages ( messages ) ;
const chain = prompt . pipe ( llm ) . pipe ( parser ) . withConfig ( getTracingConfig ( this ) ) ;
2024-07-30 07:19:47 -07:00
try {
const output = await chain . invoke ( messages ) ;
categories . forEach ( ( cat , idx ) = > {
if ( output [ cat . category ] ) returnData [ idx ] . push ( item ) ;
} ) ;
if ( fallback === 'other' && output . fallback ) returnData [ returnData . length - 1 ] . push ( item ) ;
} catch ( error ) {
2024-08-30 00:59:30 -07:00
if ( this . continueOnFail ( ) ) {
2024-07-30 07:19:47 -07:00
returnData [ 0 ] . push ( {
json : { error : error.message } ,
pairedItem : { item : itemIdx } ,
} ) ;
continue ;
}
throw error ;
}
2024-07-11 07:24:03 -07:00
}
2024-07-30 07:19:47 -07:00
2024-07-11 07:24:03 -07:00
return returnData ;
}
}