fix(OpenAI Chat Model Node): Fix loading of custom models when using custom credential URL (#12634)

This commit is contained in:
oleg 2025-01-17 09:30:02 +01:00 committed by GitHub
parent 02d953db34
commit 7cc553e3b2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 207 additions and 4 deletions

View file

@ -24,7 +24,7 @@ const modelParameter: INodeProperties = {
routing: { routing: {
request: { request: {
method: 'GET', method: 'GET',
url: '={{ $parameter.options?.baseURL?.split("/").slice(-1).pop() || "v1" }}/models', url: '={{ $parameter.options?.baseURL?.split("/").slice(-1).pop() || $credentials?.url?.split("/").slice(-1).pop() || "v1" }}/models',
}, },
output: { output: {
postReceive: [ postReceive: [

View file

@ -11,18 +11,25 @@ import {
import { getConnectionHintNoticeField } from '@utils/sharedFields'; import { getConnectionHintNoticeField } from '@utils/sharedFields';
import { searchModels } from './methods/loadModels';
import { openAiFailedAttemptHandler } from '../../vendors/OpenAi/helpers/error-handling'; import { openAiFailedAttemptHandler } from '../../vendors/OpenAi/helpers/error-handling';
import { makeN8nLlmFailedAttemptHandler } from '../n8nLlmFailedAttemptHandler'; import { makeN8nLlmFailedAttemptHandler } from '../n8nLlmFailedAttemptHandler';
import { N8nLlmTracing } from '../N8nLlmTracing'; import { N8nLlmTracing } from '../N8nLlmTracing';
export class LmChatOpenAi implements INodeType { export class LmChatOpenAi implements INodeType {
methods = {
listSearch: {
searchModels,
},
};
description: INodeTypeDescription = { description: INodeTypeDescription = {
displayName: 'OpenAI Chat Model', displayName: 'OpenAI Chat Model',
// eslint-disable-next-line n8n-nodes-base/node-class-description-name-miscased // eslint-disable-next-line n8n-nodes-base/node-class-description-name-miscased
name: 'lmChatOpenAi', name: 'lmChatOpenAi',
icon: { light: 'file:openAiLight.svg', dark: 'file:openAiLight.dark.svg' }, icon: { light: 'file:openAiLight.svg', dark: 'file:openAiLight.dark.svg' },
group: ['transform'], group: ['transform'],
version: [1, 1.1], version: [1, 1.1, 1.2],
description: 'For advanced usage with an AI chain', description: 'For advanced usage with an AI chain',
defaults: { defaults: {
name: 'OpenAI Chat Model', name: 'OpenAI Chat Model',
@ -130,6 +137,42 @@ export class LmChatOpenAi implements INodeType {
}, },
}, },
default: 'gpt-4o-mini', default: 'gpt-4o-mini',
displayOptions: {
hide: {
'@version': [{ _cnd: { gte: 1.2 } }],
},
},
},
{
displayName: 'Model',
name: 'model',
type: 'resourceLocator',
default: { mode: 'list', value: '' },
required: true,
modes: [
{
displayName: 'From List',
name: 'list',
type: 'list',
placeholder: 'Select a model...',
typeOptions: {
searchListMethod: 'searchModels',
searchable: true,
},
},
{
displayName: 'ID',
name: 'id',
type: 'string',
placeholder: '2302163813',
},
],
description: 'The model. Choose from the list, or specify an ID.',
displayOptions: {
hide: {
'@version': [{ _cnd: { lte: 1.1 } }],
},
},
}, },
{ {
displayName: displayName:
@ -251,7 +294,12 @@ export class LmChatOpenAi implements INodeType {
async supplyData(this: ISupplyDataFunctions, itemIndex: number): Promise<SupplyData> { async supplyData(this: ISupplyDataFunctions, itemIndex: number): Promise<SupplyData> {
const credentials = await this.getCredentials('openAiApi'); const credentials = await this.getCredentials('openAiApi');
const modelName = this.getNodeParameter('model', itemIndex) as string; const version = this.getNode().typeVersion;
const modelName =
version >= 1.2
? (this.getNodeParameter('model.value', itemIndex) as string)
: (this.getNodeParameter('model', itemIndex) as string);
const options = this.getNodeParameter('options', itemIndex, {}) as { const options = this.getNodeParameter('options', itemIndex, {}) as {
baseURL?: string; baseURL?: string;
frequencyPenalty?: number; frequencyPenalty?: number;

View file

@ -0,0 +1,112 @@
import type { ILoadOptionsFunctions } from 'n8n-workflow';
import OpenAI from 'openai';
import { searchModels } from '../loadModels';
jest.mock('openai');
describe('searchModels', () => {
let mockContext: jest.Mocked<ILoadOptionsFunctions>;
let mockOpenAI: jest.Mocked<typeof OpenAI>;
beforeEach(() => {
mockContext = {
getCredentials: jest.fn().mockResolvedValue({
apiKey: 'test-api-key',
}),
getNodeParameter: jest.fn().mockReturnValue(''),
} as unknown as jest.Mocked<ILoadOptionsFunctions>;
// Setup OpenAI mock with required properties
const mockOpenAIInstance = {
apiKey: 'test-api-key',
organization: null,
project: null,
_options: {},
models: {
list: jest.fn().mockResolvedValue({
data: [
{ id: 'gpt-4' },
{ id: 'gpt-3.5-turbo' },
{ id: 'gpt-3.5-turbo-instruct' },
{ id: 'ft:gpt-3.5-turbo' },
{ id: 'o1-model' },
{ id: 'other-model' },
],
}),
},
} as unknown as OpenAI;
(OpenAI as jest.MockedClass<typeof OpenAI>).mockImplementation(() => mockOpenAIInstance);
mockOpenAI = OpenAI as jest.Mocked<typeof OpenAI>;
});
afterEach(() => {
jest.clearAllMocks();
});
it('should return filtered models if custom API endpoint is not provided', async () => {
const result = await searchModels.call(mockContext);
expect(mockOpenAI).toHaveBeenCalledWith({
baseURL: 'https://api.openai.com/v1',
apiKey: 'test-api-key',
});
expect(result.results).toHaveLength(4);
});
it('should initialize OpenAI with correct credentials', async () => {
mockContext.getCredentials.mockResolvedValueOnce({
apiKey: 'test-api-key',
url: 'https://test-url.com',
});
await searchModels.call(mockContext);
expect(mockOpenAI).toHaveBeenCalledWith({
baseURL: 'https://test-url.com',
apiKey: 'test-api-key',
});
});
it('should use default OpenAI URL if no custom URL provided', async () => {
mockContext.getCredentials = jest.fn().mockResolvedValue({
apiKey: 'test-api-key',
});
await searchModels.call(mockContext);
expect(mockOpenAI).toHaveBeenCalledWith({
baseURL: 'https://api.openai.com/v1',
apiKey: 'test-api-key',
});
});
it('should include all models for custom API endpoints', async () => {
mockContext.getNodeParameter = jest.fn().mockReturnValue('https://custom-api.com');
const result = await searchModels.call(mockContext);
expect(result.results).toHaveLength(6);
});
it('should filter models based on search term', async () => {
const result = await searchModels.call(mockContext, 'gpt');
expect(result.results).toEqual([
{ name: 'gpt-4', value: 'gpt-4' },
{ name: 'gpt-3.5-turbo', value: 'gpt-3.5-turbo' },
{ name: 'ft:gpt-3.5-turbo', value: 'ft:gpt-3.5-turbo' },
]);
});
it('should handle case-insensitive search', async () => {
const result = await searchModels.call(mockContext, 'GPT');
expect(result.results).toEqual([
{ name: 'gpt-4', value: 'gpt-4' },
{ name: 'gpt-3.5-turbo', value: 'gpt-3.5-turbo' },
{ name: 'ft:gpt-3.5-turbo', value: 'ft:gpt-3.5-turbo' },
]);
});
});

View file

@ -0,0 +1,37 @@
import type { ILoadOptionsFunctions, INodeListSearchResult } from 'n8n-workflow';
import OpenAI from 'openai';
export async function searchModels(
this: ILoadOptionsFunctions,
filter?: string,
): Promise<INodeListSearchResult> {
const credentials = await this.getCredentials('openAiApi');
const baseURL =
(this.getNodeParameter('options.baseURL', '') as string) ||
(credentials.url as string) ||
'https://api.openai.com/v1';
const openai = new OpenAI({ baseURL, apiKey: credentials.apiKey as string });
const { data: models = [] } = await openai.models.list();
const filteredModels = models.filter((model: { id: string }) => {
const isValidModel =
(baseURL && !baseURL.includes('api.openai.com')) ||
model.id.startsWith('ft:') ||
model.id.startsWith('o1') ||
(model.id.startsWith('gpt-') && !model.id.includes('instruct'));
if (!filter) return isValidModel;
return isValidModel && model.id.toLowerCase().includes(filter.toLowerCase());
});
const results = {
results: filteredModels.map((model: { id: string }) => ({
name: model.id,
value: model.id,
})),
};
return results;
}

View file

@ -76,9 +76,15 @@ export async function modelSearch(
this: ILoadOptionsFunctions, this: ILoadOptionsFunctions,
filter?: string, filter?: string,
): Promise<INodeListSearchResult> { ): Promise<INodeListSearchResult> {
const credentials = await this.getCredentials<{ url: string }>('openAiApi');
const isCustomAPI = credentials.url && !credentials.url.includes('api.openai.com');
return await getModelSearch( return await getModelSearch(
(model) => (model) =>
model.id.startsWith('gpt-') || model.id.startsWith('ft:') || model.id.startsWith('o1'), isCustomAPI ||
model.id.startsWith('gpt-') ||
model.id.startsWith('ft:') ||
model.id.startsWith('o1'),
)(this, filter); )(this, filter);
} }