From d8ca8de13a4cbb856696873bdb56c66b12a5b027 Mon Sep 17 00:00:00 2001 From: Michael Kret <88898367+michael-radency@users.noreply.github.com> Date: Mon, 9 Dec 2024 10:21:43 +0200 Subject: [PATCH] feat(AI Transform Node): Reduce payload size (#11965) --- .../components/ButtonParameter/utils.test.ts | 70 ++++++++- .../src/components/ButtonParameter/utils.ts | 133 ++++++++++++++++++ 2 files changed, 202 insertions(+), 1 deletion(-) diff --git a/packages/editor-ui/src/components/ButtonParameter/utils.test.ts b/packages/editor-ui/src/components/ButtonParameter/utils.test.ts index df7e13d477..7453d41fb5 100644 --- a/packages/editor-ui/src/components/ButtonParameter/utils.test.ts +++ b/packages/editor-ui/src/components/ButtonParameter/utils.test.ts @@ -1,8 +1,10 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import { describe, it, expect, vi, beforeEach } from 'vitest'; -import { generateCodeForAiTransform } from './utils'; +import { generateCodeForAiTransform, reducePayloadSizeOrThrow } from './utils'; import { createPinia, setActivePinia } from 'pinia'; import { generateCodeForPrompt } from '@/api/ai'; +import type { AskAiRequest } from '@/types/assistant.types'; +import type { Schema } from '@/Interface'; vi.mock('./utils', async () => { const actual = await vi.importActual('./utils'); @@ -86,3 +88,69 @@ describe('generateCodeForAiTransform - Retry Tests', () => { expect(generateCodeForPrompt).toHaveBeenCalledTimes(1); }); }); + +const mockPayload = () => + ({ + context: { + schema: [ + { nodeName: 'node1', data: 'some data' }, + { nodeName: 'node2', data: 'other data' }, + ], + inputSchema: { + schema: { + value: [ + { key: 'prop1', value: 'value1' }, + { key: 'prop2', value: 'value2' }, + ], + }, + }, + }, + question: 'What is node1 and prop1?', + }) as unknown as AskAiRequest.RequestPayload; + +describe('reducePayloadSizeOrThrow', () => { + it('reduces schema size when tokens exceed the limit', () => { + const payload = mockPayload(); + const error = new Error('Limit is 100 tokens, but 104 were provided'); + + reducePayloadSizeOrThrow(payload, error); + + expect(payload.context.schema.length).toBe(1); + expect(payload.context.schema[0]).toEqual({ nodeName: 'node1', data: 'some data' }); + }); + + it('removes unreferenced properties in input schema', () => { + const payload = mockPayload(); + const error = new Error('Limit is 100 tokens, but 150 were provided'); + + reducePayloadSizeOrThrow(payload, error); + + expect(payload.context.inputSchema.schema.value.length).toBe(1); + expect((payload.context.inputSchema.schema.value as Schema[])[0].key).toBe('prop1'); + }); + + it('removes all parent nodes if needed', () => { + const payload = mockPayload(); + const error = new Error('Limit is 100 tokens, but 150 were provided'); + + payload.question = ''; + + reducePayloadSizeOrThrow(payload, error); + + expect(payload.context.schema.length).toBe(0); + }); + + it('throws error if tokens still exceed after reductions', () => { + const payload = mockPayload(); + const error = new Error('Limit is 100 tokens, but 200 were provided'); + + expect(() => reducePayloadSizeOrThrow(payload, error)).toThrowError(error); + }); + + it('throws error if message format is invalid', () => { + const payload = mockPayload(); + const error = new Error('Invalid token message format'); + + expect(() => reducePayloadSizeOrThrow(payload, error)).toThrowError(error); + }); +}); diff --git a/packages/editor-ui/src/components/ButtonParameter/utils.ts b/packages/editor-ui/src/components/ButtonParameter/utils.ts index b95846975f..1044477bbb 100644 --- a/packages/editor-ui/src/components/ButtonParameter/utils.ts +++ b/packages/editor-ui/src/components/ButtonParameter/utils.ts @@ -57,6 +57,134 @@ export function getSchemas() { }; } +//------ Reduce payload ------ + +const estimateNumberOfTokens = (item: unknown, averageTokenLength: number): number => { + if (typeof item === 'object') { + return Math.ceil(JSON.stringify(item).length / averageTokenLength); + } + + return 0; +}; + +const calculateRemainingTokens = (error: Error) => { + // Expected message format: + //'This model's maximum context length is 8192 tokens. However, your messages resulted in 10514 tokens.' + const tokens = error.message.match(/\d+/g); + + if (!tokens || tokens.length < 2) throw error; + + const maxTokens = parseInt(tokens[0], 10); + const currentTokens = parseInt(tokens[1], 10); + + return currentTokens - maxTokens; +}; + +const trimParentNodesSchema = ( + payload: AskAiRequest.RequestPayload, + remainingTokensToReduce: number, + averageTokenLength: number, +) => { + //check if parent nodes schema takes more tokens than available + let parentNodesTokenCount = estimateNumberOfTokens(payload.context.schema, averageTokenLength); + + if (remainingTokensToReduce > parentNodesTokenCount) { + remainingTokensToReduce -= parentNodesTokenCount; + payload.context.schema = []; + } + + //remove parent nodes not referenced in the prompt + if (payload.context.schema.length) { + const nodes = [...payload.context.schema]; + + for (let nodeIndex = 0; nodeIndex < nodes.length; nodeIndex++) { + if (payload.question.includes(nodes[nodeIndex].nodeName)) continue; + + const nodeTokens = estimateNumberOfTokens(nodes[nodeIndex], averageTokenLength); + remainingTokensToReduce -= nodeTokens; + parentNodesTokenCount -= nodeTokens; + payload.context.schema.splice(nodeIndex, 1); + + if (remainingTokensToReduce <= 0) break; + } + } + + return [remainingTokensToReduce, parentNodesTokenCount]; +}; + +const trimInputSchemaProperties = ( + payload: AskAiRequest.RequestPayload, + remainingTokensToReduce: number, + averageTokenLength: number, + parentNodesTokenCount: number, +) => { + if (remainingTokensToReduce <= 0) return remainingTokensToReduce; + + //remove properties not referenced in the prompt from the input schema + if (Array.isArray(payload.context.inputSchema.schema.value)) { + const props = [...payload.context.inputSchema.schema.value]; + + for (let index = 0; index < props.length; index++) { + const key = props[index].key; + + if (key && payload.question.includes(key)) continue; + + const propTokens = estimateNumberOfTokens(props[index], averageTokenLength); + remainingTokensToReduce -= propTokens; + payload.context.inputSchema.schema.value.splice(index, 1); + + if (remainingTokensToReduce <= 0) break; + } + } + + //if tokensToReduce is still remaining, remove all parent nodes + if (remainingTokensToReduce > 0) { + payload.context.schema = []; + remainingTokensToReduce -= parentNodesTokenCount; + } + + return remainingTokensToReduce; +}; + +/** + * Attempts to reduce the size of the payload to fit within token limits or throws an error if unsuccessful, + * payload would be modified in place + * + * @param {AskAiRequest.RequestPayload} payload - The request payload to be trimmed, + * 'schema' and 'inputSchema.schema' will be modified. + * @param {Error} error - The error to throw if the token reduction fails. + * @param {number} [averageTokenLength=4] - The average token length used for estimation. + * @throws {Error} - Throws the provided error if the payload cannot be reduced sufficiently. + */ +export function reducePayloadSizeOrThrow( + payload: AskAiRequest.RequestPayload, + error: Error, + averageTokenLength = 4, +) { + try { + let remainingTokensToReduce = calculateRemainingTokens(error); + + const [remaining, parentNodesTokenCount] = trimParentNodesSchema( + payload, + remainingTokensToReduce, + averageTokenLength, + ); + + remainingTokensToReduce = remaining; + + remainingTokensToReduce = trimInputSchemaProperties( + payload, + remainingTokensToReduce, + averageTokenLength, + parentNodesTokenCount, + ); + + if (remainingTokensToReduce > 0) throw error; + } catch (e) { + throw e; + } +} + export async function generateCodeForAiTransform(prompt: string, path: string, retries = 1) { const schemas = getSchemas(); @@ -83,6 +211,11 @@ export async function generateCodeForAiTransform(prompt: string, path: string, r code = generatedCode; break; } catch (e) { + if (e.message.includes('maximum context length')) { + reducePayloadSizeOrThrow(payload, e); + continue; + } + retries--; if (!retries) throw e; }