feat(Postgres Chat Memory, Redis Chat Memory, Xata): Add support for context window length (#10203)

This commit is contained in:
jeanpaul 2024-08-06 10:42:18 +02:00 committed by GitHub
parent 1eba7c3c76
commit e3edeaa035
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 55 additions and 20 deletions

View file

@ -10,7 +10,7 @@ import type { BufferWindowMemoryInput } from 'langchain/memory';
import { BufferWindowMemory } from 'langchain/memory'; import { BufferWindowMemory } from 'langchain/memory';
import { logWrapper } from '../../../utils/logWrapper'; import { logWrapper } from '../../../utils/logWrapper';
import { getConnectionHintNoticeField } from '../../../utils/sharedFields'; import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
import { sessionIdOption, sessionKeyProperty } from '../descriptions'; import { sessionIdOption, sessionKeyProperty, contextWindowLengthProperty } from '../descriptions';
import { getSessionId } from '../../../utils/helpers'; import { getSessionId } from '../../../utils/helpers';
class MemoryChatBufferSingleton { class MemoryChatBufferSingleton {
@ -130,13 +130,7 @@ export class MemoryBufferWindow implements INodeType {
}, },
}, },
sessionKeyProperty, sessionKeyProperty,
{ contextWindowLengthProperty,
displayName: 'Context Window Length',
name: 'contextWindowLength',
type: 'number',
default: 5,
description: 'The number of previous messages to consider for context',
},
], ],
}; };

View file

@ -1,7 +1,7 @@
/* eslint-disable n8n-nodes-base/node-dirname-against-convention */ /* eslint-disable n8n-nodes-base/node-dirname-against-convention */
import type { IExecuteFunctions, INodeType, INodeTypeDescription, SupplyData } from 'n8n-workflow'; import type { IExecuteFunctions, INodeType, INodeTypeDescription, SupplyData } from 'n8n-workflow';
import { NodeConnectionType } from 'n8n-workflow'; import { NodeConnectionType } from 'n8n-workflow';
import { BufferMemory } from 'langchain/memory'; import { BufferMemory, BufferWindowMemory } from 'langchain/memory';
import { PostgresChatMessageHistory } from '@langchain/community/stores/message/postgres'; import { PostgresChatMessageHistory } from '@langchain/community/stores/message/postgres';
import type pg from 'pg'; import type pg from 'pg';
import { configurePostgres } from 'n8n-nodes-base/dist/nodes/Postgres/v2/transport'; import { configurePostgres } from 'n8n-nodes-base/dist/nodes/Postgres/v2/transport';
@ -9,7 +9,7 @@ import type { PostgresNodeCredentials } from 'n8n-nodes-base/dist/nodes/Postgres
import { postgresConnectionTest } from 'n8n-nodes-base/dist/nodes/Postgres/v2/methods/credentialTest'; import { postgresConnectionTest } from 'n8n-nodes-base/dist/nodes/Postgres/v2/methods/credentialTest';
import { logWrapper } from '../../../utils/logWrapper'; import { logWrapper } from '../../../utils/logWrapper';
import { getConnectionHintNoticeField } from '../../../utils/sharedFields'; import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
import { sessionIdOption, sessionKeyProperty } from '../descriptions'; import { sessionIdOption, sessionKeyProperty, contextWindowLengthProperty } from '../descriptions';
import { getSessionId } from '../../../utils/helpers'; import { getSessionId } from '../../../utils/helpers';
export class MemoryPostgresChat implements INodeType { export class MemoryPostgresChat implements INodeType {
@ -18,7 +18,7 @@ export class MemoryPostgresChat implements INodeType {
name: 'memoryPostgresChat', name: 'memoryPostgresChat',
icon: 'file:postgres.svg', icon: 'file:postgres.svg',
group: ['transform'], group: ['transform'],
version: [1], version: [1, 1.1],
description: 'Stores the chat history in Postgres table.', description: 'Stores the chat history in Postgres table.',
defaults: { defaults: {
name: 'Postgres Chat Memory', name: 'Postgres Chat Memory',
@ -60,6 +60,10 @@ export class MemoryPostgresChat implements INodeType {
description: description:
'The table name to store the chat history in. If table does not exist, it will be created.', 'The table name to store the chat history in. If table does not exist, it will be created.',
}, },
{
...contextWindowLengthProperty,
displayOptions: { hide: { '@version': [{ _cnd: { lt: 1.1 } }] } },
},
], ],
}; };
@ -83,12 +87,19 @@ export class MemoryPostgresChat implements INodeType {
tableName, tableName,
}); });
const memory = new BufferMemory({ const memClass = this.getNode().typeVersion < 1.1 ? BufferMemory : BufferWindowMemory;
const kOptions =
this.getNode().typeVersion < 1.1
? {}
: { k: this.getNodeParameter('contextWindowLength', itemIndex) };
const memory = new memClass({
memoryKey: 'chat_history', memoryKey: 'chat_history',
chatHistory: pgChatHistory, chatHistory: pgChatHistory,
returnMessages: true, returnMessages: true,
inputKey: 'input', inputKey: 'input',
outputKey: 'output', outputKey: 'output',
...kOptions,
}); });
async function closeFunction() { async function closeFunction() {

View file

@ -7,14 +7,14 @@ import {
type SupplyData, type SupplyData,
NodeConnectionType, NodeConnectionType,
} from 'n8n-workflow'; } from 'n8n-workflow';
import { BufferMemory } from 'langchain/memory'; import { BufferMemory, BufferWindowMemory } from 'langchain/memory';
import type { RedisChatMessageHistoryInput } from '@langchain/redis'; import type { RedisChatMessageHistoryInput } from '@langchain/redis';
import { RedisChatMessageHistory } from '@langchain/redis'; import { RedisChatMessageHistory } from '@langchain/redis';
import type { RedisClientOptions } from 'redis'; import type { RedisClientOptions } from 'redis';
import { createClient } from 'redis'; import { createClient } from 'redis';
import { logWrapper } from '../../../utils/logWrapper'; import { logWrapper } from '../../../utils/logWrapper';
import { getConnectionHintNoticeField } from '../../../utils/sharedFields'; import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
import { sessionIdOption, sessionKeyProperty } from '../descriptions'; import { sessionIdOption, sessionKeyProperty, contextWindowLengthProperty } from '../descriptions';
import { getSessionId } from '../../../utils/helpers'; import { getSessionId } from '../../../utils/helpers';
export class MemoryRedisChat implements INodeType { export class MemoryRedisChat implements INodeType {
@ -23,7 +23,7 @@ export class MemoryRedisChat implements INodeType {
name: 'memoryRedisChat', name: 'memoryRedisChat',
icon: 'file:redis.svg', icon: 'file:redis.svg',
group: ['transform'], group: ['transform'],
version: [1, 1.1, 1.2], version: [1, 1.1, 1.2, 1.3],
description: 'Stores the chat history in Redis.', description: 'Stores the chat history in Redis.',
defaults: { defaults: {
name: 'Redis Chat Memory', name: 'Redis Chat Memory',
@ -95,6 +95,10 @@ export class MemoryRedisChat implements INodeType {
description: description:
'For how long the session should be stored in seconds. If set to 0 it will not expire.', 'For how long the session should be stored in seconds. If set to 0 it will not expire.',
}, },
{
...contextWindowLengthProperty,
displayOptions: { hide: { '@version': [{ _cnd: { lt: 1.3 } }] } },
},
], ],
}; };
@ -143,12 +147,19 @@ export class MemoryRedisChat implements INodeType {
} }
const redisChatHistory = new RedisChatMessageHistory(redisChatConfig); const redisChatHistory = new RedisChatMessageHistory(redisChatConfig);
const memory = new BufferMemory({ const memClass = this.getNode().typeVersion < 1.3 ? BufferMemory : BufferWindowMemory;
const kOptions =
this.getNode().typeVersion < 1.3
? {}
: { k: this.getNodeParameter('contextWindowLength', itemIndex) };
const memory = new memClass({
memoryKey: 'chat_history', memoryKey: 'chat_history',
chatHistory: redisChatHistory, chatHistory: redisChatHistory,
returnMessages: true, returnMessages: true,
inputKey: 'input', inputKey: 'input',
outputKey: 'output', outputKey: 'output',
...kOptions,
}); });
async function closeFunction() { async function closeFunction() {

View file

@ -2,11 +2,11 @@
import { NodeConnectionType, NodeOperationError } from 'n8n-workflow'; import { NodeConnectionType, NodeOperationError } from 'n8n-workflow';
import type { IExecuteFunctions, INodeType, INodeTypeDescription, SupplyData } from 'n8n-workflow'; import type { IExecuteFunctions, INodeType, INodeTypeDescription, SupplyData } from 'n8n-workflow';
import { XataChatMessageHistory } from '@langchain/community/stores/message/xata'; import { XataChatMessageHistory } from '@langchain/community/stores/message/xata';
import { BufferMemory } from 'langchain/memory'; import { BufferMemory, BufferWindowMemory } from 'langchain/memory';
import { BaseClient } from '@xata.io/client'; import { BaseClient } from '@xata.io/client';
import { logWrapper } from '../../../utils/logWrapper'; import { logWrapper } from '../../../utils/logWrapper';
import { getConnectionHintNoticeField } from '../../../utils/sharedFields'; import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
import { sessionIdOption, sessionKeyProperty } from '../descriptions'; import { sessionIdOption, sessionKeyProperty, contextWindowLengthProperty } from '../descriptions';
import { getSessionId } from '../../../utils/helpers'; import { getSessionId } from '../../../utils/helpers';
export class MemoryXata implements INodeType { export class MemoryXata implements INodeType {
@ -15,7 +15,7 @@ export class MemoryXata implements INodeType {
name: 'memoryXata', name: 'memoryXata',
icon: 'file:xata.svg', icon: 'file:xata.svg',
group: ['transform'], group: ['transform'],
version: [1, 1.1, 1.2], version: [1, 1.1, 1.2, 1.3],
description: 'Use Xata Memory', description: 'Use Xata Memory',
defaults: { defaults: {
name: 'Xata', name: 'Xata',
@ -81,6 +81,10 @@ export class MemoryXata implements INodeType {
}, },
}, },
sessionKeyProperty, sessionKeyProperty,
{
...contextWindowLengthProperty,
displayOptions: { hide: { '@version': [{ _cnd: { lt: 1.3 } }] } },
},
], ],
}; };
@ -120,12 +124,19 @@ export class MemoryXata implements INodeType {
apiKey: credentials.apiKey as string, apiKey: credentials.apiKey as string,
}); });
const memory = new BufferMemory({ const memClass = this.getNode().typeVersion < 1.3 ? BufferMemory : BufferWindowMemory;
const kOptions =
this.getNode().typeVersion < 1.3
? {}
: { k: this.getNodeParameter('contextWindowLength', itemIndex) };
const memory = new memClass({
chatHistory, chatHistory,
memoryKey: 'chat_history', memoryKey: 'chat_history',
returnMessages: true, returnMessages: true,
inputKey: 'input', inputKey: 'input',
outputKey: 'output', outputKey: 'output',
...kOptions,
}); });
return { return {

View file

@ -33,3 +33,11 @@ export const sessionKeyProperty: INodeProperties = {
}, },
}, },
}; };
export const contextWindowLengthProperty: INodeProperties = {
displayName: 'Context Window Length',
name: 'contextWindowLength',
type: 'number',
default: 5,
hint: 'How many past interactions the model receives as context',
};