mirror of
https://github.com/n8n-io/n8n.git
synced 2024-12-31 15:37:26 -08:00
feat(AI Transform Node): Reduce payload size (#11965)
This commit is contained in:
parent
aece4c497a
commit
d8ca8de13a
|
@ -1,8 +1,10 @@
|
||||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||||
import { generateCodeForAiTransform } from './utils';
|
import { generateCodeForAiTransform, reducePayloadSizeOrThrow } from './utils';
|
||||||
import { createPinia, setActivePinia } from 'pinia';
|
import { createPinia, setActivePinia } from 'pinia';
|
||||||
import { generateCodeForPrompt } from '@/api/ai';
|
import { generateCodeForPrompt } from '@/api/ai';
|
||||||
|
import type { AskAiRequest } from '@/types/assistant.types';
|
||||||
|
import type { Schema } from '@/Interface';
|
||||||
|
|
||||||
vi.mock('./utils', async () => {
|
vi.mock('./utils', async () => {
|
||||||
const actual = await vi.importActual('./utils');
|
const actual = await vi.importActual('./utils');
|
||||||
|
@ -86,3 +88,69 @@ describe('generateCodeForAiTransform - Retry Tests', () => {
|
||||||
expect(generateCodeForPrompt).toHaveBeenCalledTimes(1);
|
expect(generateCodeForPrompt).toHaveBeenCalledTimes(1);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const mockPayload = () =>
|
||||||
|
({
|
||||||
|
context: {
|
||||||
|
schema: [
|
||||||
|
{ nodeName: 'node1', data: 'some data' },
|
||||||
|
{ nodeName: 'node2', data: 'other data' },
|
||||||
|
],
|
||||||
|
inputSchema: {
|
||||||
|
schema: {
|
||||||
|
value: [
|
||||||
|
{ key: 'prop1', value: 'value1' },
|
||||||
|
{ key: 'prop2', value: 'value2' },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
question: 'What is node1 and prop1?',
|
||||||
|
}) as unknown as AskAiRequest.RequestPayload;
|
||||||
|
|
||||||
|
describe('reducePayloadSizeOrThrow', () => {
|
||||||
|
it('reduces schema size when tokens exceed the limit', () => {
|
||||||
|
const payload = mockPayload();
|
||||||
|
const error = new Error('Limit is 100 tokens, but 104 were provided');
|
||||||
|
|
||||||
|
reducePayloadSizeOrThrow(payload, error);
|
||||||
|
|
||||||
|
expect(payload.context.schema.length).toBe(1);
|
||||||
|
expect(payload.context.schema[0]).toEqual({ nodeName: 'node1', data: 'some data' });
|
||||||
|
});
|
||||||
|
|
||||||
|
it('removes unreferenced properties in input schema', () => {
|
||||||
|
const payload = mockPayload();
|
||||||
|
const error = new Error('Limit is 100 tokens, but 150 were provided');
|
||||||
|
|
||||||
|
reducePayloadSizeOrThrow(payload, error);
|
||||||
|
|
||||||
|
expect(payload.context.inputSchema.schema.value.length).toBe(1);
|
||||||
|
expect((payload.context.inputSchema.schema.value as Schema[])[0].key).toBe('prop1');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('removes all parent nodes if needed', () => {
|
||||||
|
const payload = mockPayload();
|
||||||
|
const error = new Error('Limit is 100 tokens, but 150 were provided');
|
||||||
|
|
||||||
|
payload.question = '';
|
||||||
|
|
||||||
|
reducePayloadSizeOrThrow(payload, error);
|
||||||
|
|
||||||
|
expect(payload.context.schema.length).toBe(0);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('throws error if tokens still exceed after reductions', () => {
|
||||||
|
const payload = mockPayload();
|
||||||
|
const error = new Error('Limit is 100 tokens, but 200 were provided');
|
||||||
|
|
||||||
|
expect(() => reducePayloadSizeOrThrow(payload, error)).toThrowError(error);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('throws error if message format is invalid', () => {
|
||||||
|
const payload = mockPayload();
|
||||||
|
const error = new Error('Invalid token message format');
|
||||||
|
|
||||||
|
expect(() => reducePayloadSizeOrThrow(payload, error)).toThrowError(error);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
|
@ -57,6 +57,134 @@ export function getSchemas() {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//------ Reduce payload ------
|
||||||
|
|
||||||
|
const estimateNumberOfTokens = (item: unknown, averageTokenLength: number): number => {
|
||||||
|
if (typeof item === 'object') {
|
||||||
|
return Math.ceil(JSON.stringify(item).length / averageTokenLength);
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
const calculateRemainingTokens = (error: Error) => {
|
||||||
|
// Expected message format:
|
||||||
|
//'This model's maximum context length is 8192 tokens. However, your messages resulted in 10514 tokens.'
|
||||||
|
const tokens = error.message.match(/\d+/g);
|
||||||
|
|
||||||
|
if (!tokens || tokens.length < 2) throw error;
|
||||||
|
|
||||||
|
const maxTokens = parseInt(tokens[0], 10);
|
||||||
|
const currentTokens = parseInt(tokens[1], 10);
|
||||||
|
|
||||||
|
return currentTokens - maxTokens;
|
||||||
|
};
|
||||||
|
|
||||||
|
const trimParentNodesSchema = (
|
||||||
|
payload: AskAiRequest.RequestPayload,
|
||||||
|
remainingTokensToReduce: number,
|
||||||
|
averageTokenLength: number,
|
||||||
|
) => {
|
||||||
|
//check if parent nodes schema takes more tokens than available
|
||||||
|
let parentNodesTokenCount = estimateNumberOfTokens(payload.context.schema, averageTokenLength);
|
||||||
|
|
||||||
|
if (remainingTokensToReduce > parentNodesTokenCount) {
|
||||||
|
remainingTokensToReduce -= parentNodesTokenCount;
|
||||||
|
payload.context.schema = [];
|
||||||
|
}
|
||||||
|
|
||||||
|
//remove parent nodes not referenced in the prompt
|
||||||
|
if (payload.context.schema.length) {
|
||||||
|
const nodes = [...payload.context.schema];
|
||||||
|
|
||||||
|
for (let nodeIndex = 0; nodeIndex < nodes.length; nodeIndex++) {
|
||||||
|
if (payload.question.includes(nodes[nodeIndex].nodeName)) continue;
|
||||||
|
|
||||||
|
const nodeTokens = estimateNumberOfTokens(nodes[nodeIndex], averageTokenLength);
|
||||||
|
remainingTokensToReduce -= nodeTokens;
|
||||||
|
parentNodesTokenCount -= nodeTokens;
|
||||||
|
payload.context.schema.splice(nodeIndex, 1);
|
||||||
|
|
||||||
|
if (remainingTokensToReduce <= 0) break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return [remainingTokensToReduce, parentNodesTokenCount];
|
||||||
|
};
|
||||||
|
|
||||||
|
const trimInputSchemaProperties = (
|
||||||
|
payload: AskAiRequest.RequestPayload,
|
||||||
|
remainingTokensToReduce: number,
|
||||||
|
averageTokenLength: number,
|
||||||
|
parentNodesTokenCount: number,
|
||||||
|
) => {
|
||||||
|
if (remainingTokensToReduce <= 0) return remainingTokensToReduce;
|
||||||
|
|
||||||
|
//remove properties not referenced in the prompt from the input schema
|
||||||
|
if (Array.isArray(payload.context.inputSchema.schema.value)) {
|
||||||
|
const props = [...payload.context.inputSchema.schema.value];
|
||||||
|
|
||||||
|
for (let index = 0; index < props.length; index++) {
|
||||||
|
const key = props[index].key;
|
||||||
|
|
||||||
|
if (key && payload.question.includes(key)) continue;
|
||||||
|
|
||||||
|
const propTokens = estimateNumberOfTokens(props[index], averageTokenLength);
|
||||||
|
remainingTokensToReduce -= propTokens;
|
||||||
|
payload.context.inputSchema.schema.value.splice(index, 1);
|
||||||
|
|
||||||
|
if (remainingTokensToReduce <= 0) break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//if tokensToReduce is still remaining, remove all parent nodes
|
||||||
|
if (remainingTokensToReduce > 0) {
|
||||||
|
payload.context.schema = [];
|
||||||
|
remainingTokensToReduce -= parentNodesTokenCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
return remainingTokensToReduce;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Attempts to reduce the size of the payload to fit within token limits or throws an error if unsuccessful,
|
||||||
|
* payload would be modified in place
|
||||||
|
*
|
||||||
|
* @param {AskAiRequest.RequestPayload} payload - The request payload to be trimmed,
|
||||||
|
* 'schema' and 'inputSchema.schema' will be modified.
|
||||||
|
* @param {Error} error - The error to throw if the token reduction fails.
|
||||||
|
* @param {number} [averageTokenLength=4] - The average token length used for estimation.
|
||||||
|
* @throws {Error} - Throws the provided error if the payload cannot be reduced sufficiently.
|
||||||
|
*/
|
||||||
|
export function reducePayloadSizeOrThrow(
|
||||||
|
payload: AskAiRequest.RequestPayload,
|
||||||
|
error: Error,
|
||||||
|
averageTokenLength = 4,
|
||||||
|
) {
|
||||||
|
try {
|
||||||
|
let remainingTokensToReduce = calculateRemainingTokens(error);
|
||||||
|
|
||||||
|
const [remaining, parentNodesTokenCount] = trimParentNodesSchema(
|
||||||
|
payload,
|
||||||
|
remainingTokensToReduce,
|
||||||
|
averageTokenLength,
|
||||||
|
);
|
||||||
|
|
||||||
|
remainingTokensToReduce = remaining;
|
||||||
|
|
||||||
|
remainingTokensToReduce = trimInputSchemaProperties(
|
||||||
|
payload,
|
||||||
|
remainingTokensToReduce,
|
||||||
|
averageTokenLength,
|
||||||
|
parentNodesTokenCount,
|
||||||
|
);
|
||||||
|
|
||||||
|
if (remainingTokensToReduce > 0) throw error;
|
||||||
|
} catch (e) {
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export async function generateCodeForAiTransform(prompt: string, path: string, retries = 1) {
|
export async function generateCodeForAiTransform(prompt: string, path: string, retries = 1) {
|
||||||
const schemas = getSchemas();
|
const schemas = getSchemas();
|
||||||
|
|
||||||
|
@ -83,6 +211,11 @@ export async function generateCodeForAiTransform(prompt: string, path: string, r
|
||||||
code = generatedCode;
|
code = generatedCode;
|
||||||
break;
|
break;
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
|
if (e.message.includes('maximum context length')) {
|
||||||
|
reducePayloadSizeOrThrow(payload, e);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
retries--;
|
retries--;
|
||||||
if (!retries) throw e;
|
if (!retries) throw e;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue