fix(editor): Fix retrieving of messages from memory in chat modal (#8807)

Signed-off-by: Oleg Ivaniv <me@olegivaniv.com>
This commit is contained in:
oleg 2024-03-05 13:53:46 +01:00 committed by GitHub
parent 16004331b1
commit bfda8ead0c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 40 additions and 50 deletions

View file

@ -15,7 +15,7 @@ import type { BaseDocumentLoader } from 'langchain/document_loaders/base';
import type { BaseCallbackConfig, Callbacks } from 'langchain/dist/callbacks/manager'; import type { BaseCallbackConfig, Callbacks } from 'langchain/dist/callbacks/manager';
import { BaseLLM } from 'langchain/llms/base'; import { BaseLLM } from 'langchain/llms/base';
import { BaseChatMemory } from 'langchain/memory'; import { BaseChatMemory } from 'langchain/memory';
import type { MemoryVariables } from 'langchain/dist/memory/base'; import type { MemoryVariables, OutputValues } from 'langchain/dist/memory/base';
import { BaseRetriever } from 'langchain/schema/retriever'; import { BaseRetriever } from 'langchain/schema/retriever';
import type { FormatInstructionsOptions } from 'langchain/schema/output_parser'; import type { FormatInstructionsOptions } from 'langchain/schema/output_parser';
import { BaseOutputParser, OutputParserException } from 'langchain/schema/output_parser'; import { BaseOutputParser, OutputParserException } from 'langchain/schema/output_parser';
@ -148,35 +148,37 @@ export function logWrapper(
arguments: [values], arguments: [values],
})) as MemoryVariables; })) as MemoryVariables;
const chatHistory = (response?.chat_history as BaseMessage[]) ?? response;
executeFunctions.addOutputData(connectionType, index, [ executeFunctions.addOutputData(connectionType, index, [
[{ json: { action: 'loadMemoryVariables', response } }], [{ json: { action: 'loadMemoryVariables', chatHistory } }],
]); ]);
return response; return response;
}; };
} else if ( } else if (prop === 'saveContext' && 'saveContext' in target) {
prop === 'outputKey' && return async (input: InputValues, output: OutputValues): Promise<MemoryVariables> => {
'outputKey' in target && connectionType = NodeConnectionType.AiMemory;
target.constructor.name === 'BufferWindowMemory'
) {
connectionType = NodeConnectionType.AiMemory;
const { index } = executeFunctions.addInputData(connectionType, [
[{ json: { action: 'chatHistory' } }],
]);
const response = target[prop];
target.chatHistory const { index } = executeFunctions.addInputData(connectionType, [
.getMessages() [{ json: { action: 'saveContext', input, output } }],
.then((messages) => { ]);
executeFunctions.addOutputData(NodeConnectionType.AiMemory, index, [
[{ json: { action: 'chatHistory', chatHistory: messages } }], const response = (await callMethodAsync.call(target, {
]); executeFunctions,
}) connectionType,
.catch((error: Error) => { currentNodeRunIndex: index,
executeFunctions.addOutputData(NodeConnectionType.AiMemory, index, [ method: target[prop],
[{ json: { action: 'chatHistory', error } }], arguments: [input, output],
]); })) as MemoryVariables;
});
return response; const chatHistory = await target.chatHistory.getMessages();
executeFunctions.addOutputData(connectionType, index, [
[{ json: { action: 'saveContext', chatHistory } }],
]);
return response;
};
} }
} }

View file

@ -171,6 +171,10 @@ interface LangChainMessage {
}; };
} }
interface MemoryOutput {
action: string;
chatHistory?: LangChainMessage[];
}
// TODO: // TODO:
// - display additional information like execution time, tokens used, ... // - display additional information like execution time, tokens used, ...
// - display errors better // - display errors better
@ -217,7 +221,10 @@ export default defineComponent({
this.messages = this.getChatMessages(); this.messages = this.getChatMessages();
this.setNode(); this.setNode();
setTimeout(() => this.$refs.inputField?.focus(), 0); setTimeout(() => {
this.scrollToLatestMessage();
this.$refs.inputField?.focus();
}, 0);
}, },
methods: { methods: {
displayExecution(executionId: string) { displayExecution(executionId: string) {
@ -353,32 +360,13 @@ export default defineComponent({
memoryConnection.node, memoryConnection.node,
); );
const memoryOutputData = nodeResultData const memoryOutputData = (nodeResultData ?? [])
?.map( .map(
( (data) => get(data, ['data', NodeConnectionType.AiMemory, 0, 0, 'json']) as MemoryOutput,
data,
): {
action: string;
chatHistory?: unknown[];
response?: {
sessionId?: unknown[];
};
} => get(data, ['data', NodeConnectionType.AiMemory, 0, 0, 'json'])!,
) )
?.find((data) => .find((data) => data.action === 'saveContext');
['chatHistory', 'loadMemoryVariables'].includes(data?.action) ? data : undefined,
);
let chatHistory: LangChainMessage[]; return (memoryOutputData?.chatHistory ?? []).map((message) => {
if (memoryOutputData?.chatHistory) {
chatHistory = memoryOutputData?.chatHistory as LangChainMessage[];
} else if (memoryOutputData?.response) {
chatHistory = memoryOutputData?.response.sessionId as LangChainMessage[];
} else {
return [];
}
return (chatHistory || []).map((message) => {
return { return {
text: message.kwargs.content, text: message.kwargs.content,
sender: last(message.id) === 'HumanMessage' ? 'user' : 'bot', sender: last(message.id) === 'HumanMessage' ? 'user' : 'bot',