feat: SQL agent improvements (#8709)

Signed-off-by: Oleg Ivaniv <me@olegivaniv.com>
Co-authored-by: Oleg Ivaniv <me@olegivaniv.com>
This commit is contained in:
Michael Kret 2024-02-26 15:35:00 +02:00 committed by GitHub
parent 7012577fce
commit 09524304e6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 212 additions and 63 deletions

View file

@ -9,9 +9,9 @@ import type {
INodeTypeDescription,
} from 'n8n-workflow';
import { getTemplateNoticeField } from '../../../utils/sharedFields';
import { promptTypeOptions, textInput } from '../../../utils/descriptions';
import { conversationalAgentProperties } from './agents/ConversationalAgent/description';
import { conversationalAgentExecute } from './agents/ConversationalAgent/execute';
import { openAiFunctionsAgentProperties } from './agents/OpenAiFunctionsAgent/description';
import { openAiFunctionsAgentExecute } from './agents/OpenAiFunctionsAgent/execute';
import { planAndExecuteAgentProperties } from './agents/PlanAndExecuteAgent/description';
@ -20,6 +20,7 @@ import { reActAgentAgentProperties } from './agents/ReActAgent/description';
import { reActAgentAgentExecute } from './agents/ReActAgent/execute';
import { sqlAgentAgentProperties } from './agents/SqlAgent/description';
import { sqlAgentAgentExecute } from './agents/SqlAgent/execute';
// Function used in the inputs expression to figure out which inputs to
// display based on the agent type
function getInputs(
@ -128,6 +129,9 @@ function getInputs(
{
type: NodeConnectionType.AiLanguageModel,
},
{
type: NodeConnectionType.AiMemory,
},
];
} else if (agent === 'planAndExecuteAgent') {
specialInputs = [
@ -157,10 +161,10 @@ export class Agent implements INodeType {
name: 'agent',
icon: 'fa:robot',
group: ['transform'],
version: [1, 1.1, 1.2, 1.3],
version: [1, 1.1, 1.2, 1.3, 1.4],
description: 'Generates an action plan and executes it. Can use external tools.',
subtitle:
"={{ { conversationalAgent: 'Conversational Agent', openAiFunctionsAgent: 'OpenAI Functions Agent', reactAgent: 'ReAct Agent', sqlAgent: 'SQL Agent' }[$parameter.agent] }}",
"={{ { conversationalAgent: 'Conversational Agent', openAiFunctionsAgent: 'OpenAI Functions Agent', reActAgent: 'ReAct Agent', sqlAgent: 'SQL Agent', planAndExecuteAgent: 'Plan and Execute Agent' }[$parameter.agent] }}",
defaults: {
name: 'AI Agent',
color: '#404040',
@ -257,45 +261,23 @@ export class Agent implements INodeType {
default: 'conversationalAgent',
},
{
displayName: 'Prompt',
name: 'promptType',
type: 'options',
options: [
{
// eslint-disable-next-line n8n-nodes-base/node-param-display-name-miscased
name: 'Take from previous node automatically',
value: 'auto',
description: 'Looks for an input field called chatInput',
},
{
// eslint-disable-next-line n8n-nodes-base/node-param-display-name-miscased
name: 'Define below',
value: 'define',
description:
'Use an expression to reference data in previous nodes or enter static text',
},
],
...promptTypeOptions,
displayOptions: {
hide: {
'@version': [{ _cnd: { lte: 1.2 } }],
agent: ['sqlAgent'],
},
},
default: 'auto',
},
{
displayName: 'Text',
name: 'text',
type: 'string',
required: true,
default: '',
placeholder: 'e.g. Hello, how can you help me?',
typeOptions: {
rows: 2,
},
...textInput,
displayOptions: {
show: {
promptType: ['define'],
},
hide: {
agent: ['sqlAgent'],
},
},
},
{

View file

@ -1,35 +1,91 @@
import type { INodeProperties } from 'n8n-workflow';
import { promptTypeOptions, textInput } from '../../../../../utils/descriptions';
import { SQL_PREFIX, SQL_SUFFIX } from './other/prompts';
const dataSourceOptions: INodeProperties = {
displayName: 'Data Source',
name: 'dataSource',
type: 'options',
displayOptions: {
show: {
agent: ['sqlAgent'],
},
},
default: 'sqlite',
description: 'SQL database to connect to',
options: [
{
name: 'MySQL',
value: 'mysql',
description: 'Connect to a MySQL database',
},
{
name: 'Postgres',
value: 'postgres',
description: 'Connect to a Postgres database',
},
{
name: 'SQLite',
value: 'sqlite',
description: 'Use SQLite by connecting a database file as binary input',
},
],
};
export const sqlAgentAgentProperties: INodeProperties[] = [
{
displayName: 'Data Source',
name: 'dataSource',
type: 'options',
...dataSourceOptions,
displayOptions: {
show: {
agent: ['sqlAgent'],
'@version': [{ _cnd: { lt: 1.4 } }],
},
},
default: 'sqlite',
description: 'SQL database to connect to',
options: [
{
name: 'MySQL',
value: 'mysql',
description: 'Connect to a MySQL database',
},
{
...dataSourceOptions,
default: 'postgres',
displayOptions: {
show: {
agent: ['sqlAgent'],
'@version': [{ _cnd: { gte: 1.4 } }],
},
{
name: 'Postgres',
value: 'postgres',
description: 'Connect to a Postgres database',
},
},
{
displayName: 'Credentials',
name: 'credentials',
type: 'credentials',
default: '',
},
{
// eslint-disable-next-line n8n-nodes-base/node-param-display-name-miscased
displayName:
"Pass the SQLite database into this node as binary data, e.g. by inserting a 'Read/Write Files from Disk' node beforehand",
name: 'sqLiteFileNotice',
type: 'notice',
default: '',
displayOptions: {
show: {
agent: ['sqlAgent'],
dataSource: ['sqlite'],
},
{
name: 'SQLite',
value: 'sqlite',
description: 'Use SQLite by connecting a database file as binary input',
},
},
{
displayName: 'Input Binary Field',
name: 'binaryPropertyName',
type: 'string',
default: 'data',
required: true,
placeholder: 'e.g data',
hint: 'The name of the input binary field containing the file to be extracted',
displayOptions: {
show: {
agent: ['sqlAgent'],
dataSource: ['sqlite'],
},
],
},
},
{
displayName: 'Prompt',
@ -47,6 +103,26 @@ export const sqlAgentAgentProperties: INodeProperties[] = [
rows: 5,
},
},
{
...promptTypeOptions,
displayOptions: {
hide: {
'@version': [{ _cnd: { lte: 1.2 } }],
},
show: {
agent: ['sqlAgent'],
},
},
},
{
...textInput,
displayOptions: {
show: {
promptType: ['define'],
agent: ['sqlAgent'],
},
},
},
{
displayName: 'Options',
name: 'options',

View file

@ -9,9 +9,10 @@ import { SqlDatabase } from 'langchain/sql_db';
import type { SqlCreatePromptArgs } from 'langchain/agents/toolkits/sql';
import { SqlToolkit, createSqlAgent } from 'langchain/agents/toolkits/sql';
import type { BaseLanguageModel } from 'langchain/dist/base_language';
import type { BaseChatMemory } from 'langchain/memory';
import type { DataSource } from '@n8n/typeorm';
import { getPromptInputByType } from '../../../../../utils/helpers';
import { getPromptInputByType, serializeChatHistory } from '../../../../../utils/helpers';
import { getSqliteDataSource } from './other/handlers/sqlite';
import { getPostgresDataSource } from './other/handlers/postgres';
import { SQL_PREFIX, SQL_SUFFIX } from './other/prompts';
@ -73,7 +74,8 @@ export async function sqlAgentAgentExecute(
);
}
dataSource = getSqliteDataSource.call(this, item.binary);
const binaryPropertyName = this.getNodeParameter('binaryPropertyName', i, 'data');
dataSource = await getSqliteDataSource.call(this, item.binary, binaryPropertyName);
}
if (selectedDataSource === 'postgres') {
@ -95,6 +97,7 @@ export async function sqlAgentAgentExecute(
topK: (options.topK as number) ?? 10,
prefix: (options.prefixPrompt as string) ?? SQL_PREFIX,
suffix: (options.suffixPrompt as string) ?? SQL_SUFFIX,
inputVariables: ['chatHistory', 'input', 'agent_scratchpad'],
};
const dbInstance = await SqlDatabase.fromDataSourceParams({
@ -107,7 +110,32 @@ export async function sqlAgentAgentExecute(
const toolkit = new SqlToolkit(dbInstance, model);
const agentExecutor = createSqlAgent(model, toolkit, agentOptions);
const response = await agentExecutor.call({ input, signal: this.getExecutionCancelSignal() });
const memory = (await this.getInputConnectionData(NodeConnectionType.AiMemory, 0)) as
| BaseChatMemory
| undefined;
agentExecutor.memory = memory;
let chatHistory = '';
if (memory) {
const messages = await memory.chatHistory.getMessages();
chatHistory = serializeChatHistory(messages);
}
let response;
try {
response = await agentExecutor.call({
input,
signal: this.getExecutionCancelSignal(),
chatHistory,
});
} catch (error) {
if (error.message?.output) {
response = error.message;
} else {
throw new NodeOperationError(this.getNode(), error.message, { itemIndex: i });
}
}
returnData.push({ json: response });
}

View file

@ -5,17 +5,28 @@ import * as temp from 'temp';
import * as sqlite3 from 'sqlite3';
import { DataSource } from '@n8n/typeorm';
export function getSqliteDataSource(
export async function getSqliteDataSource(
this: IExecuteFunctions,
binary: INodeExecutionData['binary'],
): DataSource {
const binaryData = binary?.data;
binaryPropertyName = 'data',
): Promise<DataSource> {
const binaryData = binary?.[binaryPropertyName];
if (!binaryData) {
throw new NodeOperationError(this.getNode(), 'No binary data received.');
}
const bufferString = Buffer.from(binaryData.data, BINARY_ENCODING);
let fileBase64;
if (binaryData.id) {
const chunkSize = 256 * 1024;
const stream = await this.helpers.getBinaryStream(binaryData.id, chunkSize);
const buffer = await this.helpers.binaryToBuffer(stream);
fileBase64 = buffer.toString('base64');
} else {
fileBase64 = binaryData.data;
}
const bufferString = Buffer.from(fileBase64, BINARY_ENCODING);
// Track and cleanup temp files at exit
temp.track();

View file

@ -12,6 +12,8 @@ DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the databa
If the question does not seem related to the database, just return "I don't know" as the answer.`;
export const SQL_SUFFIX = `Begin!
Chat History:
{chatHistory}
Question: {input}
Thought: I should look at the tables in the database to see what I can query.

View file

@ -85,8 +85,7 @@ export async function execute(this: IExecuteFunctions, i: number): Promise<INode
} catch (error) {
if (
error.message.includes('Bad request') &&
error.description &&
error.description.includes('Expected file to have JSONL format')
error.description?.includes('Expected file to have JSONL format')
) {
throw new NodeOperationError(this.getNode(), 'The file content is not in JSONL format', {
description:

View file

@ -1,14 +1,13 @@
import type { IDataObject, IExecuteFunctions } from 'n8n-workflow';
import get from 'lodash/get';
import * as assistant from '../actions/assistant';
import * as audio from '../actions/audio';
import * as file from '../actions/file';
import * as image from '../actions/image';
import * as text from '../actions/text';
import type { IDataObject, IExecuteFunctions } from 'n8n-workflow';
import * as transport from '../transport';
import get from 'lodash/get';
const createExecuteFunctionsMock = (parameters: IDataObject) => {
const nodeParameters = parameters;
return {

View file

@ -0,0 +1,34 @@
import type { INodeProperties } from 'n8n-workflow';
export const promptTypeOptions: INodeProperties = {
displayName: 'Prompt',
name: 'promptType',
type: 'options',
options: [
{
// eslint-disable-next-line n8n-nodes-base/node-param-display-name-miscased
name: 'Take from previous node automatically',
value: 'auto',
description: 'Looks for an input field called chatInput',
},
{
// eslint-disable-next-line n8n-nodes-base/node-param-display-name-miscased
name: 'Define below',
value: 'define',
description: 'Use an expression to reference data in previous nodes or enter static text',
},
],
default: 'auto',
};
export const textInput: INodeProperties = {
displayName: 'Text',
name: 'text',
type: 'string',
required: true,
default: '',
placeholder: 'e.g. Hello, how can you help me?',
typeOptions: {
rows: 2,
},
};

View file

@ -3,6 +3,7 @@ import type { EventNamesAiNodesType, IDataObject, IExecuteFunctions } from 'n8n-
import { BaseChatModel } from 'langchain/chat_models/base';
import { BaseChatModel as BaseChatModelCore } from '@langchain/core/language_models/chat_models';
import type { BaseOutputParser } from '@langchain/core/output_parsers';
import { BaseMessage } from 'langchain/schema';
export function getMetadataFiltersValues(
ctx: IExecuteFunctions,
@ -77,3 +78,17 @@ export async function logAiEvent(
executeFunctions.logger.debug(`Error logging AI event: ${event}`);
}
}
export function serializeChatHistory (chatHistory: Array<BaseMessage>): string {
return chatHistory
.map((chatMessage) => {
if (chatMessage._getType() === 'human') {
return `Human: ${chatMessage.content}`;
} else if (chatMessage._getType() === 'ai') {
return `Assistant: ${chatMessage.content}`;
} else {
return `${chatMessage.content}`;
}
})
.join('\n');
}

View file

@ -2237,6 +2237,7 @@ export interface INodeGraphItem {
method?: string; // HTTP Request node v2
src_node_id?: string;
src_instance_id?: string;
agent?: string; //@n8n/n8n-nodes-langchain.agent
}
export interface INodeNameIndex {

View file

@ -158,7 +158,9 @@ export function generateNodesGraph(
nodeItem.src_node_id = options.nodeIdMap[node.id];
}
if (node.type === 'n8n-nodes-base.httpRequest' && node.typeVersion === 1) {
if (node.type === '@n8n/n8n-nodes-langchain.agent') {
nodeItem.agent = (node.parameters.agent as string) || 'conversationalAgent';
} else if (node.type === 'n8n-nodes-base.httpRequest' && node.typeVersion === 1) {
try {
nodeItem.domain = new URL(node.parameters.url as string).hostname;
} catch {