feat(AI Transform Node): Reduce payload size (#11965)

This commit is contained in:
Michael Kret 2024-12-09 10:21:43 +02:00 committed by GitHub
parent aece4c497a
commit d8ca8de13a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 202 additions and 1 deletions

View file

@ -1,8 +1,10 @@
/* eslint-disable @typescript-eslint/no-explicit-any */ /* eslint-disable @typescript-eslint/no-explicit-any */
import { describe, it, expect, vi, beforeEach } from 'vitest'; import { describe, it, expect, vi, beforeEach } from 'vitest';
import { generateCodeForAiTransform } from './utils'; import { generateCodeForAiTransform, reducePayloadSizeOrThrow } from './utils';
import { createPinia, setActivePinia } from 'pinia'; import { createPinia, setActivePinia } from 'pinia';
import { generateCodeForPrompt } from '@/api/ai'; import { generateCodeForPrompt } from '@/api/ai';
import type { AskAiRequest } from '@/types/assistant.types';
import type { Schema } from '@/Interface';
vi.mock('./utils', async () => { vi.mock('./utils', async () => {
const actual = await vi.importActual('./utils'); const actual = await vi.importActual('./utils');
@ -86,3 +88,69 @@ describe('generateCodeForAiTransform - Retry Tests', () => {
expect(generateCodeForPrompt).toHaveBeenCalledTimes(1); expect(generateCodeForPrompt).toHaveBeenCalledTimes(1);
}); });
}); });
const mockPayload = () =>
({
context: {
schema: [
{ nodeName: 'node1', data: 'some data' },
{ nodeName: 'node2', data: 'other data' },
],
inputSchema: {
schema: {
value: [
{ key: 'prop1', value: 'value1' },
{ key: 'prop2', value: 'value2' },
],
},
},
},
question: 'What is node1 and prop1?',
}) as unknown as AskAiRequest.RequestPayload;
describe('reducePayloadSizeOrThrow', () => {
it('reduces schema size when tokens exceed the limit', () => {
const payload = mockPayload();
const error = new Error('Limit is 100 tokens, but 104 were provided');
reducePayloadSizeOrThrow(payload, error);
expect(payload.context.schema.length).toBe(1);
expect(payload.context.schema[0]).toEqual({ nodeName: 'node1', data: 'some data' });
});
it('removes unreferenced properties in input schema', () => {
const payload = mockPayload();
const error = new Error('Limit is 100 tokens, but 150 were provided');
reducePayloadSizeOrThrow(payload, error);
expect(payload.context.inputSchema.schema.value.length).toBe(1);
expect((payload.context.inputSchema.schema.value as Schema[])[0].key).toBe('prop1');
});
it('removes all parent nodes if needed', () => {
const payload = mockPayload();
const error = new Error('Limit is 100 tokens, but 150 were provided');
payload.question = '';
reducePayloadSizeOrThrow(payload, error);
expect(payload.context.schema.length).toBe(0);
});
it('throws error if tokens still exceed after reductions', () => {
const payload = mockPayload();
const error = new Error('Limit is 100 tokens, but 200 were provided');
expect(() => reducePayloadSizeOrThrow(payload, error)).toThrowError(error);
});
it('throws error if message format is invalid', () => {
const payload = mockPayload();
const error = new Error('Invalid token message format');
expect(() => reducePayloadSizeOrThrow(payload, error)).toThrowError(error);
});
});

View file

@ -57,6 +57,134 @@ export function getSchemas() {
}; };
} }
//------ Reduce payload ------
const estimateNumberOfTokens = (item: unknown, averageTokenLength: number): number => {
if (typeof item === 'object') {
return Math.ceil(JSON.stringify(item).length / averageTokenLength);
}
return 0;
};
const calculateRemainingTokens = (error: Error) => {
// Expected message format:
//'This model's maximum context length is 8192 tokens. However, your messages resulted in 10514 tokens.'
const tokens = error.message.match(/\d+/g);
if (!tokens || tokens.length < 2) throw error;
const maxTokens = parseInt(tokens[0], 10);
const currentTokens = parseInt(tokens[1], 10);
return currentTokens - maxTokens;
};
const trimParentNodesSchema = (
payload: AskAiRequest.RequestPayload,
remainingTokensToReduce: number,
averageTokenLength: number,
) => {
//check if parent nodes schema takes more tokens than available
let parentNodesTokenCount = estimateNumberOfTokens(payload.context.schema, averageTokenLength);
if (remainingTokensToReduce > parentNodesTokenCount) {
remainingTokensToReduce -= parentNodesTokenCount;
payload.context.schema = [];
}
//remove parent nodes not referenced in the prompt
if (payload.context.schema.length) {
const nodes = [...payload.context.schema];
for (let nodeIndex = 0; nodeIndex < nodes.length; nodeIndex++) {
if (payload.question.includes(nodes[nodeIndex].nodeName)) continue;
const nodeTokens = estimateNumberOfTokens(nodes[nodeIndex], averageTokenLength);
remainingTokensToReduce -= nodeTokens;
parentNodesTokenCount -= nodeTokens;
payload.context.schema.splice(nodeIndex, 1);
if (remainingTokensToReduce <= 0) break;
}
}
return [remainingTokensToReduce, parentNodesTokenCount];
};
const trimInputSchemaProperties = (
payload: AskAiRequest.RequestPayload,
remainingTokensToReduce: number,
averageTokenLength: number,
parentNodesTokenCount: number,
) => {
if (remainingTokensToReduce <= 0) return remainingTokensToReduce;
//remove properties not referenced in the prompt from the input schema
if (Array.isArray(payload.context.inputSchema.schema.value)) {
const props = [...payload.context.inputSchema.schema.value];
for (let index = 0; index < props.length; index++) {
const key = props[index].key;
if (key && payload.question.includes(key)) continue;
const propTokens = estimateNumberOfTokens(props[index], averageTokenLength);
remainingTokensToReduce -= propTokens;
payload.context.inputSchema.schema.value.splice(index, 1);
if (remainingTokensToReduce <= 0) break;
}
}
//if tokensToReduce is still remaining, remove all parent nodes
if (remainingTokensToReduce > 0) {
payload.context.schema = [];
remainingTokensToReduce -= parentNodesTokenCount;
}
return remainingTokensToReduce;
};
/**
* Attempts to reduce the size of the payload to fit within token limits or throws an error if unsuccessful,
* payload would be modified in place
*
* @param {AskAiRequest.RequestPayload} payload - The request payload to be trimmed,
* 'schema' and 'inputSchema.schema' will be modified.
* @param {Error} error - The error to throw if the token reduction fails.
* @param {number} [averageTokenLength=4] - The average token length used for estimation.
* @throws {Error} - Throws the provided error if the payload cannot be reduced sufficiently.
*/
export function reducePayloadSizeOrThrow(
payload: AskAiRequest.RequestPayload,
error: Error,
averageTokenLength = 4,
) {
try {
let remainingTokensToReduce = calculateRemainingTokens(error);
const [remaining, parentNodesTokenCount] = trimParentNodesSchema(
payload,
remainingTokensToReduce,
averageTokenLength,
);
remainingTokensToReduce = remaining;
remainingTokensToReduce = trimInputSchemaProperties(
payload,
remainingTokensToReduce,
averageTokenLength,
parentNodesTokenCount,
);
if (remainingTokensToReduce > 0) throw error;
} catch (e) {
throw e;
}
}
export async function generateCodeForAiTransform(prompt: string, path: string, retries = 1) { export async function generateCodeForAiTransform(prompt: string, path: string, retries = 1) {
const schemas = getSchemas(); const schemas = getSchemas();
@ -83,6 +211,11 @@ export async function generateCodeForAiTransform(prompt: string, path: string, r
code = generatedCode; code = generatedCode;
break; break;
} catch (e) { } catch (e) {
if (e.message.includes('maximum context length')) {
reducePayloadSizeOrThrow(payload, e);
continue;
}
retries--; retries--;
if (!retries) throw e; if (!retries) throw e;
} }