mirror of
https://github.com/n8n-io/n8n.git
synced 2024-12-28 22:19:41 -08:00
feat(AI Transform Node): Reduce payload size (#11965)
This commit is contained in:
parent
aece4c497a
commit
d8ca8de13a
|
@ -1,8 +1,10 @@
|
|||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { generateCodeForAiTransform } from './utils';
|
||||
import { generateCodeForAiTransform, reducePayloadSizeOrThrow } from './utils';
|
||||
import { createPinia, setActivePinia } from 'pinia';
|
||||
import { generateCodeForPrompt } from '@/api/ai';
|
||||
import type { AskAiRequest } from '@/types/assistant.types';
|
||||
import type { Schema } from '@/Interface';
|
||||
|
||||
vi.mock('./utils', async () => {
|
||||
const actual = await vi.importActual('./utils');
|
||||
|
@ -86,3 +88,69 @@ describe('generateCodeForAiTransform - Retry Tests', () => {
|
|||
expect(generateCodeForPrompt).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
const mockPayload = () =>
|
||||
({
|
||||
context: {
|
||||
schema: [
|
||||
{ nodeName: 'node1', data: 'some data' },
|
||||
{ nodeName: 'node2', data: 'other data' },
|
||||
],
|
||||
inputSchema: {
|
||||
schema: {
|
||||
value: [
|
||||
{ key: 'prop1', value: 'value1' },
|
||||
{ key: 'prop2', value: 'value2' },
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
question: 'What is node1 and prop1?',
|
||||
}) as unknown as AskAiRequest.RequestPayload;
|
||||
|
||||
describe('reducePayloadSizeOrThrow', () => {
|
||||
it('reduces schema size when tokens exceed the limit', () => {
|
||||
const payload = mockPayload();
|
||||
const error = new Error('Limit is 100 tokens, but 104 were provided');
|
||||
|
||||
reducePayloadSizeOrThrow(payload, error);
|
||||
|
||||
expect(payload.context.schema.length).toBe(1);
|
||||
expect(payload.context.schema[0]).toEqual({ nodeName: 'node1', data: 'some data' });
|
||||
});
|
||||
|
||||
it('removes unreferenced properties in input schema', () => {
|
||||
const payload = mockPayload();
|
||||
const error = new Error('Limit is 100 tokens, but 150 were provided');
|
||||
|
||||
reducePayloadSizeOrThrow(payload, error);
|
||||
|
||||
expect(payload.context.inputSchema.schema.value.length).toBe(1);
|
||||
expect((payload.context.inputSchema.schema.value as Schema[])[0].key).toBe('prop1');
|
||||
});
|
||||
|
||||
it('removes all parent nodes if needed', () => {
|
||||
const payload = mockPayload();
|
||||
const error = new Error('Limit is 100 tokens, but 150 were provided');
|
||||
|
||||
payload.question = '';
|
||||
|
||||
reducePayloadSizeOrThrow(payload, error);
|
||||
|
||||
expect(payload.context.schema.length).toBe(0);
|
||||
});
|
||||
|
||||
it('throws error if tokens still exceed after reductions', () => {
|
||||
const payload = mockPayload();
|
||||
const error = new Error('Limit is 100 tokens, but 200 were provided');
|
||||
|
||||
expect(() => reducePayloadSizeOrThrow(payload, error)).toThrowError(error);
|
||||
});
|
||||
|
||||
it('throws error if message format is invalid', () => {
|
||||
const payload = mockPayload();
|
||||
const error = new Error('Invalid token message format');
|
||||
|
||||
expect(() => reducePayloadSizeOrThrow(payload, error)).toThrowError(error);
|
||||
});
|
||||
});
|
||||
|
|
|
@ -57,6 +57,134 @@ export function getSchemas() {
|
|||
};
|
||||
}
|
||||
|
||||
//------ Reduce payload ------
|
||||
|
||||
const estimateNumberOfTokens = (item: unknown, averageTokenLength: number): number => {
|
||||
if (typeof item === 'object') {
|
||||
return Math.ceil(JSON.stringify(item).length / averageTokenLength);
|
||||
}
|
||||
|
||||
return 0;
|
||||
};
|
||||
|
||||
const calculateRemainingTokens = (error: Error) => {
|
||||
// Expected message format:
|
||||
//'This model's maximum context length is 8192 tokens. However, your messages resulted in 10514 tokens.'
|
||||
const tokens = error.message.match(/\d+/g);
|
||||
|
||||
if (!tokens || tokens.length < 2) throw error;
|
||||
|
||||
const maxTokens = parseInt(tokens[0], 10);
|
||||
const currentTokens = parseInt(tokens[1], 10);
|
||||
|
||||
return currentTokens - maxTokens;
|
||||
};
|
||||
|
||||
const trimParentNodesSchema = (
|
||||
payload: AskAiRequest.RequestPayload,
|
||||
remainingTokensToReduce: number,
|
||||
averageTokenLength: number,
|
||||
) => {
|
||||
//check if parent nodes schema takes more tokens than available
|
||||
let parentNodesTokenCount = estimateNumberOfTokens(payload.context.schema, averageTokenLength);
|
||||
|
||||
if (remainingTokensToReduce > parentNodesTokenCount) {
|
||||
remainingTokensToReduce -= parentNodesTokenCount;
|
||||
payload.context.schema = [];
|
||||
}
|
||||
|
||||
//remove parent nodes not referenced in the prompt
|
||||
if (payload.context.schema.length) {
|
||||
const nodes = [...payload.context.schema];
|
||||
|
||||
for (let nodeIndex = 0; nodeIndex < nodes.length; nodeIndex++) {
|
||||
if (payload.question.includes(nodes[nodeIndex].nodeName)) continue;
|
||||
|
||||
const nodeTokens = estimateNumberOfTokens(nodes[nodeIndex], averageTokenLength);
|
||||
remainingTokensToReduce -= nodeTokens;
|
||||
parentNodesTokenCount -= nodeTokens;
|
||||
payload.context.schema.splice(nodeIndex, 1);
|
||||
|
||||
if (remainingTokensToReduce <= 0) break;
|
||||
}
|
||||
}
|
||||
|
||||
return [remainingTokensToReduce, parentNodesTokenCount];
|
||||
};
|
||||
|
||||
const trimInputSchemaProperties = (
|
||||
payload: AskAiRequest.RequestPayload,
|
||||
remainingTokensToReduce: number,
|
||||
averageTokenLength: number,
|
||||
parentNodesTokenCount: number,
|
||||
) => {
|
||||
if (remainingTokensToReduce <= 0) return remainingTokensToReduce;
|
||||
|
||||
//remove properties not referenced in the prompt from the input schema
|
||||
if (Array.isArray(payload.context.inputSchema.schema.value)) {
|
||||
const props = [...payload.context.inputSchema.schema.value];
|
||||
|
||||
for (let index = 0; index < props.length; index++) {
|
||||
const key = props[index].key;
|
||||
|
||||
if (key && payload.question.includes(key)) continue;
|
||||
|
||||
const propTokens = estimateNumberOfTokens(props[index], averageTokenLength);
|
||||
remainingTokensToReduce -= propTokens;
|
||||
payload.context.inputSchema.schema.value.splice(index, 1);
|
||||
|
||||
if (remainingTokensToReduce <= 0) break;
|
||||
}
|
||||
}
|
||||
|
||||
//if tokensToReduce is still remaining, remove all parent nodes
|
||||
if (remainingTokensToReduce > 0) {
|
||||
payload.context.schema = [];
|
||||
remainingTokensToReduce -= parentNodesTokenCount;
|
||||
}
|
||||
|
||||
return remainingTokensToReduce;
|
||||
};
|
||||
|
||||
/**
|
||||
* Attempts to reduce the size of the payload to fit within token limits or throws an error if unsuccessful,
|
||||
* payload would be modified in place
|
||||
*
|
||||
* @param {AskAiRequest.RequestPayload} payload - The request payload to be trimmed,
|
||||
* 'schema' and 'inputSchema.schema' will be modified.
|
||||
* @param {Error} error - The error to throw if the token reduction fails.
|
||||
* @param {number} [averageTokenLength=4] - The average token length used for estimation.
|
||||
* @throws {Error} - Throws the provided error if the payload cannot be reduced sufficiently.
|
||||
*/
|
||||
export function reducePayloadSizeOrThrow(
|
||||
payload: AskAiRequest.RequestPayload,
|
||||
error: Error,
|
||||
averageTokenLength = 4,
|
||||
) {
|
||||
try {
|
||||
let remainingTokensToReduce = calculateRemainingTokens(error);
|
||||
|
||||
const [remaining, parentNodesTokenCount] = trimParentNodesSchema(
|
||||
payload,
|
||||
remainingTokensToReduce,
|
||||
averageTokenLength,
|
||||
);
|
||||
|
||||
remainingTokensToReduce = remaining;
|
||||
|
||||
remainingTokensToReduce = trimInputSchemaProperties(
|
||||
payload,
|
||||
remainingTokensToReduce,
|
||||
averageTokenLength,
|
||||
parentNodesTokenCount,
|
||||
);
|
||||
|
||||
if (remainingTokensToReduce > 0) throw error;
|
||||
} catch (e) {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
export async function generateCodeForAiTransform(prompt: string, path: string, retries = 1) {
|
||||
const schemas = getSchemas();
|
||||
|
||||
|
@ -83,6 +211,11 @@ export async function generateCodeForAiTransform(prompt: string, path: string, r
|
|||
code = generatedCode;
|
||||
break;
|
||||
} catch (e) {
|
||||
if (e.message.includes('maximum context length')) {
|
||||
reducePayloadSizeOrThrow(payload, e);
|
||||
continue;
|
||||
}
|
||||
|
||||
retries--;
|
||||
if (!retries) throw e;
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue