feat(Google Vertex Chat Model Node): Add support for Google Vertex AI Chat models (#9970)

Co-authored-by: oleg <oleg@n8n.io>
This commit is contained in:
Eugene 2024-07-11 14:41:10 +02:00 committed by GitHub
parent 519e57bda5
commit 071130a2dc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 539 additions and 675 deletions

View file

@ -80,6 +80,7 @@ function getInputs(
'@n8n/n8n-nodes-langchain.lmChatOpenAi', '@n8n/n8n-nodes-langchain.lmChatOpenAi',
'@n8n/n8n-nodes-langchain.lmChatGooglePalm', '@n8n/n8n-nodes-langchain.lmChatGooglePalm',
'@n8n/n8n-nodes-langchain.lmChatGoogleGemini', '@n8n/n8n-nodes-langchain.lmChatGoogleGemini',
'@n8n/n8n-nodes-langchain.lmChatGoogleVertex',
'@n8n/n8n-nodes-langchain.lmChatMistralCloud', '@n8n/n8n-nodes-langchain.lmChatMistralCloud',
'@n8n/n8n-nodes-langchain.lmChatAzureOpenAi', '@n8n/n8n-nodes-langchain.lmChatAzureOpenAi',
], ],
@ -106,6 +107,7 @@ function getInputs(
'@n8n/n8n-nodes-langchain.lmChatMistralCloud', '@n8n/n8n-nodes-langchain.lmChatMistralCloud',
'@n8n/n8n-nodes-langchain.lmChatOpenAi', '@n8n/n8n-nodes-langchain.lmChatOpenAi',
'@n8n/n8n-nodes-langchain.lmChatGroq', '@n8n/n8n-nodes-langchain.lmChatGroq',
'@n8n/n8n-nodes-langchain.lmChatGoogleVertex',
], ],
}, },
}, },

View file

@ -226,7 +226,10 @@ export async function toolsAgentExecute(this: IExecuteFunctions): Promise<INodeE
}); });
} catch (error) { } catch (error) {
if (this.continueOnFail(error)) { if (this.continueOnFail(error)) {
returnData.push({ json: { error: error.message }, pairedItem: { item: itemIndex } }); returnData.push({
json: { error: error.message },
pairedItem: { item: itemIndex },
});
continue; continue;
} }

View file

@ -7,10 +7,10 @@ import {
type SupplyData, type SupplyData,
} from 'n8n-workflow'; } from 'n8n-workflow';
import { ChatGoogleGenerativeAI } from '@langchain/google-genai'; import { ChatGoogleGenerativeAI } from '@langchain/google-genai';
import type { HarmBlockThreshold, HarmCategory, SafetySetting } from '@google/generative-ai'; import type { SafetySetting } from '@google/generative-ai';
import { getConnectionHintNoticeField } from '../../../utils/sharedFields'; import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
import { N8nLlmTracing } from '../N8nLlmTracing'; import { N8nLlmTracing } from '../N8nLlmTracing';
import { harmCategories, harmThresholds } from './options'; import { additionalOptions } from '../gemini-common/additional-options';
export class LmChatGoogleGemini implements INodeType { export class LmChatGoogleGemini implements INodeType {
description: INodeTypeDescription = { description: INodeTypeDescription = {
@ -108,89 +108,7 @@ export class LmChatGoogleGemini implements INodeType {
}, },
default: 'models/gemini-1.0-pro', default: 'models/gemini-1.0-pro',
}, },
{ additionalOptions,
displayName: 'Options',
name: 'options',
placeholder: 'Add Option',
description: 'Additional options to add',
type: 'collection',
default: {},
options: [
{
displayName: 'Maximum Number of Tokens',
name: 'maxOutputTokens',
default: 2048,
description: 'The maximum number of tokens to generate in the completion',
type: 'number',
},
{
displayName: 'Sampling Temperature',
name: 'temperature',
default: 0.4,
typeOptions: { maxValue: 1, minValue: 0, numberPrecision: 1 },
description:
'Controls randomness: Lowering results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive.',
type: 'number',
},
{
displayName: 'Top K',
name: 'topK',
default: 32,
typeOptions: { maxValue: 40, minValue: -1, numberPrecision: 1 },
description:
'Used to remove "long tail" low probability responses. Defaults to -1, which disables it.',
type: 'number',
},
{
displayName: 'Top P',
name: 'topP',
default: 1,
typeOptions: { maxValue: 1, minValue: 0, numberPrecision: 1 },
description:
'Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered. We generally recommend altering this or temperature but not both.',
type: 'number',
},
// Safety Settings
{
displayName: 'Safety Settings',
name: 'safetySettings',
type: 'fixedCollection',
typeOptions: { multipleValues: true },
default: {
values: {
category: harmCategories[0].name as HarmCategory,
threshold: harmThresholds[0].name as HarmBlockThreshold,
},
},
placeholder: 'Add Option',
options: [
{
name: 'values',
displayName: 'Values',
values: [
{
displayName: 'Safety Category',
name: 'category',
type: 'options',
description: 'The category of harmful content to block',
default: 'HARM_CATEGORY_UNSPECIFIED',
options: harmCategories,
},
{
displayName: 'Safety Threshold',
name: 'threshold',
type: 'options',
description: 'The threshold of harmful content to block',
default: 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
options: harmThresholds,
},
],
},
],
},
],
},
], ],
}; };

View file

@ -0,0 +1,200 @@
/* eslint-disable n8n-nodes-base/node-dirname-against-convention */
import {
NodeConnectionType,
type IExecuteFunctions,
type INodeType,
type INodeTypeDescription,
type SupplyData,
type ILoadOptionsFunctions,
type JsonObject,
NodeOperationError,
} from 'n8n-workflow';
import { ChatVertexAI } from '@langchain/google-vertexai';
import type { SafetySetting } from '@google/generative-ai';
import { ProjectsClient } from '@google-cloud/resource-manager';
import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
import { N8nLlmTracing } from '../N8nLlmTracing';
import { makeErrorFromStatus } from './error-handling';
import { additionalOptions } from '../gemini-common/additional-options';
export class LmChatGoogleVertex implements INodeType {
description: INodeTypeDescription = {
displayName: 'Google Vertex Chat Model',
// eslint-disable-next-line n8n-nodes-base/node-class-description-name-miscased
name: 'lmChatGoogleVertex',
icon: 'file:google.svg',
group: ['transform'],
version: 1,
description: 'Chat Model Google Vertex',
defaults: {
name: 'Google Vertex Chat Model',
},
codex: {
categories: ['AI'],
subcategories: {
AI: ['Language Models'],
},
resources: {
primaryDocumentation: [
{
url: 'https://docs.n8n.io/integrations/builtin/cluster-nodes/sub-nodes/n8n-nodes-langchain.lmchatgooglevertex/',
},
],
},
},
// eslint-disable-next-line n8n-nodes-base/node-class-description-inputs-wrong-regular-node
inputs: [],
// eslint-disable-next-line n8n-nodes-base/node-class-description-outputs-wrong
outputs: [NodeConnectionType.AiLanguageModel],
outputNames: ['Model'],
credentials: [
{
name: 'googleApi',
required: true,
},
],
properties: [
getConnectionHintNoticeField([NodeConnectionType.AiChain, NodeConnectionType.AiAgent]),
{
displayName: 'Project ID',
name: 'projectId',
type: 'resourceLocator',
default: { mode: 'list', value: '' },
required: true,
description: 'Select or enter your Google Cloud project ID',
modes: [
{
displayName: 'From List',
name: 'list',
type: 'list',
typeOptions: {
searchListMethod: 'gcpProjectsList',
},
},
{
displayName: 'ID',
name: 'id',
type: 'string',
},
],
},
{
displayName: 'Model Name',
name: 'modelName',
type: 'string',
description:
'The model which will generate the completion. <a href="https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models">Learn more</a>.',
default: 'gemini-1.5-flash',
},
additionalOptions,
],
};
methods = {
listSearch: {
async gcpProjectsList(this: ILoadOptionsFunctions) {
const results: Array<{ name: string; value: string }> = [];
const credentials = await this.getCredentials('googleApi');
const client = new ProjectsClient({
credentials: {
client_email: credentials.email as string,
private_key: credentials.privateKey as string,
},
});
const [projects] = await client.searchProjects();
for (const project of projects) {
if (project.projectId) {
results.push({
name: project.displayName ?? project.projectId,
value: project.projectId,
});
}
}
return { results };
},
},
};
async supplyData(this: IExecuteFunctions, itemIndex: number): Promise<SupplyData> {
const credentials = await this.getCredentials('googleApi');
const modelName = this.getNodeParameter('modelName', itemIndex) as string;
const projectId = this.getNodeParameter('projectId', itemIndex, '', {
extractValue: true,
}) as string;
const options = this.getNodeParameter('options', itemIndex, {
maxOutputTokens: 2048,
temperature: 0.4,
topK: 40,
topP: 0.9,
}) as {
maxOutputTokens: number;
temperature: number;
topK: number;
topP: number;
};
const safetySettings = this.getNodeParameter(
'options.safetySettings.values',
itemIndex,
null,
) as SafetySetting[];
try {
const model = new ChatVertexAI({
authOptions: {
projectId,
credentials: {
client_email: credentials.email as string,
private_key: credentials.privateKey as string,
},
},
model: modelName,
topK: options.topK,
topP: options.topP,
temperature: options.temperature,
maxOutputTokens: options.maxOutputTokens,
safetySettings,
callbacks: [new N8nLlmTracing(this)],
// Handle ChatVertexAI invocation errors to provide better error messages
onFailedAttempt: (error: any) => {
const customError = makeErrorFromStatus(Number(error?.response?.status), {
modelName,
});
if (customError) {
throw new NodeOperationError(this.getNode(), error as JsonObject, customError);
}
throw error;
},
});
return {
response: model,
};
} catch (e) {
// Catch model name validation error from LangChain (https://github.com/langchain-ai/langchainjs/blob/ef201d0ee85ee4049078270a0cfd7a1767e624f8/libs/langchain-google-common/src/utils/common.ts#L124)
// to show more helpful error message
if (e?.message?.startsWith('Unable to verify model params')) {
throw new NodeOperationError(this.getNode(), e as JsonObject, {
message: 'Unsupported model',
description: "Only models starting with 'gemini' are supported.",
});
}
// Assume all other exceptions while creating a new ChatVertexAI instance are parameter validation errors
throw new NodeOperationError(this.getNode(), e as JsonObject, {
message: 'Invalid options',
description: e.message,
});
}
}
}

View file

@ -0,0 +1,25 @@
export interface ErrorLike {
message?: string;
description?: string;
}
export interface ErrorContext {
modelName?: string;
}
export function makeErrorFromStatus(statusCode: number, context?: ErrorContext): ErrorLike {
const errorMessages: Record<number, ErrorLike> = {
403: {
message: 'Unauthorized for this project',
description:
'Check your Google Cloud project ID, and that your credential has access to that project',
},
404: {
message: context?.modelName
? `No model found called '${context.modelName}'`
: 'No model found',
},
};
return errorMessages[statusCode];
}

View file

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 48 48"><defs><path id="a" d="M44.5 20H24v8.5h11.8C34.7 33.9 30.1 37 24 37c-7.2 0-13-5.8-13-13s5.8-13 13-13c3.1 0 5.9 1.1 8.1 2.9l6.4-6.4C34.6 4.1 29.6 2 24 2 11.8 2 2 11.8 2 24s9.8 22 22 22c11 0 21-8 21-22 0-1.3-.2-2.7-.5-4"/></defs><clipPath id="b"><use xlink:href="#a" overflow="visible"/></clipPath><path fill="#FBBC05" d="M0 37V11l17 13z" clip-path="url(#b)"/><path fill="#EA4335" d="m0 11 17 13 7-6.1L48 14V0H0z" clip-path="url(#b)"/><path fill="#34A853" d="m0 37 30-23 7.9 1L48 0v48H0z" clip-path="url(#b)"/><path fill="#4285F4" d="M48 48 17 24l-4-3 35-10z" clip-path="url(#b)"/></svg>

After

Width:  |  Height:  |  Size: 687 B

View file

@ -0,0 +1,87 @@
import type { HarmBlockThreshold, HarmCategory } from '@google/generative-ai';
import type { INodeProperties } from 'n8n-workflow';
import { harmCategories, harmThresholds } from './safety-options';
export const additionalOptions: INodeProperties = {
displayName: 'Options',
name: 'options',
placeholder: 'Add Option',
description: 'Additional options to add',
type: 'collection',
default: {},
options: [
{
displayName: 'Maximum Number of Tokens',
name: 'maxOutputTokens',
default: 2048,
description: 'The maximum number of tokens to generate in the completion',
type: 'number',
},
{
displayName: 'Sampling Temperature',
name: 'temperature',
default: 0.4,
typeOptions: { maxValue: 1, minValue: 0, numberPrecision: 1 },
description:
'Controls randomness: Lowering results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive.',
type: 'number',
},
{
displayName: 'Top K',
name: 'topK',
default: 32,
typeOptions: { maxValue: 40, minValue: -1, numberPrecision: 1 },
description:
'Used to remove "long tail" low probability responses. Defaults to -1, which disables it.',
type: 'number',
},
{
displayName: 'Top P',
name: 'topP',
default: 1,
typeOptions: { maxValue: 1, minValue: 0, numberPrecision: 1 },
description:
'Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered. We generally recommend altering this or temperature but not both.',
type: 'number',
},
// Safety Settings
{
displayName: 'Safety Settings',
name: 'safetySettings',
type: 'fixedCollection',
typeOptions: { multipleValues: true },
default: {
values: {
category: harmCategories[0].name as HarmCategory,
threshold: harmThresholds[0].name as HarmBlockThreshold,
},
},
placeholder: 'Add Option',
options: [
{
name: 'values',
displayName: 'Values',
values: [
{
displayName: 'Safety Category',
name: 'category',
type: 'options',
description: 'The category of harmful content to block',
default: 'HARM_CATEGORY_UNSPECIFIED',
options: harmCategories,
},
{
displayName: 'Safety Threshold',
name: 'threshold',
type: 'options',
description: 'The threshold of harmful content to block',
default: 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
options: harmThresholds,
},
],
},
],
},
],
};

View file

@ -65,6 +65,7 @@
"dist/nodes/llms/LmChatAwsBedrock/LmChatAwsBedrock.node.js", "dist/nodes/llms/LmChatAwsBedrock/LmChatAwsBedrock.node.js",
"dist/nodes/llms/LmChatGooglePalm/LmChatGooglePalm.node.js", "dist/nodes/llms/LmChatGooglePalm/LmChatGooglePalm.node.js",
"dist/nodes/llms/LmChatGoogleGemini/LmChatGoogleGemini.node.js", "dist/nodes/llms/LmChatGoogleGemini/LmChatGoogleGemini.node.js",
"dist/nodes/llms/LmChatGoogleVertex/LmChatGoogleVertex.node.js",
"dist/nodes/llms/LmChatGroq/LmChatGroq.node.js", "dist/nodes/llms/LmChatGroq/LmChatGroq.node.js",
"dist/nodes/llms/LmChatMistralCloud/LmChatMistralCloud.node.js", "dist/nodes/llms/LmChatMistralCloud/LmChatMistralCloud.node.js",
"dist/nodes/llms/LMChatOllama/LmChatOllama.node.js", "dist/nodes/llms/LMChatOllama/LmChatOllama.node.js",
@ -129,6 +130,7 @@
"@getzep/zep-cloud": "1.0.6", "@getzep/zep-cloud": "1.0.6",
"@getzep/zep-js": "0.9.0", "@getzep/zep-js": "0.9.0",
"@google-ai/generativelanguage": "2.5.0", "@google-ai/generativelanguage": "2.5.0",
"@google-cloud/resource-manager": "5.3.0",
"@google/generative-ai": "0.11.4", "@google/generative-ai": "0.11.4",
"@huggingface/inference": "2.7.0", "@huggingface/inference": "2.7.0",
"@langchain/anthropic": "0.1.21", "@langchain/anthropic": "0.1.21",
@ -136,6 +138,7 @@
"@langchain/community": "0.2.13", "@langchain/community": "0.2.13",
"@langchain/core": "0.2.9", "@langchain/core": "0.2.9",
"@langchain/google-genai": "0.0.16", "@langchain/google-genai": "0.0.16",
"@langchain/google-vertexai": "0.0.19",
"@langchain/groq": "0.0.12", "@langchain/groq": "0.0.12",
"@langchain/mistralai": "0.0.22", "@langchain/mistralai": "0.0.22",
"@langchain/openai": "0.0.33", "@langchain/openai": "0.0.33",

View file

@ -44,6 +44,12 @@ export function isChatInstance(model: unknown): model is BaseChatModel {
return namespace.includes('chat_models'); return namespace.includes('chat_models');
} }
export function isToolsInstance(model: unknown): model is Tool {
const namespace = (model as Tool)?.lc_namespace ?? [];
return namespace.includes('tools');
}
export async function getOptionalOutputParsers( export async function getOptionalOutputParsers(
ctx: IExecuteFunctions, ctx: IExecuteFunctions,
): Promise<Array<BaseOutputParser<unknown>>> { ): Promise<Array<BaseOutputParser<unknown>>> {

View file

@ -1,7 +1,7 @@
import { NodeOperationError, NodeConnectionType } from 'n8n-workflow'; import { NodeOperationError, NodeConnectionType } from 'n8n-workflow';
import type { ConnectionTypes, IExecuteFunctions, INodeExecutionData } from 'n8n-workflow'; import type { ConnectionTypes, IExecuteFunctions, INodeExecutionData } from 'n8n-workflow';
import { Tool } from '@langchain/core/tools'; import type { Tool } from '@langchain/core/tools';
import type { BaseMessage } from '@langchain/core/messages'; import type { BaseMessage } from '@langchain/core/messages';
import type { InputValues, MemoryVariables, OutputValues } from '@langchain/core/memory'; import type { InputValues, MemoryVariables, OutputValues } from '@langchain/core/memory';
import { BaseChatMessageHistory } from '@langchain/core/chat_history'; import { BaseChatMessageHistory } from '@langchain/core/chat_history';
@ -18,7 +18,7 @@ import { isObject } from 'lodash';
import type { BaseDocumentLoader } from 'langchain/dist/document_loaders/base'; import type { BaseDocumentLoader } from 'langchain/dist/document_loaders/base';
import { N8nJsonLoader } from './N8nJsonLoader'; import { N8nJsonLoader } from './N8nJsonLoader';
import { N8nBinaryLoader } from './N8nBinaryLoader'; import { N8nBinaryLoader } from './N8nBinaryLoader';
import { logAiEvent } from './helpers'; import { logAiEvent, isToolsInstance } from './helpers';
const errorsMap: { [key: string]: { message: string; description: string } } = { const errorsMap: { [key: string]: { message: string; description: string } } = {
'You exceeded your current quota, please check your plan and billing details.': { 'You exceeded your current quota, please check your plan and billing details.': {
@ -401,7 +401,7 @@ export function logWrapper(
} }
// ========== Tool ========== // ========== Tool ==========
if (originalInstance instanceof Tool) { if (isToolsInstance(originalInstance)) {
if (prop === '_call' && '_call' in target) { if (prop === '_call' && '_call' in target) {
return async (query: string): Promise<string> => { return async (query: string): Promise<string> => {
connectionType = NodeConnectionType.AiTool; connectionType = NodeConnectionType.AiTool;

View file

@ -56,6 +56,7 @@ const googleServiceAccountScopes = {
'https://www.googleapis.com/auth/datastore', 'https://www.googleapis.com/auth/datastore',
'https://www.googleapis.com/auth/firebase', 'https://www.googleapis.com/auth/firebase',
], ],
vertex: ['https://www.googleapis.com/auth/cloud-platform'],
}; };
type GoogleServiceAccount = keyof typeof googleServiceAccountScopes; type GoogleServiceAccount = keyof typeof googleServiceAccountScopes;

File diff suppressed because it is too large Load diff