feat(Information Extractor Node): Add new simplified AI-node for information extraction (#10149)

This commit is contained in:
Eugene 2024-07-25 14:47:18 +02:00 committed by GitHub
parent 49c7306feb
commit 3d235b0b2d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 566 additions and 0 deletions

View file

@ -0,0 +1,308 @@
import { jsonParse, NodeConnectionType, NodeOperationError } from 'n8n-workflow';
import type {
INodeType,
INodeTypeDescription,
IExecuteFunctions,
INodeExecutionData,
INodePropertyOptions,
} from 'n8n-workflow';
import type { JSONSchema7 } from 'json-schema';
import type { BaseLanguageModel } from '@langchain/core/language_models/base';
import { ChatPromptTemplate, SystemMessagePromptTemplate } from '@langchain/core/prompts';
import type { z } from 'zod';
import { OutputFixingParser, StructuredOutputParser } from 'langchain/output_parsers';
import { HumanMessage } from '@langchain/core/messages';
import { generateSchema, getSandboxWithZod } from '../../../utils/schemaParsing';
import {
inputSchemaField,
jsonSchemaExampleField,
schemaTypeField,
} from '../../../utils/descriptions';
import { getTracingConfig } from '../../../utils/tracing';
import type { AttributeDefinition } from './types';
import { makeZodSchemaFromAttributes } from './helpers';
const SYSTEM_PROMPT_TEMPLATE = `You are an expert extraction algorithm.
Only extract relevant information from the text.
If you do not know the value of an attribute asked to extract, you may omit the attribute's value.`;
export class InformationExtractor implements INodeType {
description: INodeTypeDescription = {
displayName: 'Information Extractor',
name: 'informationExtractor',
icon: 'fa:project-diagram',
iconColor: 'black',
group: ['transform'],
version: 1,
description: 'Extract information from text in a structured format',
codex: {
alias: ['NER', 'parse', 'parsing', 'JSON', 'data extraction', 'structured'],
categories: ['AI'],
subcategories: {
AI: ['Chains', 'Root Nodes'],
},
resources: {
primaryDocumentation: [
{
url: 'https://docs.n8n.io/integrations/builtin/cluster-nodes/root-nodes/n8n-nodes-langchain.information-extractor/',
},
],
},
},
defaults: {
name: 'Information Extractor',
},
inputs: [
{ displayName: '', type: NodeConnectionType.Main },
{
displayName: 'Model',
maxConnections: 1,
type: NodeConnectionType.AiLanguageModel,
required: true,
},
],
outputs: [NodeConnectionType.Main],
properties: [
{
displayName: 'Text',
name: 'text',
type: 'string',
default: '',
description: 'The text to extract information from',
typeOptions: {
rows: 2,
},
},
{
...schemaTypeField,
description: 'How to specify the schema for the desired output',
options: [
{
name: 'From Attribute Descriptions',
value: 'fromAttributes',
description:
'Extract specific attributes from the text based on types and descriptions',
} as INodePropertyOptions,
...(schemaTypeField.options as INodePropertyOptions[]),
],
default: 'fromAttributes',
},
{
...jsonSchemaExampleField,
default: `{
"state": "California",
"cities": ["Los Angeles", "San Francisco", "San Diego"]
}`,
},
{
...inputSchemaField,
default: `{
"type": "object",
"properties": {
"state": {
"type": "string"
},
"cities": {
"type": "array",
"items": {
"type": "string"
}
}
}
}`,
},
{
displayName:
'The schema has to be defined in the <a target="_blank" href="https://json-schema.org/">JSON Schema</a> format. Look at <a target="_blank" href="https://json-schema.org/learn/miscellaneous-examples.html">this</a> page for examples.',
name: 'notice',
type: 'notice',
default: '',
displayOptions: {
show: {
schemaType: ['manual'],
},
},
},
{
displayName: 'Attributes',
name: 'attributes',
placeholder: 'Add Attribute',
type: 'fixedCollection',
default: {},
displayOptions: {
show: {
schemaType: ['fromAttributes'],
},
},
typeOptions: {
multipleValues: true,
},
options: [
{
name: 'attributes',
displayName: 'Attribute List',
values: [
{
displayName: 'Name',
name: 'name',
type: 'string',
default: '',
description: 'Attribute to extract',
placeholder: 'e.g. company_name',
required: true,
},
{
displayName: 'Type',
name: 'type',
type: 'options',
description: 'Data type of the attribute',
required: true,
options: [
{
name: 'Boolean',
value: 'boolean',
},
{
name: 'Date',
value: 'date',
},
{
name: 'Number',
value: 'number',
},
{
name: 'String',
value: 'string',
},
],
default: 'string',
},
{
displayName: 'Description',
name: 'description',
type: 'string',
default: '',
description: 'Describe your attribute',
placeholder: 'Add description for the attribute',
required: true,
},
{
displayName: 'Required',
name: 'required',
type: 'boolean',
default: false,
description: 'Whether attribute is required',
required: true,
},
],
},
],
},
{
displayName: 'Options',
name: 'options',
type: 'collection',
default: {},
placeholder: 'Add Option',
options: [
{
displayName: 'System Prompt Template',
name: 'systemPromptTemplate',
type: 'string',
default: SYSTEM_PROMPT_TEMPLATE,
description: 'String to use directly as the system prompt template',
typeOptions: {
rows: 6,
},
},
],
},
],
};
async execute(this: IExecuteFunctions): Promise<INodeExecutionData[][]> {
const items = this.getInputData();
const llm = (await this.getInputConnectionData(
NodeConnectionType.AiLanguageModel,
0,
)) as BaseLanguageModel;
const schemaType = this.getNodeParameter('schemaType', 0, '') as
| 'fromAttributes'
| 'fromJson'
| 'manual';
let parser: OutputFixingParser<object>;
if (schemaType === 'fromAttributes') {
const attributes = this.getNodeParameter(
'attributes.attributes',
0,
[],
) as AttributeDefinition[];
if (attributes.length === 0) {
throw new NodeOperationError(this.getNode(), 'At least one attribute must be specified');
}
parser = OutputFixingParser.fromLLM(
llm,
StructuredOutputParser.fromZodSchema(makeZodSchemaFromAttributes(attributes)),
);
} else {
let jsonSchema: JSONSchema7;
if (schemaType === 'fromJson') {
const jsonExample = this.getNodeParameter('jsonSchemaExample', 0, '') as string;
jsonSchema = generateSchema(jsonExample);
} else {
const inputSchema = this.getNodeParameter('inputSchema', 0, '') as string;
jsonSchema = jsonParse<JSONSchema7>(inputSchema);
}
const zodSchemaSandbox = getSandboxWithZod(this, jsonSchema, 0);
const zodSchema = (await zodSchemaSandbox.runCode()) as z.ZodSchema<object>;
parser = OutputFixingParser.fromLLM(llm, StructuredOutputParser.fromZodSchema(zodSchema));
}
const resultData: INodeExecutionData[] = [];
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
const input = this.getNodeParameter('text', itemIndex) as string;
const inputPrompt = new HumanMessage(input);
const options = this.getNodeParameter('options', itemIndex, {}) as {
systemPromptTemplate?: string;
};
const systemPromptTemplate = SystemMessagePromptTemplate.fromTemplate(
`${options.systemPromptTemplate ?? SYSTEM_PROMPT_TEMPLATE}
{format_instructions}`,
);
const messages = [
await systemPromptTemplate.format({
format_instructions: parser.getFormatInstructions(),
}),
inputPrompt,
];
const prompt = ChatPromptTemplate.fromMessages(messages);
const chain = prompt.pipe(llm).pipe(parser).withConfig(getTracingConfig(this));
try {
const output = await chain.invoke(messages);
resultData.push({ json: { output } });
} catch (error) {
if (this.continueOnFail(error)) {
resultData.push({ json: { error: error.message }, pairedItem: { item: itemIndex } });
continue;
}
throw error;
}
}
return [resultData];
}
}

View file

@ -0,0 +1,33 @@
import { z } from 'zod';
import type { AttributeDefinition } from './types';
function makeAttributeSchema(attributeDefinition: AttributeDefinition, required: boolean = true) {
let schema: z.ZodTypeAny;
if (attributeDefinition.type === 'string') {
schema = z.string();
} else if (attributeDefinition.type === 'number') {
schema = z.number();
} else if (attributeDefinition.type === 'boolean') {
schema = z.boolean();
} else if (attributeDefinition.type === 'date') {
schema = z.string().date();
} else {
schema = z.unknown();
}
if (!required) {
schema = schema.optional();
}
return schema.describe(attributeDefinition.description);
}
export function makeZodSchemaFromAttributes(attributes: AttributeDefinition[]) {
const schemaEntries = attributes.map((attr) => [
attr.name,
makeAttributeSchema(attr, attr.required),
]);
return z.object(Object.fromEntries(schemaEntries));
}

View file

@ -0,0 +1,218 @@
import type { IDataObject, IExecuteFunctions } from 'n8n-workflow/src';
import get from 'lodash/get';
import { FakeLLM, FakeListChatModel } from '@langchain/core/utils/testing';
import type { BaseLanguageModel } from '@langchain/core/language_models/base';
import { InformationExtractor } from '../InformationExtractor.node';
import { makeZodSchemaFromAttributes } from '../helpers';
import type { AttributeDefinition } from '../types';
const mockPersonAttributes: AttributeDefinition[] = [
{
name: 'name',
type: 'string',
description: 'The name of the person',
required: false,
},
{
name: 'age',
type: 'number',
description: 'The age of the person',
required: false,
},
];
const mockPersonAttributesRequired: AttributeDefinition[] = [
{
name: 'name',
type: 'string',
description: 'The name of the person',
required: true,
},
{
name: 'age',
type: 'number',
description: 'The age of the person',
required: true,
},
];
function formatFakeLlmResponse(object: Record<string, any>) {
return `\`\`\`json\n${JSON.stringify(object, null, 2)}\n\`\`\``;
}
const createExecuteFunctionsMock = (parameters: IDataObject, fakeLlm: BaseLanguageModel) => {
const nodeParameters = parameters;
return {
getNodeParameter(parameter: string) {
return get(nodeParameters, parameter);
},
getNode() {
return {};
},
getInputConnectionData() {
return fakeLlm;
},
getInputData() {
return [{ json: {} }];
},
getWorkflow() {
return {
name: 'Test Workflow',
};
},
getExecutionId() {
return 'test_execution_id';
},
continueOnFail() {
return false;
},
} as unknown as IExecuteFunctions;
};
describe('InformationExtractor', () => {
describe('From Attribute Descriptions', () => {
it('should generate a schema from attribute descriptions with optional fields', async () => {
const schema = makeZodSchemaFromAttributes(mockPersonAttributes);
expect(schema.parse({ name: 'John', age: 30 })).toEqual({ name: 'John', age: 30 });
expect(schema.parse({ name: 'John' })).toEqual({ name: 'John' });
expect(schema.parse({ age: 30 })).toEqual({ age: 30 });
});
it('should make a request to LLM and return the extracted attributes', async () => {
const node = new InformationExtractor();
const response = await node.execute.call(
createExecuteFunctionsMock(
{
text: 'John is 30 years old',
attributes: {
attributes: mockPersonAttributes,
},
options: {},
schemaType: 'fromAttributes',
},
new FakeLLM({ response: formatFakeLlmResponse({ name: 'John', age: 30 }) }),
),
);
expect(response).toEqual([[{ json: { output: { name: 'John', age: 30 } } }]]);
});
it('should not fail if LLM could not extract some attribute', async () => {
const node = new InformationExtractor();
const response = await node.execute.call(
createExecuteFunctionsMock(
{
text: 'John is 30 years old',
attributes: {
attributes: mockPersonAttributes,
},
options: {},
schemaType: 'fromAttributes',
},
new FakeLLM({ response: formatFakeLlmResponse({ name: 'John' }) }),
),
);
expect(response).toEqual([[{ json: { output: { name: 'John' } } }]]);
});
it('should fail if LLM could not extract some required attribute', async () => {
const node = new InformationExtractor();
try {
await node.execute.call(
createExecuteFunctionsMock(
{
text: 'John is 30 years old',
attributes: {
attributes: mockPersonAttributesRequired,
},
options: {},
schemaType: 'fromAttributes',
},
new FakeLLM({ response: formatFakeLlmResponse({ name: 'John' }) }),
),
);
} catch (error) {
expect(error.message).toContain('Failed to parse');
}
});
it('should fail if LLM extracted an attribute with the wrong type', async () => {
const node = new InformationExtractor();
try {
await node.execute.call(
createExecuteFunctionsMock(
{
text: 'John is 30 years old',
attributes: {
attributes: mockPersonAttributes,
},
options: {},
schemaType: 'fromAttributes',
},
new FakeLLM({ response: formatFakeLlmResponse({ name: 'John', age: '30' }) }),
),
);
} catch (error) {
expect(error.message).toContain('Failed to parse');
}
});
it('retries if LLM fails to extract some required attribute', async () => {
const node = new InformationExtractor();
const response = await node.execute.call(
createExecuteFunctionsMock(
{
text: 'John is 30 years old',
attributes: {
attributes: mockPersonAttributesRequired,
},
options: {},
schemaType: 'fromAttributes',
},
new FakeListChatModel({
responses: [
formatFakeLlmResponse({ name: 'John' }),
formatFakeLlmResponse({ name: 'John', age: 30 }),
],
}),
),
);
expect(response).toEqual([[{ json: { output: { name: 'John', age: 30 } } }]]);
});
it('retries if LLM extracted an attribute with a wrong type', async () => {
const node = new InformationExtractor();
const response = await node.execute.call(
createExecuteFunctionsMock(
{
text: 'John is 30 years old',
attributes: {
attributes: mockPersonAttributesRequired,
},
options: {},
schemaType: 'fromAttributes',
},
new FakeListChatModel({
responses: [
formatFakeLlmResponse({ name: 'John', age: '30' }),
formatFakeLlmResponse({ name: 'John', age: 30 }),
],
}),
),
);
expect(response).toEqual([[{ json: { output: { name: 'John', age: 30 } } }]]);
});
});
});

View file

@ -0,0 +1,6 @@
export interface AttributeDefinition {
name: string;
description: string;
type: 'string' | 'number' | 'boolean' | 'date';
required: boolean;
}

View file

@ -45,6 +45,7 @@
"dist/nodes/chains/ChainSummarization/ChainSummarization.node.js", "dist/nodes/chains/ChainSummarization/ChainSummarization.node.js",
"dist/nodes/chains/ChainLLM/ChainLlm.node.js", "dist/nodes/chains/ChainLLM/ChainLlm.node.js",
"dist/nodes/chains/ChainRetrievalQA/ChainRetrievalQa.node.js", "dist/nodes/chains/ChainRetrievalQA/ChainRetrievalQa.node.js",
"dist/nodes/chains/InformationExtractor/InformationExtractor.node.js",
"dist/nodes/chains/TextClassifier/TextClassifier.node.js", "dist/nodes/chains/TextClassifier/TextClassifier.node.js",
"dist/nodes/code/Code.node.js", "dist/nodes/code/Code.node.js",
"dist/nodes/document_loaders/DocumentDefaultDataLoader/DocumentDefaultDataLoader.node.js", "dist/nodes/document_loaders/DocumentDefaultDataLoader/DocumentDefaultDataLoader.node.js",