mirror of
https://github.com/n8n-io/n8n.git
synced 2025-03-05 20:50:17 -08:00
fix(OpenAI Chat Model Node): Fix loading of custom models when using custom credential URL (#12634)
This commit is contained in:
parent
02d953db34
commit
7cc553e3b2
|
@ -24,7 +24,7 @@ const modelParameter: INodeProperties = {
|
|||
routing: {
|
||||
request: {
|
||||
method: 'GET',
|
||||
url: '={{ $parameter.options?.baseURL?.split("/").slice(-1).pop() || "v1" }}/models',
|
||||
url: '={{ $parameter.options?.baseURL?.split("/").slice(-1).pop() || $credentials?.url?.split("/").slice(-1).pop() || "v1" }}/models',
|
||||
},
|
||||
output: {
|
||||
postReceive: [
|
||||
|
|
|
@ -11,18 +11,25 @@ import {
|
|||
|
||||
import { getConnectionHintNoticeField } from '@utils/sharedFields';
|
||||
|
||||
import { searchModels } from './methods/loadModels';
|
||||
import { openAiFailedAttemptHandler } from '../../vendors/OpenAi/helpers/error-handling';
|
||||
import { makeN8nLlmFailedAttemptHandler } from '../n8nLlmFailedAttemptHandler';
|
||||
import { N8nLlmTracing } from '../N8nLlmTracing';
|
||||
|
||||
export class LmChatOpenAi implements INodeType {
|
||||
methods = {
|
||||
listSearch: {
|
||||
searchModels,
|
||||
},
|
||||
};
|
||||
|
||||
description: INodeTypeDescription = {
|
||||
displayName: 'OpenAI Chat Model',
|
||||
// eslint-disable-next-line n8n-nodes-base/node-class-description-name-miscased
|
||||
name: 'lmChatOpenAi',
|
||||
icon: { light: 'file:openAiLight.svg', dark: 'file:openAiLight.dark.svg' },
|
||||
group: ['transform'],
|
||||
version: [1, 1.1],
|
||||
version: [1, 1.1, 1.2],
|
||||
description: 'For advanced usage with an AI chain',
|
||||
defaults: {
|
||||
name: 'OpenAI Chat Model',
|
||||
|
@ -130,6 +137,42 @@ export class LmChatOpenAi implements INodeType {
|
|||
},
|
||||
},
|
||||
default: 'gpt-4o-mini',
|
||||
displayOptions: {
|
||||
hide: {
|
||||
'@version': [{ _cnd: { gte: 1.2 } }],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
displayName: 'Model',
|
||||
name: 'model',
|
||||
type: 'resourceLocator',
|
||||
default: { mode: 'list', value: '' },
|
||||
required: true,
|
||||
modes: [
|
||||
{
|
||||
displayName: 'From List',
|
||||
name: 'list',
|
||||
type: 'list',
|
||||
placeholder: 'Select a model...',
|
||||
typeOptions: {
|
||||
searchListMethod: 'searchModels',
|
||||
searchable: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
displayName: 'ID',
|
||||
name: 'id',
|
||||
type: 'string',
|
||||
placeholder: '2302163813',
|
||||
},
|
||||
],
|
||||
description: 'The model. Choose from the list, or specify an ID.',
|
||||
displayOptions: {
|
||||
hide: {
|
||||
'@version': [{ _cnd: { lte: 1.1 } }],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
displayName:
|
||||
|
@ -251,7 +294,12 @@ export class LmChatOpenAi implements INodeType {
|
|||
async supplyData(this: ISupplyDataFunctions, itemIndex: number): Promise<SupplyData> {
|
||||
const credentials = await this.getCredentials('openAiApi');
|
||||
|
||||
const modelName = this.getNodeParameter('model', itemIndex) as string;
|
||||
const version = this.getNode().typeVersion;
|
||||
const modelName =
|
||||
version >= 1.2
|
||||
? (this.getNodeParameter('model.value', itemIndex) as string)
|
||||
: (this.getNodeParameter('model', itemIndex) as string);
|
||||
|
||||
const options = this.getNodeParameter('options', itemIndex, {}) as {
|
||||
baseURL?: string;
|
||||
frequencyPenalty?: number;
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
import type { ILoadOptionsFunctions } from 'n8n-workflow';
|
||||
import OpenAI from 'openai';
|
||||
|
||||
import { searchModels } from '../loadModels';
|
||||
|
||||
jest.mock('openai');
|
||||
|
||||
describe('searchModels', () => {
|
||||
let mockContext: jest.Mocked<ILoadOptionsFunctions>;
|
||||
let mockOpenAI: jest.Mocked<typeof OpenAI>;
|
||||
|
||||
beforeEach(() => {
|
||||
mockContext = {
|
||||
getCredentials: jest.fn().mockResolvedValue({
|
||||
apiKey: 'test-api-key',
|
||||
}),
|
||||
getNodeParameter: jest.fn().mockReturnValue(''),
|
||||
} as unknown as jest.Mocked<ILoadOptionsFunctions>;
|
||||
|
||||
// Setup OpenAI mock with required properties
|
||||
const mockOpenAIInstance = {
|
||||
apiKey: 'test-api-key',
|
||||
organization: null,
|
||||
project: null,
|
||||
_options: {},
|
||||
models: {
|
||||
list: jest.fn().mockResolvedValue({
|
||||
data: [
|
||||
{ id: 'gpt-4' },
|
||||
{ id: 'gpt-3.5-turbo' },
|
||||
{ id: 'gpt-3.5-turbo-instruct' },
|
||||
{ id: 'ft:gpt-3.5-turbo' },
|
||||
{ id: 'o1-model' },
|
||||
{ id: 'other-model' },
|
||||
],
|
||||
}),
|
||||
},
|
||||
} as unknown as OpenAI;
|
||||
|
||||
(OpenAI as jest.MockedClass<typeof OpenAI>).mockImplementation(() => mockOpenAIInstance);
|
||||
|
||||
mockOpenAI = OpenAI as jest.Mocked<typeof OpenAI>;
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should return filtered models if custom API endpoint is not provided', async () => {
|
||||
const result = await searchModels.call(mockContext);
|
||||
|
||||
expect(mockOpenAI).toHaveBeenCalledWith({
|
||||
baseURL: 'https://api.openai.com/v1',
|
||||
apiKey: 'test-api-key',
|
||||
});
|
||||
expect(result.results).toHaveLength(4);
|
||||
});
|
||||
|
||||
it('should initialize OpenAI with correct credentials', async () => {
|
||||
mockContext.getCredentials.mockResolvedValueOnce({
|
||||
apiKey: 'test-api-key',
|
||||
url: 'https://test-url.com',
|
||||
});
|
||||
await searchModels.call(mockContext);
|
||||
|
||||
expect(mockOpenAI).toHaveBeenCalledWith({
|
||||
baseURL: 'https://test-url.com',
|
||||
apiKey: 'test-api-key',
|
||||
});
|
||||
});
|
||||
|
||||
it('should use default OpenAI URL if no custom URL provided', async () => {
|
||||
mockContext.getCredentials = jest.fn().mockResolvedValue({
|
||||
apiKey: 'test-api-key',
|
||||
});
|
||||
|
||||
await searchModels.call(mockContext);
|
||||
|
||||
expect(mockOpenAI).toHaveBeenCalledWith({
|
||||
baseURL: 'https://api.openai.com/v1',
|
||||
apiKey: 'test-api-key',
|
||||
});
|
||||
});
|
||||
|
||||
it('should include all models for custom API endpoints', async () => {
|
||||
mockContext.getNodeParameter = jest.fn().mockReturnValue('https://custom-api.com');
|
||||
|
||||
const result = await searchModels.call(mockContext);
|
||||
|
||||
expect(result.results).toHaveLength(6);
|
||||
});
|
||||
|
||||
it('should filter models based on search term', async () => {
|
||||
const result = await searchModels.call(mockContext, 'gpt');
|
||||
|
||||
expect(result.results).toEqual([
|
||||
{ name: 'gpt-4', value: 'gpt-4' },
|
||||
{ name: 'gpt-3.5-turbo', value: 'gpt-3.5-turbo' },
|
||||
{ name: 'ft:gpt-3.5-turbo', value: 'ft:gpt-3.5-turbo' },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle case-insensitive search', async () => {
|
||||
const result = await searchModels.call(mockContext, 'GPT');
|
||||
|
||||
expect(result.results).toEqual([
|
||||
{ name: 'gpt-4', value: 'gpt-4' },
|
||||
{ name: 'gpt-3.5-turbo', value: 'gpt-3.5-turbo' },
|
||||
{ name: 'ft:gpt-3.5-turbo', value: 'ft:gpt-3.5-turbo' },
|
||||
]);
|
||||
});
|
||||
});
|
|
@ -0,0 +1,37 @@
|
|||
import type { ILoadOptionsFunctions, INodeListSearchResult } from 'n8n-workflow';
|
||||
import OpenAI from 'openai';
|
||||
|
||||
export async function searchModels(
|
||||
this: ILoadOptionsFunctions,
|
||||
filter?: string,
|
||||
): Promise<INodeListSearchResult> {
|
||||
const credentials = await this.getCredentials('openAiApi');
|
||||
const baseURL =
|
||||
(this.getNodeParameter('options.baseURL', '') as string) ||
|
||||
(credentials.url as string) ||
|
||||
'https://api.openai.com/v1';
|
||||
|
||||
const openai = new OpenAI({ baseURL, apiKey: credentials.apiKey as string });
|
||||
const { data: models = [] } = await openai.models.list();
|
||||
|
||||
const filteredModels = models.filter((model: { id: string }) => {
|
||||
const isValidModel =
|
||||
(baseURL && !baseURL.includes('api.openai.com')) ||
|
||||
model.id.startsWith('ft:') ||
|
||||
model.id.startsWith('o1') ||
|
||||
(model.id.startsWith('gpt-') && !model.id.includes('instruct'));
|
||||
|
||||
if (!filter) return isValidModel;
|
||||
|
||||
return isValidModel && model.id.toLowerCase().includes(filter.toLowerCase());
|
||||
});
|
||||
|
||||
const results = {
|
||||
results: filteredModels.map((model: { id: string }) => ({
|
||||
name: model.id,
|
||||
value: model.id,
|
||||
})),
|
||||
};
|
||||
|
||||
return results;
|
||||
}
|
|
@ -76,9 +76,15 @@ export async function modelSearch(
|
|||
this: ILoadOptionsFunctions,
|
||||
filter?: string,
|
||||
): Promise<INodeListSearchResult> {
|
||||
const credentials = await this.getCredentials<{ url: string }>('openAiApi');
|
||||
const isCustomAPI = credentials.url && !credentials.url.includes('api.openai.com');
|
||||
|
||||
return await getModelSearch(
|
||||
(model) =>
|
||||
model.id.startsWith('gpt-') || model.id.startsWith('ft:') || model.id.startsWith('o1'),
|
||||
isCustomAPI ||
|
||||
model.id.startsWith('gpt-') ||
|
||||
model.id.startsWith('ft:') ||
|
||||
model.id.startsWith('o1'),
|
||||
)(this, filter);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue