feat: Session is selector for memory nodes (#8736)

This commit is contained in:
Michael Kret 2024-02-27 15:01:15 +02:00 committed by GitHub
parent 5f6da7b84e
commit 2aaf211dfc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 188 additions and 22 deletions

View file

@ -10,6 +10,8 @@ import type { BufferWindowMemoryInput } from 'langchain/memory';
import { BufferWindowMemory } from 'langchain/memory'; import { BufferWindowMemory } from 'langchain/memory';
import { logWrapper } from '../../../utils/logWrapper'; import { logWrapper } from '../../../utils/logWrapper';
import { getConnectionHintNoticeField } from '../../../utils/sharedFields'; import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
import { sessionIdOption, sessionKeyProperty } from '../descriptions';
import { getSessionId } from '../../../utils/helpers';
class MemoryChatBufferSingleton { class MemoryChatBufferSingleton {
private static instance: MemoryChatBufferSingleton; private static instance: MemoryChatBufferSingleton;
@ -70,7 +72,7 @@ export class MemoryBufferWindow implements INodeType {
name: 'memoryBufferWindow', name: 'memoryBufferWindow',
icon: 'fa:database', icon: 'fa:database',
group: ['transform'], group: ['transform'],
version: [1, 1.1], version: [1, 1.1, 1.2],
description: 'Stores in n8n memory, so no credentials required', description: 'Stores in n8n memory, so no credentials required',
defaults: { defaults: {
name: 'Window Buffer Memory', name: 'Window Buffer Memory',
@ -119,6 +121,15 @@ export class MemoryBufferWindow implements INodeType {
}, },
}, },
}, },
{
...sessionIdOption,
displayOptions: {
show: {
'@version': [{ _cnd: { gte: 1.2 } }],
},
},
},
sessionKeyProperty,
{ {
displayName: 'Context Window Length', displayName: 'Context Window Length',
name: 'contextWindowLength', name: 'contextWindowLength',
@ -130,12 +141,21 @@ export class MemoryBufferWindow implements INodeType {
}; };
async supplyData(this: IExecuteFunctions, itemIndex: number): Promise<SupplyData> { async supplyData(this: IExecuteFunctions, itemIndex: number): Promise<SupplyData> {
const sessionKey = this.getNodeParameter('sessionKey', itemIndex) as string;
const contextWindowLength = this.getNodeParameter('contextWindowLength', itemIndex) as number; const contextWindowLength = this.getNodeParameter('contextWindowLength', itemIndex) as number;
const workflowId = this.getWorkflow().id; const workflowId = this.getWorkflow().id;
const memoryInstance = MemoryChatBufferSingleton.getInstance(); const memoryInstance = MemoryChatBufferSingleton.getInstance();
const memory = await memoryInstance.getMemory(`${workflowId}__${sessionKey}`, { const nodeVersion = this.getNode().typeVersion;
let sessionId;
if (nodeVersion >= 1.2) {
sessionId = getSessionId(this, itemIndex);
} else {
sessionId = this.getNodeParameter('sessionKey', itemIndex) as string;
}
const memory = await memoryInstance.getMemory(`${workflowId}__${sessionId}`, {
k: contextWindowLength, k: contextWindowLength,
inputKey: 'input', inputKey: 'input',
memoryKey: 'chat_history', memoryKey: 'chat_history',

View file

@ -10,6 +10,8 @@ import {
import { MotorheadMemory } from 'langchain/memory'; import { MotorheadMemory } from 'langchain/memory';
import { logWrapper } from '../../../utils/logWrapper'; import { logWrapper } from '../../../utils/logWrapper';
import { getConnectionHintNoticeField } from '../../../utils/sharedFields'; import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
import { sessionIdOption, sessionKeyProperty } from '../descriptions';
import { getSessionId } from '../../../utils/helpers';
export class MemoryMotorhead implements INodeType { export class MemoryMotorhead implements INodeType {
description: INodeTypeDescription = { description: INodeTypeDescription = {
@ -17,7 +19,7 @@ export class MemoryMotorhead implements INodeType {
name: 'memoryMotorhead', name: 'memoryMotorhead',
icon: 'fa:file-export', icon: 'fa:file-export',
group: ['transform'], group: ['transform'],
version: [1, 1.1], version: [1, 1.1, 1.2],
description: 'Use Motorhead Memory', description: 'Use Motorhead Memory',
defaults: { defaults: {
name: 'Motorhead', name: 'Motorhead',
@ -72,13 +74,29 @@ export class MemoryMotorhead implements INodeType {
}, },
}, },
}, },
{
...sessionIdOption,
displayOptions: {
show: {
'@version': [{ _cnd: { gte: 1.2 } }],
},
},
},
sessionKeyProperty,
], ],
}; };
async supplyData(this: IExecuteFunctions, itemIndex: number): Promise<SupplyData> { async supplyData(this: IExecuteFunctions, itemIndex: number): Promise<SupplyData> {
const credentials = await this.getCredentials('motorheadApi'); const credentials = await this.getCredentials('motorheadApi');
const nodeVersion = this.getNode().typeVersion;
const sessionId = this.getNodeParameter('sessionId', itemIndex) as string; let sessionId;
if (nodeVersion >= 1.2) {
sessionId = getSessionId(this, itemIndex);
} else {
sessionId = this.getNodeParameter('sessionId', itemIndex) as string;
}
const memory = new MotorheadMemory({ const memory = new MotorheadMemory({
sessionId, sessionId,

View file

@ -14,6 +14,8 @@ import type { RedisClientOptions } from 'redis';
import { createClient } from 'redis'; import { createClient } from 'redis';
import { logWrapper } from '../../../utils/logWrapper'; import { logWrapper } from '../../../utils/logWrapper';
import { getConnectionHintNoticeField } from '../../../utils/sharedFields'; import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
import { sessionIdOption, sessionKeyProperty } from '../descriptions';
import { getSessionId } from '../../../utils/helpers';
export class MemoryRedisChat implements INodeType { export class MemoryRedisChat implements INodeType {
description: INodeTypeDescription = { description: INodeTypeDescription = {
@ -21,7 +23,7 @@ export class MemoryRedisChat implements INodeType {
name: 'memoryRedisChat', name: 'memoryRedisChat',
icon: 'file:redis.svg', icon: 'file:redis.svg',
group: ['transform'], group: ['transform'],
version: [1, 1.1], version: [1, 1.1, 1.2],
description: 'Stores the chat history in Redis.', description: 'Stores the chat history in Redis.',
defaults: { defaults: {
name: 'Redis Chat Memory', name: 'Redis Chat Memory',
@ -76,6 +78,15 @@ export class MemoryRedisChat implements INodeType {
}, },
}, },
}, },
{
...sessionIdOption,
displayOptions: {
show: {
'@version': [{ _cnd: { gte: 1.2 } }],
},
},
},
sessionKeyProperty,
{ {
displayName: 'Session Time To Live', displayName: 'Session Time To Live',
name: 'sessionTTL', name: 'sessionTTL',
@ -89,9 +100,18 @@ export class MemoryRedisChat implements INodeType {
async supplyData(this: IExecuteFunctions, itemIndex: number): Promise<SupplyData> { async supplyData(this: IExecuteFunctions, itemIndex: number): Promise<SupplyData> {
const credentials = await this.getCredentials('redis'); const credentials = await this.getCredentials('redis');
const sessionKey = this.getNodeParameter('sessionKey', itemIndex) as string; const nodeVersion = this.getNode().typeVersion;
const sessionTTL = this.getNodeParameter('sessionTTL', itemIndex, 0) as number; const sessionTTL = this.getNodeParameter('sessionTTL', itemIndex, 0) as number;
let sessionId;
if (nodeVersion >= 1.2) {
sessionId = getSessionId(this, itemIndex);
} else {
sessionId = this.getNodeParameter('sessionKey', itemIndex) as string;
}
const redisOptions: RedisClientOptions = { const redisOptions: RedisClientOptions = {
socket: { socket: {
host: credentials.host as string, host: credentials.host as string,
@ -115,7 +135,7 @@ export class MemoryRedisChat implements INodeType {
const redisChatConfig: RedisChatMessageHistoryInput = { const redisChatConfig: RedisChatMessageHistoryInput = {
client, client,
sessionId: sessionKey, sessionId,
}; };
if (sessionTTL > 0) { if (sessionTTL > 0) {

View file

@ -6,13 +6,16 @@ import { BufferMemory } from 'langchain/memory';
import { BaseClient } from '@xata.io/client'; import { BaseClient } from '@xata.io/client';
import { logWrapper } from '../../../utils/logWrapper'; import { logWrapper } from '../../../utils/logWrapper';
import { getConnectionHintNoticeField } from '../../../utils/sharedFields'; import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
import { sessionIdOption, sessionKeyProperty } from '../descriptions';
import { getSessionId } from '../../../utils/helpers';
export class MemoryXata implements INodeType { export class MemoryXata implements INodeType {
description: INodeTypeDescription = { description: INodeTypeDescription = {
displayName: 'Xata', displayName: 'Xata',
name: 'memoryXata', name: 'memoryXata',
icon: 'file:xata.svg', icon: 'file:xata.svg',
group: ['transform'], group: ['transform'],
version: [1, 1.1], version: [1, 1.1, 1.2],
description: 'Use Xata Memory', description: 'Use Xata Memory',
defaults: { defaults: {
name: 'Xata', name: 'Xata',
@ -69,11 +72,29 @@ export class MemoryXata implements INodeType {
}, },
}, },
}, },
{
...sessionIdOption,
displayOptions: {
show: {
'@version': [{ _cnd: { gte: 1.2 } }],
},
},
},
sessionKeyProperty,
], ],
}; };
async supplyData(this: IExecuteFunctions, itemIndex: number): Promise<SupplyData> { async supplyData(this: IExecuteFunctions, itemIndex: number): Promise<SupplyData> {
const credentials = await this.getCredentials('xataApi'); const credentials = await this.getCredentials('xataApi');
const nodeVersion = this.getNode().typeVersion;
let sessionId;
if (nodeVersion >= 1.2) {
sessionId = getSessionId(this, itemIndex);
} else {
sessionId = this.getNodeParameter('sessionId', itemIndex) as string;
}
const xataClient = new BaseClient({ const xataClient = new BaseClient({
apiKey: credentials.apiKey as string, apiKey: credentials.apiKey as string,
@ -81,8 +102,6 @@ export class MemoryXata implements INodeType {
databaseURL: credentials.databaseEndpoint as string, databaseURL: credentials.databaseEndpoint as string,
}); });
const sessionId = this.getNodeParameter('sessionId', itemIndex) as string;
const table = (credentials.databaseEndpoint as string).match( const table = (credentials.databaseEndpoint as string).match(
/https:\/\/[^.]+\.[^.]+\.xata\.sh\/db\/([^\/:]+)/, /https:\/\/[^.]+\.[^.]+\.xata\.sh\/db\/([^\/:]+)/,
); );
@ -94,18 +113,21 @@ export class MemoryXata implements INodeType {
); );
} }
const memory = new BufferMemory({ const chatHistory = new XataChatMessageHistory({
chatHistory: new XataChatMessageHistory({
table: table[1], table: table[1],
sessionId, sessionId,
client: xataClient, client: xataClient,
apiKey: credentials.apiKey as string, apiKey: credentials.apiKey as string,
}), });
const memory = new BufferMemory({
chatHistory,
memoryKey: 'chat_history', memoryKey: 'chat_history',
returnMessages: true, returnMessages: true,
inputKey: 'input', inputKey: 'input',
outputKey: 'output', outputKey: 'output',
}); });
return { return {
response: logWrapper(memory, this), response: logWrapper(memory, this),
}; };

View file

@ -9,6 +9,8 @@ import {
import { ZepMemory } from 'langchain/memory/zep'; import { ZepMemory } from 'langchain/memory/zep';
import { logWrapper } from '../../../utils/logWrapper'; import { logWrapper } from '../../../utils/logWrapper';
import { getConnectionHintNoticeField } from '../../../utils/sharedFields'; import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
import { sessionIdOption, sessionKeyProperty } from '../descriptions';
import { getSessionId } from '../../../utils/helpers';
export class MemoryZep implements INodeType { export class MemoryZep implements INodeType {
description: INodeTypeDescription = { description: INodeTypeDescription = {
@ -17,7 +19,7 @@ export class MemoryZep implements INodeType {
// eslint-disable-next-line n8n-nodes-base/node-class-description-icon-not-svg // eslint-disable-next-line n8n-nodes-base/node-class-description-icon-not-svg
icon: 'file:zep.png', icon: 'file:zep.png',
group: ['transform'], group: ['transform'],
version: [1, 1.1], version: [1, 1.1, 1.2],
description: 'Use Zep Memory', description: 'Use Zep Memory',
defaults: { defaults: {
name: 'Zep', name: 'Zep',
@ -72,6 +74,15 @@ export class MemoryZep implements INodeType {
}, },
}, },
}, },
{
...sessionIdOption,
displayOptions: {
show: {
'@version': [{ _cnd: { gte: 1.2 } }],
},
},
},
sessionKeyProperty,
], ],
}; };
@ -81,8 +92,15 @@ export class MemoryZep implements INodeType {
apiUrl: string; apiUrl: string;
}; };
// TODO: Should it get executed once per item or not? const nodeVersion = this.getNode().typeVersion;
const sessionId = this.getNodeParameter('sessionId', itemIndex) as string;
let sessionId;
if (nodeVersion >= 1.2) {
sessionId = getSessionId(this, itemIndex);
} else {
sessionId = this.getNodeParameter('sessionId', itemIndex) as string;
}
const memory = new ZepMemory({ const memory = new ZepMemory({
sessionId, sessionId,

View file

@ -0,0 +1,35 @@
import type { INodeProperties } from 'n8n-workflow';
export const sessionIdOption: INodeProperties = {
displayName: 'Session ID',
name: 'sessionIdType',
type: 'options',
options: [
{
// eslint-disable-next-line n8n-nodes-base/node-param-display-name-miscased
name: 'Take from previous node automatically',
value: 'fromInput',
description: 'Looks for an input field called sessionId',
},
{
// eslint-disable-next-line n8n-nodes-base/node-param-display-name-miscased
name: 'Define below',
value: 'customKey',
description: 'Use an expression to reference data in previous nodes or enter static text',
},
],
default: 'fromInput',
};
export const sessionKeyProperty: INodeProperties = {
displayName: 'Key',
name: 'sessionKey',
type: 'string',
default: '',
description: 'The key to use to store session ID in the memory',
displayOptions: {
show: {
sessionIdType: ['customKey'],
},
},
};

View file

@ -3,7 +3,7 @@ import type { EventNamesAiNodesType, IDataObject, IExecuteFunctions } from 'n8n-
import { BaseChatModel } from 'langchain/chat_models/base'; import { BaseChatModel } from 'langchain/chat_models/base';
import { BaseChatModel as BaseChatModelCore } from '@langchain/core/language_models/chat_models'; import { BaseChatModel as BaseChatModelCore } from '@langchain/core/language_models/chat_models';
import type { BaseOutputParser } from '@langchain/core/output_parsers'; import type { BaseOutputParser } from '@langchain/core/output_parsers';
import { BaseMessage } from 'langchain/schema'; import type { BaseMessage } from 'langchain/schema';
export function getMetadataFiltersValues( export function getMetadataFiltersValues(
ctx: IExecuteFunctions, ctx: IExecuteFunctions,
@ -67,6 +67,39 @@ export function getPromptInputByType(options: {
return input; return input;
} }
export function getSessionId(
ctx: IExecuteFunctions,
itemIndex: number,
selectorKey = 'sessionIdType',
autoSelect = 'fromInput',
customKey = 'sessionKey',
) {
let sessionId = '';
const selectorType = ctx.getNodeParameter(selectorKey, itemIndex) as string;
if (selectorType === autoSelect) {
sessionId = ctx.evaluateExpression('{{ $json.sessionId }}', itemIndex) as string;
if (sessionId === '' || sessionId === undefined) {
throw new NodeOperationError(ctx.getNode(), 'No session ID found', {
description:
"Expected to find the session ID in an input field called 'sessionId' (this is what the chat trigger node outputs). To use something else, change the 'Session ID' parameter",
itemIndex,
});
}
} else {
sessionId = ctx.getNodeParameter(customKey, itemIndex, '') as string;
if (sessionId === '' || sessionId === undefined) {
throw new NodeOperationError(ctx.getNode(), 'Key parameter is empty', {
description:
"Provide a key to use as session ID in the 'Key' parameter or use the 'Take from previous node automatically' option to use the session ID from the previous node, e.t. chat trigger node",
itemIndex,
});
}
}
return sessionId;
}
export async function logAiEvent( export async function logAiEvent(
executeFunctions: IExecuteFunctions, executeFunctions: IExecuteFunctions,
event: EventNamesAiNodesType, event: EventNamesAiNodesType,
@ -79,7 +112,7 @@ export async function logAiEvent(
} }
} }
export function serializeChatHistory (chatHistory: Array<BaseMessage>): string { export function serializeChatHistory(chatHistory: BaseMessage[]): string {
return chatHistory return chatHistory
.map((chatMessage) => { .map((chatMessage) => {
if (chatMessage._getType() === 'human') { if (chatMessage._getType() === 'human') {