mirror of
https://github.com/n8n-io/n8n.git
synced 2024-12-25 12:44:07 -08:00
feat: Session is selector for memory nodes (#8736)
This commit is contained in:
parent
5f6da7b84e
commit
2aaf211dfc
|
@ -10,6 +10,8 @@ import type { BufferWindowMemoryInput } from 'langchain/memory';
|
|||
import { BufferWindowMemory } from 'langchain/memory';
|
||||
import { logWrapper } from '../../../utils/logWrapper';
|
||||
import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
|
||||
import { sessionIdOption, sessionKeyProperty } from '../descriptions';
|
||||
import { getSessionId } from '../../../utils/helpers';
|
||||
|
||||
class MemoryChatBufferSingleton {
|
||||
private static instance: MemoryChatBufferSingleton;
|
||||
|
@ -70,7 +72,7 @@ export class MemoryBufferWindow implements INodeType {
|
|||
name: 'memoryBufferWindow',
|
||||
icon: 'fa:database',
|
||||
group: ['transform'],
|
||||
version: [1, 1.1],
|
||||
version: [1, 1.1, 1.2],
|
||||
description: 'Stores in n8n memory, so no credentials required',
|
||||
defaults: {
|
||||
name: 'Window Buffer Memory',
|
||||
|
@ -119,6 +121,15 @@ export class MemoryBufferWindow implements INodeType {
|
|||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
...sessionIdOption,
|
||||
displayOptions: {
|
||||
show: {
|
||||
'@version': [{ _cnd: { gte: 1.2 } }],
|
||||
},
|
||||
},
|
||||
},
|
||||
sessionKeyProperty,
|
||||
{
|
||||
displayName: 'Context Window Length',
|
||||
name: 'contextWindowLength',
|
||||
|
@ -130,12 +141,21 @@ export class MemoryBufferWindow implements INodeType {
|
|||
};
|
||||
|
||||
async supplyData(this: IExecuteFunctions, itemIndex: number): Promise<SupplyData> {
|
||||
const sessionKey = this.getNodeParameter('sessionKey', itemIndex) as string;
|
||||
const contextWindowLength = this.getNodeParameter('contextWindowLength', itemIndex) as number;
|
||||
const workflowId = this.getWorkflow().id;
|
||||
const memoryInstance = MemoryChatBufferSingleton.getInstance();
|
||||
|
||||
const memory = await memoryInstance.getMemory(`${workflowId}__${sessionKey}`, {
|
||||
const nodeVersion = this.getNode().typeVersion;
|
||||
|
||||
let sessionId;
|
||||
|
||||
if (nodeVersion >= 1.2) {
|
||||
sessionId = getSessionId(this, itemIndex);
|
||||
} else {
|
||||
sessionId = this.getNodeParameter('sessionKey', itemIndex) as string;
|
||||
}
|
||||
|
||||
const memory = await memoryInstance.getMemory(`${workflowId}__${sessionId}`, {
|
||||
k: contextWindowLength,
|
||||
inputKey: 'input',
|
||||
memoryKey: 'chat_history',
|
||||
|
|
|
@ -10,6 +10,8 @@ import {
|
|||
import { MotorheadMemory } from 'langchain/memory';
|
||||
import { logWrapper } from '../../../utils/logWrapper';
|
||||
import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
|
||||
import { sessionIdOption, sessionKeyProperty } from '../descriptions';
|
||||
import { getSessionId } from '../../../utils/helpers';
|
||||
|
||||
export class MemoryMotorhead implements INodeType {
|
||||
description: INodeTypeDescription = {
|
||||
|
@ -17,7 +19,7 @@ export class MemoryMotorhead implements INodeType {
|
|||
name: 'memoryMotorhead',
|
||||
icon: 'fa:file-export',
|
||||
group: ['transform'],
|
||||
version: [1, 1.1],
|
||||
version: [1, 1.1, 1.2],
|
||||
description: 'Use Motorhead Memory',
|
||||
defaults: {
|
||||
name: 'Motorhead',
|
||||
|
@ -72,13 +74,29 @@ export class MemoryMotorhead implements INodeType {
|
|||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
...sessionIdOption,
|
||||
displayOptions: {
|
||||
show: {
|
||||
'@version': [{ _cnd: { gte: 1.2 } }],
|
||||
},
|
||||
},
|
||||
},
|
||||
sessionKeyProperty,
|
||||
],
|
||||
};
|
||||
|
||||
async supplyData(this: IExecuteFunctions, itemIndex: number): Promise<SupplyData> {
|
||||
const credentials = await this.getCredentials('motorheadApi');
|
||||
const nodeVersion = this.getNode().typeVersion;
|
||||
|
||||
const sessionId = this.getNodeParameter('sessionId', itemIndex) as string;
|
||||
let sessionId;
|
||||
|
||||
if (nodeVersion >= 1.2) {
|
||||
sessionId = getSessionId(this, itemIndex);
|
||||
} else {
|
||||
sessionId = this.getNodeParameter('sessionId', itemIndex) as string;
|
||||
}
|
||||
|
||||
const memory = new MotorheadMemory({
|
||||
sessionId,
|
||||
|
|
|
@ -14,6 +14,8 @@ import type { RedisClientOptions } from 'redis';
|
|||
import { createClient } from 'redis';
|
||||
import { logWrapper } from '../../../utils/logWrapper';
|
||||
import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
|
||||
import { sessionIdOption, sessionKeyProperty } from '../descriptions';
|
||||
import { getSessionId } from '../../../utils/helpers';
|
||||
|
||||
export class MemoryRedisChat implements INodeType {
|
||||
description: INodeTypeDescription = {
|
||||
|
@ -21,7 +23,7 @@ export class MemoryRedisChat implements INodeType {
|
|||
name: 'memoryRedisChat',
|
||||
icon: 'file:redis.svg',
|
||||
group: ['transform'],
|
||||
version: [1, 1.1],
|
||||
version: [1, 1.1, 1.2],
|
||||
description: 'Stores the chat history in Redis.',
|
||||
defaults: {
|
||||
name: 'Redis Chat Memory',
|
||||
|
@ -76,6 +78,15 @@ export class MemoryRedisChat implements INodeType {
|
|||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
...sessionIdOption,
|
||||
displayOptions: {
|
||||
show: {
|
||||
'@version': [{ _cnd: { gte: 1.2 } }],
|
||||
},
|
||||
},
|
||||
},
|
||||
sessionKeyProperty,
|
||||
{
|
||||
displayName: 'Session Time To Live',
|
||||
name: 'sessionTTL',
|
||||
|
@ -89,9 +100,18 @@ export class MemoryRedisChat implements INodeType {
|
|||
|
||||
async supplyData(this: IExecuteFunctions, itemIndex: number): Promise<SupplyData> {
|
||||
const credentials = await this.getCredentials('redis');
|
||||
const sessionKey = this.getNodeParameter('sessionKey', itemIndex) as string;
|
||||
const nodeVersion = this.getNode().typeVersion;
|
||||
|
||||
const sessionTTL = this.getNodeParameter('sessionTTL', itemIndex, 0) as number;
|
||||
|
||||
let sessionId;
|
||||
|
||||
if (nodeVersion >= 1.2) {
|
||||
sessionId = getSessionId(this, itemIndex);
|
||||
} else {
|
||||
sessionId = this.getNodeParameter('sessionKey', itemIndex) as string;
|
||||
}
|
||||
|
||||
const redisOptions: RedisClientOptions = {
|
||||
socket: {
|
||||
host: credentials.host as string,
|
||||
|
@ -115,7 +135,7 @@ export class MemoryRedisChat implements INodeType {
|
|||
|
||||
const redisChatConfig: RedisChatMessageHistoryInput = {
|
||||
client,
|
||||
sessionId: sessionKey,
|
||||
sessionId,
|
||||
};
|
||||
|
||||
if (sessionTTL > 0) {
|
||||
|
|
|
@ -6,13 +6,16 @@ import { BufferMemory } from 'langchain/memory';
|
|||
import { BaseClient } from '@xata.io/client';
|
||||
import { logWrapper } from '../../../utils/logWrapper';
|
||||
import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
|
||||
import { sessionIdOption, sessionKeyProperty } from '../descriptions';
|
||||
import { getSessionId } from '../../../utils/helpers';
|
||||
|
||||
export class MemoryXata implements INodeType {
|
||||
description: INodeTypeDescription = {
|
||||
displayName: 'Xata',
|
||||
name: 'memoryXata',
|
||||
icon: 'file:xata.svg',
|
||||
group: ['transform'],
|
||||
version: [1, 1.1],
|
||||
version: [1, 1.1, 1.2],
|
||||
description: 'Use Xata Memory',
|
||||
defaults: {
|
||||
name: 'Xata',
|
||||
|
@ -69,11 +72,29 @@ export class MemoryXata implements INodeType {
|
|||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
...sessionIdOption,
|
||||
displayOptions: {
|
||||
show: {
|
||||
'@version': [{ _cnd: { gte: 1.2 } }],
|
||||
},
|
||||
},
|
||||
},
|
||||
sessionKeyProperty,
|
||||
],
|
||||
};
|
||||
|
||||
async supplyData(this: IExecuteFunctions, itemIndex: number): Promise<SupplyData> {
|
||||
const credentials = await this.getCredentials('xataApi');
|
||||
const nodeVersion = this.getNode().typeVersion;
|
||||
|
||||
let sessionId;
|
||||
|
||||
if (nodeVersion >= 1.2) {
|
||||
sessionId = getSessionId(this, itemIndex);
|
||||
} else {
|
||||
sessionId = this.getNodeParameter('sessionId', itemIndex) as string;
|
||||
}
|
||||
|
||||
const xataClient = new BaseClient({
|
||||
apiKey: credentials.apiKey as string,
|
||||
|
@ -81,8 +102,6 @@ export class MemoryXata implements INodeType {
|
|||
databaseURL: credentials.databaseEndpoint as string,
|
||||
});
|
||||
|
||||
const sessionId = this.getNodeParameter('sessionId', itemIndex) as string;
|
||||
|
||||
const table = (credentials.databaseEndpoint as string).match(
|
||||
/https:\/\/[^.]+\.[^.]+\.xata\.sh\/db\/([^\/:]+)/,
|
||||
);
|
||||
|
@ -94,18 +113,21 @@ export class MemoryXata implements INodeType {
|
|||
);
|
||||
}
|
||||
|
||||
const memory = new BufferMemory({
|
||||
chatHistory: new XataChatMessageHistory({
|
||||
const chatHistory = new XataChatMessageHistory({
|
||||
table: table[1],
|
||||
sessionId,
|
||||
client: xataClient,
|
||||
apiKey: credentials.apiKey as string,
|
||||
}),
|
||||
});
|
||||
|
||||
const memory = new BufferMemory({
|
||||
chatHistory,
|
||||
memoryKey: 'chat_history',
|
||||
returnMessages: true,
|
||||
inputKey: 'input',
|
||||
outputKey: 'output',
|
||||
});
|
||||
|
||||
return {
|
||||
response: logWrapper(memory, this),
|
||||
};
|
||||
|
|
|
@ -9,6 +9,8 @@ import {
|
|||
import { ZepMemory } from 'langchain/memory/zep';
|
||||
import { logWrapper } from '../../../utils/logWrapper';
|
||||
import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
|
||||
import { sessionIdOption, sessionKeyProperty } from '../descriptions';
|
||||
import { getSessionId } from '../../../utils/helpers';
|
||||
|
||||
export class MemoryZep implements INodeType {
|
||||
description: INodeTypeDescription = {
|
||||
|
@ -17,7 +19,7 @@ export class MemoryZep implements INodeType {
|
|||
// eslint-disable-next-line n8n-nodes-base/node-class-description-icon-not-svg
|
||||
icon: 'file:zep.png',
|
||||
group: ['transform'],
|
||||
version: [1, 1.1],
|
||||
version: [1, 1.1, 1.2],
|
||||
description: 'Use Zep Memory',
|
||||
defaults: {
|
||||
name: 'Zep',
|
||||
|
@ -72,6 +74,15 @@ export class MemoryZep implements INodeType {
|
|||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
...sessionIdOption,
|
||||
displayOptions: {
|
||||
show: {
|
||||
'@version': [{ _cnd: { gte: 1.2 } }],
|
||||
},
|
||||
},
|
||||
},
|
||||
sessionKeyProperty,
|
||||
],
|
||||
};
|
||||
|
||||
|
@ -81,8 +92,15 @@ export class MemoryZep implements INodeType {
|
|||
apiUrl: string;
|
||||
};
|
||||
|
||||
// TODO: Should it get executed once per item or not?
|
||||
const sessionId = this.getNodeParameter('sessionId', itemIndex) as string;
|
||||
const nodeVersion = this.getNode().typeVersion;
|
||||
|
||||
let sessionId;
|
||||
|
||||
if (nodeVersion >= 1.2) {
|
||||
sessionId = getSessionId(this, itemIndex);
|
||||
} else {
|
||||
sessionId = this.getNodeParameter('sessionId', itemIndex) as string;
|
||||
}
|
||||
|
||||
const memory = new ZepMemory({
|
||||
sessionId,
|
||||
|
|
35
packages/@n8n/nodes-langchain/nodes/memory/descriptions.ts
Normal file
35
packages/@n8n/nodes-langchain/nodes/memory/descriptions.ts
Normal file
|
@ -0,0 +1,35 @@
|
|||
import type { INodeProperties } from 'n8n-workflow';
|
||||
|
||||
export const sessionIdOption: INodeProperties = {
|
||||
displayName: 'Session ID',
|
||||
name: 'sessionIdType',
|
||||
type: 'options',
|
||||
options: [
|
||||
{
|
||||
// eslint-disable-next-line n8n-nodes-base/node-param-display-name-miscased
|
||||
name: 'Take from previous node automatically',
|
||||
value: 'fromInput',
|
||||
description: 'Looks for an input field called sessionId',
|
||||
},
|
||||
{
|
||||
// eslint-disable-next-line n8n-nodes-base/node-param-display-name-miscased
|
||||
name: 'Define below',
|
||||
value: 'customKey',
|
||||
description: 'Use an expression to reference data in previous nodes or enter static text',
|
||||
},
|
||||
],
|
||||
default: 'fromInput',
|
||||
};
|
||||
|
||||
export const sessionKeyProperty: INodeProperties = {
|
||||
displayName: 'Key',
|
||||
name: 'sessionKey',
|
||||
type: 'string',
|
||||
default: '',
|
||||
description: 'The key to use to store session ID in the memory',
|
||||
displayOptions: {
|
||||
show: {
|
||||
sessionIdType: ['customKey'],
|
||||
},
|
||||
},
|
||||
};
|
|
@ -3,7 +3,7 @@ import type { EventNamesAiNodesType, IDataObject, IExecuteFunctions } from 'n8n-
|
|||
import { BaseChatModel } from 'langchain/chat_models/base';
|
||||
import { BaseChatModel as BaseChatModelCore } from '@langchain/core/language_models/chat_models';
|
||||
import type { BaseOutputParser } from '@langchain/core/output_parsers';
|
||||
import { BaseMessage } from 'langchain/schema';
|
||||
import type { BaseMessage } from 'langchain/schema';
|
||||
|
||||
export function getMetadataFiltersValues(
|
||||
ctx: IExecuteFunctions,
|
||||
|
@ -67,6 +67,39 @@ export function getPromptInputByType(options: {
|
|||
return input;
|
||||
}
|
||||
|
||||
export function getSessionId(
|
||||
ctx: IExecuteFunctions,
|
||||
itemIndex: number,
|
||||
selectorKey = 'sessionIdType',
|
||||
autoSelect = 'fromInput',
|
||||
customKey = 'sessionKey',
|
||||
) {
|
||||
let sessionId = '';
|
||||
const selectorType = ctx.getNodeParameter(selectorKey, itemIndex) as string;
|
||||
|
||||
if (selectorType === autoSelect) {
|
||||
sessionId = ctx.evaluateExpression('{{ $json.sessionId }}', itemIndex) as string;
|
||||
if (sessionId === '' || sessionId === undefined) {
|
||||
throw new NodeOperationError(ctx.getNode(), 'No session ID found', {
|
||||
description:
|
||||
"Expected to find the session ID in an input field called 'sessionId' (this is what the chat trigger node outputs). To use something else, change the 'Session ID' parameter",
|
||||
itemIndex,
|
||||
});
|
||||
}
|
||||
} else {
|
||||
sessionId = ctx.getNodeParameter(customKey, itemIndex, '') as string;
|
||||
if (sessionId === '' || sessionId === undefined) {
|
||||
throw new NodeOperationError(ctx.getNode(), 'Key parameter is empty', {
|
||||
description:
|
||||
"Provide a key to use as session ID in the 'Key' parameter or use the 'Take from previous node automatically' option to use the session ID from the previous node, e.t. chat trigger node",
|
||||
itemIndex,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return sessionId;
|
||||
}
|
||||
|
||||
export async function logAiEvent(
|
||||
executeFunctions: IExecuteFunctions,
|
||||
event: EventNamesAiNodesType,
|
||||
|
@ -79,7 +112,7 @@ export async function logAiEvent(
|
|||
}
|
||||
}
|
||||
|
||||
export function serializeChatHistory (chatHistory: Array<BaseMessage>): string {
|
||||
export function serializeChatHistory(chatHistory: BaseMessage[]): string {
|
||||
return chatHistory
|
||||
.map((chatMessage) => {
|
||||
if (chatMessage._getType() === 'human') {
|
||||
|
|
Loading…
Reference in a new issue