refactor(core): Update AI-Assistant backend code to use DTOs and injectable config (no-changelog) (#12373)

This commit is contained in:
कारतोफ्फेलस्क्रिप्ट™ 2024-12-26 15:31:19 +01:00 committed by GitHub
parent f754b22a3f
commit 1d5e891a0d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 679 additions and 40 deletions

View file

@ -0,0 +1,36 @@
import { AiApplySuggestionRequestDto } from '../ai-apply-suggestion-request.dto';
describe('AiApplySuggestionRequestDto', () => {
it('should validate a valid suggestion application request', () => {
const validRequest = {
sessionId: 'session-123',
suggestionId: 'suggestion-456',
};
const result = AiApplySuggestionRequestDto.safeParse(validRequest);
expect(result.success).toBe(true);
});
it('should fail if sessionId is missing', () => {
const invalidRequest = {
suggestionId: 'suggestion-456',
};
const result = AiApplySuggestionRequestDto.safeParse(invalidRequest);
expect(result.success).toBe(false);
expect(result.error?.issues[0].path).toEqual(['sessionId']);
});
it('should fail if suggestionId is missing', () => {
const invalidRequest = {
sessionId: 'session-123',
};
const result = AiApplySuggestionRequestDto.safeParse(invalidRequest);
expect(result.success).toBe(false);
expect(result.error?.issues[0].path).toEqual(['suggestionId']);
});
});

View file

@ -0,0 +1,252 @@
import { AiAskRequestDto } from '../ai-ask-request.dto';
describe('AiAskRequestDto', () => {
const validRequest = {
question: 'How can I improve this workflow?',
context: {
schema: [
{
nodeName: 'TestNode',
schema: {
type: 'string',
key: 'testKey',
value: 'testValue',
path: '/test/path',
},
},
],
inputSchema: {
nodeName: 'InputNode',
schema: {
type: 'object',
key: 'inputKey',
value: [
{
type: 'string',
key: 'nestedKey',
value: 'nestedValue',
path: '/nested/path',
},
],
path: '/input/path',
},
},
pushRef: 'push-123',
ndvPushRef: 'ndv-push-456',
},
forNode: 'TestWorkflowNode',
};
it('should validate a valid AI ask request', () => {
const result = AiAskRequestDto.safeParse(validRequest);
expect(result.success).toBe(true);
});
it('should fail if question is missing', () => {
const invalidRequest = {
...validRequest,
question: undefined,
};
const result = AiAskRequestDto.safeParse(invalidRequest);
expect(result.success).toBe(false);
expect(result.error?.issues[0].path).toEqual(['question']);
});
it('should fail if context is invalid', () => {
const invalidRequest = {
...validRequest,
context: {
...validRequest.context,
schema: [
{
nodeName: 'TestNode',
schema: {
type: 'invalid-type', // Invalid type
value: 'testValue',
path: '/test/path',
},
},
],
},
};
const result = AiAskRequestDto.safeParse(invalidRequest);
expect(result.success).toBe(false);
});
it('should fail if forNode is missing', () => {
const invalidRequest = {
...validRequest,
forNode: undefined,
};
const result = AiAskRequestDto.safeParse(invalidRequest);
expect(result.success).toBe(false);
expect(result.error?.issues[0].path).toEqual(['forNode']);
});
it('should validate all possible schema types', () => {
const allTypesRequest = {
question: 'Test all possible types',
context: {
schema: [
{
nodeName: 'AllTypesNode',
schema: {
type: 'object',
key: 'typesRoot',
value: [
{ type: 'string', key: 'stringType', value: 'string', path: '/types/string' },
{ type: 'number', key: 'numberType', value: 'number', path: '/types/number' },
{ type: 'boolean', key: 'booleanType', value: 'boolean', path: '/types/boolean' },
{ type: 'bigint', key: 'bigintType', value: 'bigint', path: '/types/bigint' },
{ type: 'symbol', key: 'symbolType', value: 'symbol', path: '/types/symbol' },
{ type: 'array', key: 'arrayType', value: [], path: '/types/array' },
{ type: 'object', key: 'objectType', value: [], path: '/types/object' },
{
type: 'function',
key: 'functionType',
value: 'function',
path: '/types/function',
},
{ type: 'null', key: 'nullType', value: 'null', path: '/types/null' },
{
type: 'undefined',
key: 'undefinedType',
value: 'undefined',
path: '/types/undefined',
},
],
path: '/types/root',
},
},
],
inputSchema: {
nodeName: 'InputNode',
schema: {
type: 'object',
key: 'simpleInput',
value: [
{
type: 'string',
key: 'simpleKey',
value: 'simpleValue',
path: '/simple/path',
},
],
path: '/simple/input/path',
},
},
pushRef: 'push-types-123',
ndvPushRef: 'ndv-push-types-456',
},
forNode: 'TypeCheckNode',
};
const result = AiAskRequestDto.safeParse(allTypesRequest);
expect(result.success).toBe(true);
});
it('should fail with invalid type', () => {
const invalidTypeRequest = {
question: 'Test invalid type',
context: {
schema: [
{
nodeName: 'InvalidTypeNode',
schema: {
type: 'invalid-type', // This should fail
key: 'invalidKey',
value: 'invalidValue',
path: '/invalid/path',
},
},
],
inputSchema: {
nodeName: 'InputNode',
schema: {
type: 'object',
key: 'simpleInput',
value: [
{
type: 'string',
key: 'simpleKey',
value: 'simpleValue',
path: '/simple/path',
},
],
path: '/simple/input/path',
},
},
pushRef: 'push-invalid-123',
ndvPushRef: 'ndv-push-invalid-456',
},
forNode: 'InvalidTypeNode',
};
const result = AiAskRequestDto.safeParse(invalidTypeRequest);
expect(result.success).toBe(false);
});
it('should validate multiple schema entries', () => {
const multiSchemaRequest = {
question: 'Multiple schema test',
context: {
schema: [
{
nodeName: 'FirstNode',
schema: {
type: 'string',
key: 'firstKey',
value: 'firstValue',
path: '/first/path',
},
},
{
nodeName: 'SecondNode',
schema: {
type: 'object',
key: 'secondKey',
value: [
{
type: 'number',
key: 'nestedKey',
value: 'nestedValue',
path: '/second/nested/path',
},
],
path: '/second/path',
},
},
],
inputSchema: {
nodeName: 'InputNode',
schema: {
type: 'object',
key: 'simpleInput',
value: [
{
type: 'string',
key: 'simpleKey',
value: 'simpleValue',
path: '/simple/path',
},
],
path: '/simple/input/path',
},
},
pushRef: 'push-multi-123',
ndvPushRef: 'ndv-push-multi-456',
},
forNode: 'MultiSchemaNode',
};
const result = AiAskRequestDto.safeParse(multiSchemaRequest);
expect(result.success).toBe(true);
});
});

View file

@ -0,0 +1,34 @@
import { AiChatRequestDto } from '../ai-chat-request.dto';
describe('AiChatRequestDto', () => {
it('should validate a request with a payload and session ID', () => {
const validRequest = {
payload: { someKey: 'someValue' },
sessionId: 'session-123',
};
const result = AiChatRequestDto.safeParse(validRequest);
expect(result.success).toBe(true);
});
it('should validate a request with only a payload', () => {
const validRequest = {
payload: { complexObject: { nested: 'value' } },
};
const result = AiChatRequestDto.safeParse(validRequest);
expect(result.success).toBe(true);
});
it('should fail if payload is missing', () => {
const invalidRequest = {
sessionId: 'session-123',
};
const result = AiChatRequestDto.safeParse(invalidRequest);
expect(result.success).toBe(false);
});
});

View file

@ -0,0 +1,7 @@
import { z } from 'zod';
import { Z } from 'zod-class';
export class AiApplySuggestionRequestDto extends Z.class({
sessionId: z.string(),
suggestionId: z.string(),
}) {}

View file

@ -0,0 +1,53 @@
import type { AiAssistantSDK, SchemaType } from '@n8n_io/ai-assistant-sdk';
import { z } from 'zod';
import { Z } from 'zod-class';
// Note: This is copied from the sdk, since this type is not exported
type Schema = {
type: SchemaType;
key?: string;
value: string | Schema[];
path: string;
};
// Create a lazy validator to handle the recursive type
const schemaValidator: z.ZodType<Schema> = z.lazy(() =>
z.object({
type: z.enum([
'string',
'number',
'boolean',
'bigint',
'symbol',
'array',
'object',
'function',
'null',
'undefined',
]),
key: z.string().optional(),
value: z.union([z.string(), z.lazy(() => schemaValidator.array())]),
path: z.string(),
}),
);
export class AiAskRequestDto
extends Z.class({
question: z.string(),
context: z.object({
schema: z.array(
z.object({
nodeName: z.string(),
schema: schemaValidator,
}),
),
inputSchema: z.object({
nodeName: z.string(),
schema: schemaValidator,
}),
pushRef: z.string(),
ndvPushRef: z.string(),
}),
forNode: z.string(),
})
implements AiAssistantSDK.AskAiRequestPayload {}

View file

@ -0,0 +1,10 @@
import type { AiAssistantSDK } from '@n8n_io/ai-assistant-sdk';
import { z } from 'zod';
import { Z } from 'zod-class';
export class AiChatRequestDto
extends Z.class({
payload: z.object({}).passthrough(), // Allow any object shape
sessionId: z.string().optional(),
})
implements AiAssistantSDK.ChatRequestPayload {}

View file

@ -1,3 +1,6 @@
export { AiAskRequestDto } from './ai/ai-ask-request.dto';
export { AiChatRequestDto } from './ai/ai-chat-request.dto';
export { AiApplySuggestionRequestDto } from './ai/ai-apply-suggestion-request.dto';
export { PasswordUpdateRequestDto } from './user/password-update-request.dto';
export { RoleChangeRequestDto } from './user/role-change-request.dto';
export { SettingsUpdateRequestDto } from './user/settings-update-request.dto';

View file

@ -0,0 +1,8 @@
import { Config, Env } from '../decorators';
@Config
export class AiAssistantConfig {
/** Base URL of the AI assistant service */
@Env('N8N_AI_ASSISTANT_BASE_URL')
baseUrl: string = '';
}

View file

@ -1,3 +1,4 @@
import { AiAssistantConfig } from './configs/aiAssistant.config';
import { CacheConfig } from './configs/cache.config';
import { CredentialsConfig } from './configs/credentials.config';
import { DatabaseConfig } from './configs/database.config';
@ -121,4 +122,7 @@ export class GlobalConfig {
@Nested
diagnostics: DiagnosticsConfig;
@Nested
aiAssistant: AiAssistantConfig;
}

View file

@ -289,6 +289,9 @@ describe('GlobalConfig', () => {
apiHost: 'https://ph.n8n.io',
},
},
aiAssistant: {
baseUrl: '',
},
};
it('should use all default values when no env variables are defined', () => {

View file

@ -341,15 +341,6 @@ export const schema = {
},
},
aiAssistant: {
baseUrl: {
doc: 'Base URL of the AI assistant service',
format: String,
default: '',
env: 'N8N_AI_ASSISTANT_BASE_URL',
},
},
expression: {
evaluator: {
doc: 'Expression evaluator to use',

View file

@ -0,0 +1,111 @@
import type {
AiAskRequestDto,
AiApplySuggestionRequestDto,
AiChatRequestDto,
} from '@n8n/api-types';
import type { AiAssistantSDK } from '@n8n_io/ai-assistant-sdk';
import { mock } from 'jest-mock-extended';
import { InternalServerError } from '@/errors/response-errors/internal-server.error';
import type { AuthenticatedRequest } from '@/requests';
import type { AiService } from '@/services/ai.service';
import { AiController, type FlushableResponse } from '../ai.controller';
describe('AiController', () => {
const aiService = mock<AiService>();
const controller = new AiController(aiService);
const request = mock<AuthenticatedRequest>({
user: { id: 'user123' },
});
const response = mock<FlushableResponse>();
beforeEach(() => {
jest.clearAllMocks();
response.header.mockReturnThis();
});
describe('chat', () => {
const payload = mock<AiChatRequestDto>();
it('should handle chat request successfully', async () => {
aiService.chat.mockResolvedValue(
mock<Response>({
body: mock({
pipeTo: jest.fn().mockImplementation(async (writableStream) => {
// Simulate stream writing
const writer = writableStream.getWriter();
await writer.write(JSON.stringify({ message: 'test response' }));
await writer.close();
}),
}),
}),
);
await controller.chat(request, response, payload);
expect(aiService.chat).toHaveBeenCalledWith(payload, request.user);
expect(response.header).toHaveBeenCalledWith('Content-type', 'application/json-lines');
expect(response.flush).toHaveBeenCalled();
expect(response.end).toHaveBeenCalled();
});
it('should throw InternalServerError if chat fails', async () => {
const mockError = new Error('Chat failed');
aiService.chat.mockRejectedValue(mockError);
await expect(controller.chat(request, response, payload)).rejects.toThrow(
InternalServerError,
);
});
});
describe('applySuggestion', () => {
const payload = mock<AiApplySuggestionRequestDto>();
it('should apply suggestion successfully', async () => {
const clientResponse = mock<AiAssistantSDK.ApplySuggestionResponse>();
aiService.applySuggestion.mockResolvedValue(clientResponse);
const result = await controller.applySuggestion(request, response, payload);
expect(aiService.applySuggestion).toHaveBeenCalledWith(payload, request.user);
expect(result).toEqual(clientResponse);
});
it('should throw InternalServerError if applying suggestion fails', async () => {
const mockError = new Error('Apply suggestion failed');
aiService.applySuggestion.mockRejectedValue(mockError);
await expect(controller.applySuggestion(request, response, payload)).rejects.toThrow(
InternalServerError,
);
});
});
describe('askAi method', () => {
const payload = mock<AiAskRequestDto>();
it('should ask AI successfully', async () => {
const clientResponse = mock<AiAssistantSDK.AskAiResponsePayload>();
aiService.askAi.mockResolvedValue(clientResponse);
const result = await controller.askAi(request, response, payload);
expect(aiService.askAi).toHaveBeenCalledWith(payload, request.user);
expect(result).toEqual(clientResponse);
});
it('should throw InternalServerError if asking AI fails', async () => {
const mockError = new Error('Ask AI failed');
aiService.askAi.mockRejectedValue(mockError);
await expect(controller.askAi(request, response, payload)).rejects.toThrow(
InternalServerError,
);
});
});
});

View file

@ -1,23 +1,24 @@
import { AiChatRequestDto, AiApplySuggestionRequestDto, AiAskRequestDto } from '@n8n/api-types';
import type { AiAssistantSDK } from '@n8n_io/ai-assistant-sdk';
import type { Response } from 'express';
import { Response } from 'express';
import { strict as assert } from 'node:assert';
import { WritableStream } from 'node:stream/web';
import { Post, RestController } from '@/decorators';
import { Body, Post, RestController } from '@/decorators';
import { InternalServerError } from '@/errors/response-errors/internal-server.error';
import { AiAssistantRequest } from '@/requests';
import { AuthenticatedRequest } from '@/requests';
import { AiService } from '@/services/ai.service';
type FlushableResponse = Response & { flush: () => void };
export type FlushableResponse = Response & { flush: () => void };
@RestController('/ai')
export class AiController {
constructor(private readonly aiService: AiService) {}
@Post('/chat', { rateLimit: { limit: 100 } })
async chat(req: AiAssistantRequest.Chat, res: FlushableResponse) {
async chat(req: AuthenticatedRequest, res: FlushableResponse, @Body payload: AiChatRequestDto) {
try {
const aiResponse = await this.aiService.chat(req.body, req.user);
const aiResponse = await this.aiService.chat(payload, req.user);
if (aiResponse.body) {
res.header('Content-type', 'application/json-lines').flush();
await aiResponse.body.pipeTo(
@ -38,10 +39,12 @@ export class AiController {
@Post('/chat/apply-suggestion')
async applySuggestion(
req: AiAssistantRequest.ApplySuggestionPayload,
req: AuthenticatedRequest,
_: Response,
@Body payload: AiApplySuggestionRequestDto,
): Promise<AiAssistantSDK.ApplySuggestionResponse> {
try {
return await this.aiService.applySuggestion(req.body, req.user);
return await this.aiService.applySuggestion(payload, req.user);
} catch (e) {
assert(e instanceof Error);
throw new InternalServerError(e.message, e);
@ -49,9 +52,13 @@ export class AiController {
}
@Post('/ask-ai')
async askAi(req: AiAssistantRequest.AskAiPayload): Promise<AiAssistantSDK.AskAiResponsePayload> {
async askAi(
req: AuthenticatedRequest,
_: Response,
@Body payload: AiAskRequestDto,
): Promise<AiAssistantSDK.AskAiResponsePayload> {
try {
return await this.aiService.askAi(req.body, req.user);
return await this.aiService.askAi(payload, req.user);
} catch (e) {
assert(e instanceof Error);
throw new InternalServerError(e.message, e);

View file

@ -1,5 +1,4 @@
import type { Scope } from '@n8n/permissions';
import type { AiAssistantSDK } from '@n8n_io/ai-assistant-sdk';
import type express from 'express';
import type {
BannerName,
@ -574,15 +573,3 @@ export declare namespace NpsSurveyRequest {
// once some schema validation is added
type NpsSurveyUpdate = AuthenticatedRequest<{}, {}, unknown>;
}
// ----------------------------------
// /ai-assistant
// ----------------------------------
export declare namespace AiAssistantRequest {
type Chat = AuthenticatedRequest<{}, {}, AiAssistantSDK.ChatRequestPayload>;
type SuggestionPayload = { sessionId: string; suggestionId: string };
type ApplySuggestionPayload = AuthenticatedRequest<{}, {}, SuggestionPayload>;
type AskAiPayload = AuthenticatedRequest<{}, {}, AiAssistantSDK.AskAiRequestPayload>;
}

View file

@ -0,0 +1,132 @@
import type {
AiAskRequestDto,
AiApplySuggestionRequestDto,
AiChatRequestDto,
} from '@n8n/api-types';
import type { GlobalConfig } from '@n8n/config';
import { AiAssistantClient, type AiAssistantSDK } from '@n8n_io/ai-assistant-sdk';
import { mock } from 'jest-mock-extended';
import type { IUser } from 'n8n-workflow';
import { N8N_VERSION } from '@/constants';
import type { License } from '@/license';
import { AiService } from '../ai.service';
jest.mock('@n8n_io/ai-assistant-sdk', () => ({
AiAssistantClient: jest.fn(),
}));
describe('AiService', () => {
let aiService: AiService;
const baseUrl = 'https://ai-assistant-url.com';
const user = mock<IUser>({ id: 'user123' });
const client = mock<AiAssistantClient>();
const license = mock<License>();
const globalConfig = mock<GlobalConfig>({
logging: { level: 'info' },
aiAssistant: { baseUrl },
});
beforeEach(() => {
jest.clearAllMocks();
(AiAssistantClient as jest.Mock).mockImplementation(() => client);
aiService = new AiService(license, globalConfig);
});
afterEach(() => {
jest.clearAllMocks();
});
describe('init', () => {
it('should not initialize client if AI assistant is not enabled', async () => {
license.isAiAssistantEnabled.mockReturnValue(false);
await aiService.init();
expect(AiAssistantClient).not.toHaveBeenCalled();
});
it('should initialize client when AI assistant is enabled', async () => {
license.isAiAssistantEnabled.mockReturnValue(true);
license.loadCertStr.mockResolvedValue('mock-license-cert');
license.getConsumerId.mockReturnValue('mock-consumer-id');
await aiService.init();
expect(AiAssistantClient).toHaveBeenCalledWith({
licenseCert: 'mock-license-cert',
consumerId: 'mock-consumer-id',
n8nVersion: N8N_VERSION,
baseUrl,
logLevel: 'info',
});
});
});
describe('chat', () => {
const payload = mock<AiChatRequestDto>();
it('should call client chat method after initialization', async () => {
license.isAiAssistantEnabled.mockReturnValue(true);
const clientResponse = mock<Response>();
client.chat.mockResolvedValue(clientResponse);
const result = await aiService.chat(payload, user);
expect(client.chat).toHaveBeenCalledWith(payload, { id: user.id });
expect(result).toEqual(clientResponse);
});
it('should throw error if client is not initialized', async () => {
license.isAiAssistantEnabled.mockReturnValue(false);
await expect(aiService.chat(payload, user)).rejects.toThrow('Assistant client not setup');
});
});
describe('applySuggestion', () => {
const payload = mock<AiApplySuggestionRequestDto>();
it('should call client applySuggestion', async () => {
license.isAiAssistantEnabled.mockReturnValue(true);
const clientResponse = mock<AiAssistantSDK.ApplySuggestionResponse>();
client.applySuggestion.mockResolvedValue(clientResponse);
const result = await aiService.applySuggestion(payload, user);
expect(client.applySuggestion).toHaveBeenCalledWith(payload, { id: user.id });
expect(result).toEqual(clientResponse);
});
it('should throw error if client is not initialized', async () => {
license.isAiAssistantEnabled.mockReturnValue(false);
await expect(aiService.applySuggestion(payload, user)).rejects.toThrow(
'Assistant client not setup',
);
});
});
describe('askAi', () => {
const payload = mock<AiAskRequestDto>();
it('should call client askAi method after initialization', async () => {
license.isAiAssistantEnabled.mockReturnValue(true);
const clientResponse = mock<AiAssistantSDK.AskAiResponsePayload>();
client.askAi.mockResolvedValue(clientResponse);
const result = await aiService.askAi(payload, user);
expect(client.askAi).toHaveBeenCalledWith(payload, { id: user.id });
expect(result).toEqual(clientResponse);
});
it('should throw error if client is not initialized', async () => {
license.isAiAssistantEnabled.mockReturnValue(false);
await expect(aiService.askAi(payload, user)).rejects.toThrow('Assistant client not setup');
});
});
});

View file

@ -1,12 +1,13 @@
import type {
AiApplySuggestionRequestDto,
AiAskRequestDto,
AiChatRequestDto,
} from '@n8n/api-types';
import { GlobalConfig } from '@n8n/config';
import type { AiAssistantSDK } from '@n8n_io/ai-assistant-sdk';
import { AiAssistantClient } from '@n8n_io/ai-assistant-sdk';
import { assert, type IUser } from 'n8n-workflow';
import { Service } from 'typedi';
import config from '@/config';
import type { AiAssistantRequest } from '@/requests';
import { N8N_VERSION } from '../constants';
import { License } from '../license';
@ -27,7 +28,7 @@ export class AiService {
const licenseCert = await this.licenseService.loadCertStr();
const consumerId = this.licenseService.getConsumerId();
const baseUrl = config.get('aiAssistant.baseUrl');
const baseUrl = this.globalConfig.aiAssistant.baseUrl;
const logLevel = this.globalConfig.logging.level;
this.client = new AiAssistantClient({
@ -39,7 +40,7 @@ export class AiService {
});
}
async chat(payload: AiAssistantSDK.ChatRequestPayload, user: IUser) {
async chat(payload: AiChatRequestDto, user: IUser) {
if (!this.client) {
await this.init();
}
@ -48,7 +49,7 @@ export class AiService {
return await this.client.chat(payload, { id: user.id });
}
async applySuggestion(payload: AiAssistantRequest.SuggestionPayload, user: IUser) {
async applySuggestion(payload: AiApplySuggestionRequestDto, user: IUser) {
if (!this.client) {
await this.init();
}
@ -57,7 +58,7 @@ export class AiService {
return await this.client.applySuggestion(payload, { id: user.id });
}
async askAi(payload: AiAssistantSDK.AskAiRequestPayload, user: IUser) {
async askAi(payload: AiAskRequestDto, user: IUser) {
if (!this.client) {
await this.init();
}