From e3edeaa03526f041d15d1099ea91869e38a0decc Mon Sep 17 00:00:00 2001 From: jeanpaul Date: Tue, 6 Aug 2024 10:42:18 +0200 Subject: [PATCH] feat(Postgres Chat Memory, Redis Chat Memory, Xata): Add support for context window length (#10203) --- .../MemoryBufferWindow.node.ts | 10 ++-------- .../MemoryPostgresChat.node.ts | 19 +++++++++++++++---- .../MemoryRedisChat/MemoryRedisChat.node.ts | 19 +++++++++++++++---- .../memory/MemoryXata/MemoryXata.node.ts | 19 +++++++++++++++---- .../nodes/memory/descriptions.ts | 8 ++++++++ 5 files changed, 55 insertions(+), 20 deletions(-) diff --git a/packages/@n8n/nodes-langchain/nodes/memory/MemoryBufferWindow/MemoryBufferWindow.node.ts b/packages/@n8n/nodes-langchain/nodes/memory/MemoryBufferWindow/MemoryBufferWindow.node.ts index 2b7e205de6..b8eea7a5e2 100644 --- a/packages/@n8n/nodes-langchain/nodes/memory/MemoryBufferWindow/MemoryBufferWindow.node.ts +++ b/packages/@n8n/nodes-langchain/nodes/memory/MemoryBufferWindow/MemoryBufferWindow.node.ts @@ -10,7 +10,7 @@ import type { BufferWindowMemoryInput } from 'langchain/memory'; import { BufferWindowMemory } from 'langchain/memory'; import { logWrapper } from '../../../utils/logWrapper'; import { getConnectionHintNoticeField } from '../../../utils/sharedFields'; -import { sessionIdOption, sessionKeyProperty } from '../descriptions'; +import { sessionIdOption, sessionKeyProperty, contextWindowLengthProperty } from '../descriptions'; import { getSessionId } from '../../../utils/helpers'; class MemoryChatBufferSingleton { @@ -130,13 +130,7 @@ export class MemoryBufferWindow implements INodeType { }, }, sessionKeyProperty, - { - displayName: 'Context Window Length', - name: 'contextWindowLength', - type: 'number', - default: 5, - description: 'The number of previous messages to consider for context', - }, + contextWindowLengthProperty, ], }; diff --git a/packages/@n8n/nodes-langchain/nodes/memory/MemoryPostgresChat/MemoryPostgresChat.node.ts b/packages/@n8n/nodes-langchain/nodes/memory/MemoryPostgresChat/MemoryPostgresChat.node.ts index ea3ed3c33e..b1a9cd7aea 100644 --- a/packages/@n8n/nodes-langchain/nodes/memory/MemoryPostgresChat/MemoryPostgresChat.node.ts +++ b/packages/@n8n/nodes-langchain/nodes/memory/MemoryPostgresChat/MemoryPostgresChat.node.ts @@ -1,7 +1,7 @@ /* eslint-disable n8n-nodes-base/node-dirname-against-convention */ import type { IExecuteFunctions, INodeType, INodeTypeDescription, SupplyData } from 'n8n-workflow'; import { NodeConnectionType } from 'n8n-workflow'; -import { BufferMemory } from 'langchain/memory'; +import { BufferMemory, BufferWindowMemory } from 'langchain/memory'; import { PostgresChatMessageHistory } from '@langchain/community/stores/message/postgres'; import type pg from 'pg'; import { configurePostgres } from 'n8n-nodes-base/dist/nodes/Postgres/v2/transport'; @@ -9,7 +9,7 @@ import type { PostgresNodeCredentials } from 'n8n-nodes-base/dist/nodes/Postgres import { postgresConnectionTest } from 'n8n-nodes-base/dist/nodes/Postgres/v2/methods/credentialTest'; import { logWrapper } from '../../../utils/logWrapper'; import { getConnectionHintNoticeField } from '../../../utils/sharedFields'; -import { sessionIdOption, sessionKeyProperty } from '../descriptions'; +import { sessionIdOption, sessionKeyProperty, contextWindowLengthProperty } from '../descriptions'; import { getSessionId } from '../../../utils/helpers'; export class MemoryPostgresChat implements INodeType { @@ -18,7 +18,7 @@ export class MemoryPostgresChat implements INodeType { name: 'memoryPostgresChat', icon: 'file:postgres.svg', group: ['transform'], - version: [1], + version: [1, 1.1], description: 'Stores the chat history in Postgres table.', defaults: { name: 'Postgres Chat Memory', @@ -60,6 +60,10 @@ export class MemoryPostgresChat implements INodeType { description: 'The table name to store the chat history in. If table does not exist, it will be created.', }, + { + ...contextWindowLengthProperty, + displayOptions: { hide: { '@version': [{ _cnd: { lt: 1.1 } }] } }, + }, ], }; @@ -83,12 +87,19 @@ export class MemoryPostgresChat implements INodeType { tableName, }); - const memory = new BufferMemory({ + const memClass = this.getNode().typeVersion < 1.1 ? BufferMemory : BufferWindowMemory; + const kOptions = + this.getNode().typeVersion < 1.1 + ? {} + : { k: this.getNodeParameter('contextWindowLength', itemIndex) }; + + const memory = new memClass({ memoryKey: 'chat_history', chatHistory: pgChatHistory, returnMessages: true, inputKey: 'input', outputKey: 'output', + ...kOptions, }); async function closeFunction() { diff --git a/packages/@n8n/nodes-langchain/nodes/memory/MemoryRedisChat/MemoryRedisChat.node.ts b/packages/@n8n/nodes-langchain/nodes/memory/MemoryRedisChat/MemoryRedisChat.node.ts index d139bd31e3..da57ede1d2 100644 --- a/packages/@n8n/nodes-langchain/nodes/memory/MemoryRedisChat/MemoryRedisChat.node.ts +++ b/packages/@n8n/nodes-langchain/nodes/memory/MemoryRedisChat/MemoryRedisChat.node.ts @@ -7,14 +7,14 @@ import { type SupplyData, NodeConnectionType, } from 'n8n-workflow'; -import { BufferMemory } from 'langchain/memory'; +import { BufferMemory, BufferWindowMemory } from 'langchain/memory'; import type { RedisChatMessageHistoryInput } from '@langchain/redis'; import { RedisChatMessageHistory } from '@langchain/redis'; import type { RedisClientOptions } from 'redis'; import { createClient } from 'redis'; import { logWrapper } from '../../../utils/logWrapper'; import { getConnectionHintNoticeField } from '../../../utils/sharedFields'; -import { sessionIdOption, sessionKeyProperty } from '../descriptions'; +import { sessionIdOption, sessionKeyProperty, contextWindowLengthProperty } from '../descriptions'; import { getSessionId } from '../../../utils/helpers'; export class MemoryRedisChat implements INodeType { @@ -23,7 +23,7 @@ export class MemoryRedisChat implements INodeType { name: 'memoryRedisChat', icon: 'file:redis.svg', group: ['transform'], - version: [1, 1.1, 1.2], + version: [1, 1.1, 1.2, 1.3], description: 'Stores the chat history in Redis.', defaults: { name: 'Redis Chat Memory', @@ -95,6 +95,10 @@ export class MemoryRedisChat implements INodeType { description: 'For how long the session should be stored in seconds. If set to 0 it will not expire.', }, + { + ...contextWindowLengthProperty, + displayOptions: { hide: { '@version': [{ _cnd: { lt: 1.3 } }] } }, + }, ], }; @@ -143,12 +147,19 @@ export class MemoryRedisChat implements INodeType { } const redisChatHistory = new RedisChatMessageHistory(redisChatConfig); - const memory = new BufferMemory({ + const memClass = this.getNode().typeVersion < 1.3 ? BufferMemory : BufferWindowMemory; + const kOptions = + this.getNode().typeVersion < 1.3 + ? {} + : { k: this.getNodeParameter('contextWindowLength', itemIndex) }; + + const memory = new memClass({ memoryKey: 'chat_history', chatHistory: redisChatHistory, returnMessages: true, inputKey: 'input', outputKey: 'output', + ...kOptions, }); async function closeFunction() { diff --git a/packages/@n8n/nodes-langchain/nodes/memory/MemoryXata/MemoryXata.node.ts b/packages/@n8n/nodes-langchain/nodes/memory/MemoryXata/MemoryXata.node.ts index e5c9dc4c35..f0177d9e75 100644 --- a/packages/@n8n/nodes-langchain/nodes/memory/MemoryXata/MemoryXata.node.ts +++ b/packages/@n8n/nodes-langchain/nodes/memory/MemoryXata/MemoryXata.node.ts @@ -2,11 +2,11 @@ import { NodeConnectionType, NodeOperationError } from 'n8n-workflow'; import type { IExecuteFunctions, INodeType, INodeTypeDescription, SupplyData } from 'n8n-workflow'; import { XataChatMessageHistory } from '@langchain/community/stores/message/xata'; -import { BufferMemory } from 'langchain/memory'; +import { BufferMemory, BufferWindowMemory } from 'langchain/memory'; import { BaseClient } from '@xata.io/client'; import { logWrapper } from '../../../utils/logWrapper'; import { getConnectionHintNoticeField } from '../../../utils/sharedFields'; -import { sessionIdOption, sessionKeyProperty } from '../descriptions'; +import { sessionIdOption, sessionKeyProperty, contextWindowLengthProperty } from '../descriptions'; import { getSessionId } from '../../../utils/helpers'; export class MemoryXata implements INodeType { @@ -15,7 +15,7 @@ export class MemoryXata implements INodeType { name: 'memoryXata', icon: 'file:xata.svg', group: ['transform'], - version: [1, 1.1, 1.2], + version: [1, 1.1, 1.2, 1.3], description: 'Use Xata Memory', defaults: { name: 'Xata', @@ -81,6 +81,10 @@ export class MemoryXata implements INodeType { }, }, sessionKeyProperty, + { + ...contextWindowLengthProperty, + displayOptions: { hide: { '@version': [{ _cnd: { lt: 1.3 } }] } }, + }, ], }; @@ -120,12 +124,19 @@ export class MemoryXata implements INodeType { apiKey: credentials.apiKey as string, }); - const memory = new BufferMemory({ + const memClass = this.getNode().typeVersion < 1.3 ? BufferMemory : BufferWindowMemory; + const kOptions = + this.getNode().typeVersion < 1.3 + ? {} + : { k: this.getNodeParameter('contextWindowLength', itemIndex) }; + + const memory = new memClass({ chatHistory, memoryKey: 'chat_history', returnMessages: true, inputKey: 'input', outputKey: 'output', + ...kOptions, }); return { diff --git a/packages/@n8n/nodes-langchain/nodes/memory/descriptions.ts b/packages/@n8n/nodes-langchain/nodes/memory/descriptions.ts index 5f722c4647..354d134fb7 100644 --- a/packages/@n8n/nodes-langchain/nodes/memory/descriptions.ts +++ b/packages/@n8n/nodes-langchain/nodes/memory/descriptions.ts @@ -33,3 +33,11 @@ export const sessionKeyProperty: INodeProperties = { }, }, }; + +export const contextWindowLengthProperty: INodeProperties = { + displayName: 'Context Window Length', + name: 'contextWindowLength', + type: 'number', + default: 5, + hint: 'How many past interactions the model receives as context', +};