feat: Session is selector for memory nodes (#8736)

This commit is contained in:
Michael Kret 2024-02-27 15:01:15 +02:00 committed by GitHub
parent 5f6da7b84e
commit 2aaf211dfc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 188 additions and 22 deletions

View file

@ -10,6 +10,8 @@ import type { BufferWindowMemoryInput } from 'langchain/memory';
import { BufferWindowMemory } from 'langchain/memory';
import { logWrapper } from '../../../utils/logWrapper';
import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
import { sessionIdOption, sessionKeyProperty } from '../descriptions';
import { getSessionId } from '../../../utils/helpers';
class MemoryChatBufferSingleton {
private static instance: MemoryChatBufferSingleton;
@ -70,7 +72,7 @@ export class MemoryBufferWindow implements INodeType {
name: 'memoryBufferWindow',
icon: 'fa:database',
group: ['transform'],
version: [1, 1.1],
version: [1, 1.1, 1.2],
description: 'Stores in n8n memory, so no credentials required',
defaults: {
name: 'Window Buffer Memory',
@ -119,6 +121,15 @@ export class MemoryBufferWindow implements INodeType {
},
},
},
{
...sessionIdOption,
displayOptions: {
show: {
'@version': [{ _cnd: { gte: 1.2 } }],
},
},
},
sessionKeyProperty,
{
displayName: 'Context Window Length',
name: 'contextWindowLength',
@ -130,12 +141,21 @@ export class MemoryBufferWindow implements INodeType {
};
async supplyData(this: IExecuteFunctions, itemIndex: number): Promise<SupplyData> {
const sessionKey = this.getNodeParameter('sessionKey', itemIndex) as string;
const contextWindowLength = this.getNodeParameter('contextWindowLength', itemIndex) as number;
const workflowId = this.getWorkflow().id;
const memoryInstance = MemoryChatBufferSingleton.getInstance();
const memory = await memoryInstance.getMemory(`${workflowId}__${sessionKey}`, {
const nodeVersion = this.getNode().typeVersion;
let sessionId;
if (nodeVersion >= 1.2) {
sessionId = getSessionId(this, itemIndex);
} else {
sessionId = this.getNodeParameter('sessionKey', itemIndex) as string;
}
const memory = await memoryInstance.getMemory(`${workflowId}__${sessionId}`, {
k: contextWindowLength,
inputKey: 'input',
memoryKey: 'chat_history',

View file

@ -10,6 +10,8 @@ import {
import { MotorheadMemory } from 'langchain/memory';
import { logWrapper } from '../../../utils/logWrapper';
import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
import { sessionIdOption, sessionKeyProperty } from '../descriptions';
import { getSessionId } from '../../../utils/helpers';
export class MemoryMotorhead implements INodeType {
description: INodeTypeDescription = {
@ -17,7 +19,7 @@ export class MemoryMotorhead implements INodeType {
name: 'memoryMotorhead',
icon: 'fa:file-export',
group: ['transform'],
version: [1, 1.1],
version: [1, 1.1, 1.2],
description: 'Use Motorhead Memory',
defaults: {
name: 'Motorhead',
@ -72,13 +74,29 @@ export class MemoryMotorhead implements INodeType {
},
},
},
{
...sessionIdOption,
displayOptions: {
show: {
'@version': [{ _cnd: { gte: 1.2 } }],
},
},
},
sessionKeyProperty,
],
};
async supplyData(this: IExecuteFunctions, itemIndex: number): Promise<SupplyData> {
const credentials = await this.getCredentials('motorheadApi');
const nodeVersion = this.getNode().typeVersion;
const sessionId = this.getNodeParameter('sessionId', itemIndex) as string;
let sessionId;
if (nodeVersion >= 1.2) {
sessionId = getSessionId(this, itemIndex);
} else {
sessionId = this.getNodeParameter('sessionId', itemIndex) as string;
}
const memory = new MotorheadMemory({
sessionId,

View file

@ -14,6 +14,8 @@ import type { RedisClientOptions } from 'redis';
import { createClient } from 'redis';
import { logWrapper } from '../../../utils/logWrapper';
import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
import { sessionIdOption, sessionKeyProperty } from '../descriptions';
import { getSessionId } from '../../../utils/helpers';
export class MemoryRedisChat implements INodeType {
description: INodeTypeDescription = {
@ -21,7 +23,7 @@ export class MemoryRedisChat implements INodeType {
name: 'memoryRedisChat',
icon: 'file:redis.svg',
group: ['transform'],
version: [1, 1.1],
version: [1, 1.1, 1.2],
description: 'Stores the chat history in Redis.',
defaults: {
name: 'Redis Chat Memory',
@ -76,6 +78,15 @@ export class MemoryRedisChat implements INodeType {
},
},
},
{
...sessionIdOption,
displayOptions: {
show: {
'@version': [{ _cnd: { gte: 1.2 } }],
},
},
},
sessionKeyProperty,
{
displayName: 'Session Time To Live',
name: 'sessionTTL',
@ -89,9 +100,18 @@ export class MemoryRedisChat implements INodeType {
async supplyData(this: IExecuteFunctions, itemIndex: number): Promise<SupplyData> {
const credentials = await this.getCredentials('redis');
const sessionKey = this.getNodeParameter('sessionKey', itemIndex) as string;
const nodeVersion = this.getNode().typeVersion;
const sessionTTL = this.getNodeParameter('sessionTTL', itemIndex, 0) as number;
let sessionId;
if (nodeVersion >= 1.2) {
sessionId = getSessionId(this, itemIndex);
} else {
sessionId = this.getNodeParameter('sessionKey', itemIndex) as string;
}
const redisOptions: RedisClientOptions = {
socket: {
host: credentials.host as string,
@ -115,7 +135,7 @@ export class MemoryRedisChat implements INodeType {
const redisChatConfig: RedisChatMessageHistoryInput = {
client,
sessionId: sessionKey,
sessionId,
};
if (sessionTTL > 0) {

View file

@ -6,13 +6,16 @@ import { BufferMemory } from 'langchain/memory';
import { BaseClient } from '@xata.io/client';
import { logWrapper } from '../../../utils/logWrapper';
import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
import { sessionIdOption, sessionKeyProperty } from '../descriptions';
import { getSessionId } from '../../../utils/helpers';
export class MemoryXata implements INodeType {
description: INodeTypeDescription = {
displayName: 'Xata',
name: 'memoryXata',
icon: 'file:xata.svg',
group: ['transform'],
version: [1, 1.1],
version: [1, 1.1, 1.2],
description: 'Use Xata Memory',
defaults: {
name: 'Xata',
@ -69,11 +72,29 @@ export class MemoryXata implements INodeType {
},
},
},
{
...sessionIdOption,
displayOptions: {
show: {
'@version': [{ _cnd: { gte: 1.2 } }],
},
},
},
sessionKeyProperty,
],
};
async supplyData(this: IExecuteFunctions, itemIndex: number): Promise<SupplyData> {
const credentials = await this.getCredentials('xataApi');
const nodeVersion = this.getNode().typeVersion;
let sessionId;
if (nodeVersion >= 1.2) {
sessionId = getSessionId(this, itemIndex);
} else {
sessionId = this.getNodeParameter('sessionId', itemIndex) as string;
}
const xataClient = new BaseClient({
apiKey: credentials.apiKey as string,
@ -81,8 +102,6 @@ export class MemoryXata implements INodeType {
databaseURL: credentials.databaseEndpoint as string,
});
const sessionId = this.getNodeParameter('sessionId', itemIndex) as string;
const table = (credentials.databaseEndpoint as string).match(
/https:\/\/[^.]+\.[^.]+\.xata\.sh\/db\/([^\/:]+)/,
);
@ -94,18 +113,21 @@ export class MemoryXata implements INodeType {
);
}
const chatHistory = new XataChatMessageHistory({
table: table[1],
sessionId,
client: xataClient,
apiKey: credentials.apiKey as string,
});
const memory = new BufferMemory({
chatHistory: new XataChatMessageHistory({
table: table[1],
sessionId,
client: xataClient,
apiKey: credentials.apiKey as string,
}),
chatHistory,
memoryKey: 'chat_history',
returnMessages: true,
inputKey: 'input',
outputKey: 'output',
});
return {
response: logWrapper(memory, this),
};

View file

@ -9,6 +9,8 @@ import {
import { ZepMemory } from 'langchain/memory/zep';
import { logWrapper } from '../../../utils/logWrapper';
import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
import { sessionIdOption, sessionKeyProperty } from '../descriptions';
import { getSessionId } from '../../../utils/helpers';
export class MemoryZep implements INodeType {
description: INodeTypeDescription = {
@ -17,7 +19,7 @@ export class MemoryZep implements INodeType {
// eslint-disable-next-line n8n-nodes-base/node-class-description-icon-not-svg
icon: 'file:zep.png',
group: ['transform'],
version: [1, 1.1],
version: [1, 1.1, 1.2],
description: 'Use Zep Memory',
defaults: {
name: 'Zep',
@ -72,6 +74,15 @@ export class MemoryZep implements INodeType {
},
},
},
{
...sessionIdOption,
displayOptions: {
show: {
'@version': [{ _cnd: { gte: 1.2 } }],
},
},
},
sessionKeyProperty,
],
};
@ -81,8 +92,15 @@ export class MemoryZep implements INodeType {
apiUrl: string;
};
// TODO: Should it get executed once per item or not?
const sessionId = this.getNodeParameter('sessionId', itemIndex) as string;
const nodeVersion = this.getNode().typeVersion;
let sessionId;
if (nodeVersion >= 1.2) {
sessionId = getSessionId(this, itemIndex);
} else {
sessionId = this.getNodeParameter('sessionId', itemIndex) as string;
}
const memory = new ZepMemory({
sessionId,

View file

@ -0,0 +1,35 @@
import type { INodeProperties } from 'n8n-workflow';
export const sessionIdOption: INodeProperties = {
displayName: 'Session ID',
name: 'sessionIdType',
type: 'options',
options: [
{
// eslint-disable-next-line n8n-nodes-base/node-param-display-name-miscased
name: 'Take from previous node automatically',
value: 'fromInput',
description: 'Looks for an input field called sessionId',
},
{
// eslint-disable-next-line n8n-nodes-base/node-param-display-name-miscased
name: 'Define below',
value: 'customKey',
description: 'Use an expression to reference data in previous nodes or enter static text',
},
],
default: 'fromInput',
};
export const sessionKeyProperty: INodeProperties = {
displayName: 'Key',
name: 'sessionKey',
type: 'string',
default: '',
description: 'The key to use to store session ID in the memory',
displayOptions: {
show: {
sessionIdType: ['customKey'],
},
},
};

View file

@ -3,7 +3,7 @@ import type { EventNamesAiNodesType, IDataObject, IExecuteFunctions } from 'n8n-
import { BaseChatModel } from 'langchain/chat_models/base';
import { BaseChatModel as BaseChatModelCore } from '@langchain/core/language_models/chat_models';
import type { BaseOutputParser } from '@langchain/core/output_parsers';
import { BaseMessage } from 'langchain/schema';
import type { BaseMessage } from 'langchain/schema';
export function getMetadataFiltersValues(
ctx: IExecuteFunctions,
@ -67,6 +67,39 @@ export function getPromptInputByType(options: {
return input;
}
export function getSessionId(
ctx: IExecuteFunctions,
itemIndex: number,
selectorKey = 'sessionIdType',
autoSelect = 'fromInput',
customKey = 'sessionKey',
) {
let sessionId = '';
const selectorType = ctx.getNodeParameter(selectorKey, itemIndex) as string;
if (selectorType === autoSelect) {
sessionId = ctx.evaluateExpression('{{ $json.sessionId }}', itemIndex) as string;
if (sessionId === '' || sessionId === undefined) {
throw new NodeOperationError(ctx.getNode(), 'No session ID found', {
description:
"Expected to find the session ID in an input field called 'sessionId' (this is what the chat trigger node outputs). To use something else, change the 'Session ID' parameter",
itemIndex,
});
}
} else {
sessionId = ctx.getNodeParameter(customKey, itemIndex, '') as string;
if (sessionId === '' || sessionId === undefined) {
throw new NodeOperationError(ctx.getNode(), 'Key parameter is empty', {
description:
"Provide a key to use as session ID in the 'Key' parameter or use the 'Take from previous node automatically' option to use the session ID from the previous node, e.t. chat trigger node",
itemIndex,
});
}
}
return sessionId;
}
export async function logAiEvent(
executeFunctions: IExecuteFunctions,
event: EventNamesAiNodesType,
@ -79,7 +112,7 @@ export async function logAiEvent(
}
}
export function serializeChatHistory (chatHistory: Array<BaseMessage>): string {
export function serializeChatHistory(chatHistory: BaseMessage[]): string {
return chatHistory
.map((chatMessage) => {
if (chatMessage._getType() === 'human') {