mirror of
https://github.com/n8n-io/n8n.git
synced 2025-03-05 20:50:17 -08:00
fix(OpenAI Chat Model Node): Fix loading of custom models when using custom credential URL (#12634)
This commit is contained in:
parent
02d953db34
commit
7cc553e3b2
|
@ -24,7 +24,7 @@ const modelParameter: INodeProperties = {
|
||||||
routing: {
|
routing: {
|
||||||
request: {
|
request: {
|
||||||
method: 'GET',
|
method: 'GET',
|
||||||
url: '={{ $parameter.options?.baseURL?.split("/").slice(-1).pop() || "v1" }}/models',
|
url: '={{ $parameter.options?.baseURL?.split("/").slice(-1).pop() || $credentials?.url?.split("/").slice(-1).pop() || "v1" }}/models',
|
||||||
},
|
},
|
||||||
output: {
|
output: {
|
||||||
postReceive: [
|
postReceive: [
|
||||||
|
|
|
@ -11,18 +11,25 @@ import {
|
||||||
|
|
||||||
import { getConnectionHintNoticeField } from '@utils/sharedFields';
|
import { getConnectionHintNoticeField } from '@utils/sharedFields';
|
||||||
|
|
||||||
|
import { searchModels } from './methods/loadModels';
|
||||||
import { openAiFailedAttemptHandler } from '../../vendors/OpenAi/helpers/error-handling';
|
import { openAiFailedAttemptHandler } from '../../vendors/OpenAi/helpers/error-handling';
|
||||||
import { makeN8nLlmFailedAttemptHandler } from '../n8nLlmFailedAttemptHandler';
|
import { makeN8nLlmFailedAttemptHandler } from '../n8nLlmFailedAttemptHandler';
|
||||||
import { N8nLlmTracing } from '../N8nLlmTracing';
|
import { N8nLlmTracing } from '../N8nLlmTracing';
|
||||||
|
|
||||||
export class LmChatOpenAi implements INodeType {
|
export class LmChatOpenAi implements INodeType {
|
||||||
|
methods = {
|
||||||
|
listSearch: {
|
||||||
|
searchModels,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
description: INodeTypeDescription = {
|
description: INodeTypeDescription = {
|
||||||
displayName: 'OpenAI Chat Model',
|
displayName: 'OpenAI Chat Model',
|
||||||
// eslint-disable-next-line n8n-nodes-base/node-class-description-name-miscased
|
// eslint-disable-next-line n8n-nodes-base/node-class-description-name-miscased
|
||||||
name: 'lmChatOpenAi',
|
name: 'lmChatOpenAi',
|
||||||
icon: { light: 'file:openAiLight.svg', dark: 'file:openAiLight.dark.svg' },
|
icon: { light: 'file:openAiLight.svg', dark: 'file:openAiLight.dark.svg' },
|
||||||
group: ['transform'],
|
group: ['transform'],
|
||||||
version: [1, 1.1],
|
version: [1, 1.1, 1.2],
|
||||||
description: 'For advanced usage with an AI chain',
|
description: 'For advanced usage with an AI chain',
|
||||||
defaults: {
|
defaults: {
|
||||||
name: 'OpenAI Chat Model',
|
name: 'OpenAI Chat Model',
|
||||||
|
@ -130,6 +137,42 @@ export class LmChatOpenAi implements INodeType {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
default: 'gpt-4o-mini',
|
default: 'gpt-4o-mini',
|
||||||
|
displayOptions: {
|
||||||
|
hide: {
|
||||||
|
'@version': [{ _cnd: { gte: 1.2 } }],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
displayName: 'Model',
|
||||||
|
name: 'model',
|
||||||
|
type: 'resourceLocator',
|
||||||
|
default: { mode: 'list', value: '' },
|
||||||
|
required: true,
|
||||||
|
modes: [
|
||||||
|
{
|
||||||
|
displayName: 'From List',
|
||||||
|
name: 'list',
|
||||||
|
type: 'list',
|
||||||
|
placeholder: 'Select a model...',
|
||||||
|
typeOptions: {
|
||||||
|
searchListMethod: 'searchModels',
|
||||||
|
searchable: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
displayName: 'ID',
|
||||||
|
name: 'id',
|
||||||
|
type: 'string',
|
||||||
|
placeholder: '2302163813',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
description: 'The model. Choose from the list, or specify an ID.',
|
||||||
|
displayOptions: {
|
||||||
|
hide: {
|
||||||
|
'@version': [{ _cnd: { lte: 1.1 } }],
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
displayName:
|
displayName:
|
||||||
|
@ -251,7 +294,12 @@ export class LmChatOpenAi implements INodeType {
|
||||||
async supplyData(this: ISupplyDataFunctions, itemIndex: number): Promise<SupplyData> {
|
async supplyData(this: ISupplyDataFunctions, itemIndex: number): Promise<SupplyData> {
|
||||||
const credentials = await this.getCredentials('openAiApi');
|
const credentials = await this.getCredentials('openAiApi');
|
||||||
|
|
||||||
const modelName = this.getNodeParameter('model', itemIndex) as string;
|
const version = this.getNode().typeVersion;
|
||||||
|
const modelName =
|
||||||
|
version >= 1.2
|
||||||
|
? (this.getNodeParameter('model.value', itemIndex) as string)
|
||||||
|
: (this.getNodeParameter('model', itemIndex) as string);
|
||||||
|
|
||||||
const options = this.getNodeParameter('options', itemIndex, {}) as {
|
const options = this.getNodeParameter('options', itemIndex, {}) as {
|
||||||
baseURL?: string;
|
baseURL?: string;
|
||||||
frequencyPenalty?: number;
|
frequencyPenalty?: number;
|
||||||
|
|
|
@ -0,0 +1,112 @@
|
||||||
|
import type { ILoadOptionsFunctions } from 'n8n-workflow';
|
||||||
|
import OpenAI from 'openai';
|
||||||
|
|
||||||
|
import { searchModels } from '../loadModels';
|
||||||
|
|
||||||
|
jest.mock('openai');
|
||||||
|
|
||||||
|
describe('searchModels', () => {
|
||||||
|
let mockContext: jest.Mocked<ILoadOptionsFunctions>;
|
||||||
|
let mockOpenAI: jest.Mocked<typeof OpenAI>;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
mockContext = {
|
||||||
|
getCredentials: jest.fn().mockResolvedValue({
|
||||||
|
apiKey: 'test-api-key',
|
||||||
|
}),
|
||||||
|
getNodeParameter: jest.fn().mockReturnValue(''),
|
||||||
|
} as unknown as jest.Mocked<ILoadOptionsFunctions>;
|
||||||
|
|
||||||
|
// Setup OpenAI mock with required properties
|
||||||
|
const mockOpenAIInstance = {
|
||||||
|
apiKey: 'test-api-key',
|
||||||
|
organization: null,
|
||||||
|
project: null,
|
||||||
|
_options: {},
|
||||||
|
models: {
|
||||||
|
list: jest.fn().mockResolvedValue({
|
||||||
|
data: [
|
||||||
|
{ id: 'gpt-4' },
|
||||||
|
{ id: 'gpt-3.5-turbo' },
|
||||||
|
{ id: 'gpt-3.5-turbo-instruct' },
|
||||||
|
{ id: 'ft:gpt-3.5-turbo' },
|
||||||
|
{ id: 'o1-model' },
|
||||||
|
{ id: 'other-model' },
|
||||||
|
],
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
} as unknown as OpenAI;
|
||||||
|
|
||||||
|
(OpenAI as jest.MockedClass<typeof OpenAI>).mockImplementation(() => mockOpenAIInstance);
|
||||||
|
|
||||||
|
mockOpenAI = OpenAI as jest.Mocked<typeof OpenAI>;
|
||||||
|
});
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return filtered models if custom API endpoint is not provided', async () => {
|
||||||
|
const result = await searchModels.call(mockContext);
|
||||||
|
|
||||||
|
expect(mockOpenAI).toHaveBeenCalledWith({
|
||||||
|
baseURL: 'https://api.openai.com/v1',
|
||||||
|
apiKey: 'test-api-key',
|
||||||
|
});
|
||||||
|
expect(result.results).toHaveLength(4);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should initialize OpenAI with correct credentials', async () => {
|
||||||
|
mockContext.getCredentials.mockResolvedValueOnce({
|
||||||
|
apiKey: 'test-api-key',
|
||||||
|
url: 'https://test-url.com',
|
||||||
|
});
|
||||||
|
await searchModels.call(mockContext);
|
||||||
|
|
||||||
|
expect(mockOpenAI).toHaveBeenCalledWith({
|
||||||
|
baseURL: 'https://test-url.com',
|
||||||
|
apiKey: 'test-api-key',
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should use default OpenAI URL if no custom URL provided', async () => {
|
||||||
|
mockContext.getCredentials = jest.fn().mockResolvedValue({
|
||||||
|
apiKey: 'test-api-key',
|
||||||
|
});
|
||||||
|
|
||||||
|
await searchModels.call(mockContext);
|
||||||
|
|
||||||
|
expect(mockOpenAI).toHaveBeenCalledWith({
|
||||||
|
baseURL: 'https://api.openai.com/v1',
|
||||||
|
apiKey: 'test-api-key',
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should include all models for custom API endpoints', async () => {
|
||||||
|
mockContext.getNodeParameter = jest.fn().mockReturnValue('https://custom-api.com');
|
||||||
|
|
||||||
|
const result = await searchModels.call(mockContext);
|
||||||
|
|
||||||
|
expect(result.results).toHaveLength(6);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should filter models based on search term', async () => {
|
||||||
|
const result = await searchModels.call(mockContext, 'gpt');
|
||||||
|
|
||||||
|
expect(result.results).toEqual([
|
||||||
|
{ name: 'gpt-4', value: 'gpt-4' },
|
||||||
|
{ name: 'gpt-3.5-turbo', value: 'gpt-3.5-turbo' },
|
||||||
|
{ name: 'ft:gpt-3.5-turbo', value: 'ft:gpt-3.5-turbo' },
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle case-insensitive search', async () => {
|
||||||
|
const result = await searchModels.call(mockContext, 'GPT');
|
||||||
|
|
||||||
|
expect(result.results).toEqual([
|
||||||
|
{ name: 'gpt-4', value: 'gpt-4' },
|
||||||
|
{ name: 'gpt-3.5-turbo', value: 'gpt-3.5-turbo' },
|
||||||
|
{ name: 'ft:gpt-3.5-turbo', value: 'ft:gpt-3.5-turbo' },
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
});
|
|
@ -0,0 +1,37 @@
|
||||||
|
import type { ILoadOptionsFunctions, INodeListSearchResult } from 'n8n-workflow';
|
||||||
|
import OpenAI from 'openai';
|
||||||
|
|
||||||
|
export async function searchModels(
|
||||||
|
this: ILoadOptionsFunctions,
|
||||||
|
filter?: string,
|
||||||
|
): Promise<INodeListSearchResult> {
|
||||||
|
const credentials = await this.getCredentials('openAiApi');
|
||||||
|
const baseURL =
|
||||||
|
(this.getNodeParameter('options.baseURL', '') as string) ||
|
||||||
|
(credentials.url as string) ||
|
||||||
|
'https://api.openai.com/v1';
|
||||||
|
|
||||||
|
const openai = new OpenAI({ baseURL, apiKey: credentials.apiKey as string });
|
||||||
|
const { data: models = [] } = await openai.models.list();
|
||||||
|
|
||||||
|
const filteredModels = models.filter((model: { id: string }) => {
|
||||||
|
const isValidModel =
|
||||||
|
(baseURL && !baseURL.includes('api.openai.com')) ||
|
||||||
|
model.id.startsWith('ft:') ||
|
||||||
|
model.id.startsWith('o1') ||
|
||||||
|
(model.id.startsWith('gpt-') && !model.id.includes('instruct'));
|
||||||
|
|
||||||
|
if (!filter) return isValidModel;
|
||||||
|
|
||||||
|
return isValidModel && model.id.toLowerCase().includes(filter.toLowerCase());
|
||||||
|
});
|
||||||
|
|
||||||
|
const results = {
|
||||||
|
results: filteredModels.map((model: { id: string }) => ({
|
||||||
|
name: model.id,
|
||||||
|
value: model.id,
|
||||||
|
})),
|
||||||
|
};
|
||||||
|
|
||||||
|
return results;
|
||||||
|
}
|
|
@ -76,9 +76,15 @@ export async function modelSearch(
|
||||||
this: ILoadOptionsFunctions,
|
this: ILoadOptionsFunctions,
|
||||||
filter?: string,
|
filter?: string,
|
||||||
): Promise<INodeListSearchResult> {
|
): Promise<INodeListSearchResult> {
|
||||||
|
const credentials = await this.getCredentials<{ url: string }>('openAiApi');
|
||||||
|
const isCustomAPI = credentials.url && !credentials.url.includes('api.openai.com');
|
||||||
|
|
||||||
return await getModelSearch(
|
return await getModelSearch(
|
||||||
(model) =>
|
(model) =>
|
||||||
model.id.startsWith('gpt-') || model.id.startsWith('ft:') || model.id.startsWith('o1'),
|
isCustomAPI ||
|
||||||
|
model.id.startsWith('gpt-') ||
|
||||||
|
model.id.startsWith('ft:') ||
|
||||||
|
model.id.startsWith('o1'),
|
||||||
)(this, filter);
|
)(this, filter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue