diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 3e88d460e55..8d4503460af 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -1165,9 +1165,12 @@ "externalApiKeyPlaceholderSet": "API key configured", "externalApiKeyHelper": "Stored in api_keys.yaml in your InvokeAI root directory.", "externalBaseUrl": "Base URL (optional)", + "externalOverrideBaseUrl": "Override Base URL", "externalBaseUrlPlaceholder": "https://...", "externalBaseUrlHelper": "Override the default API base URL if needed.", "externalResetHelper": "Clear API key and base URL.", + "externalProviderSaveFailed": "Failed to save external provider configuration.", + "externalProviderResetFailed": "Failed to reset external provider configuration.", "height": "Height", "huggingFace": "HuggingFace", "huggingFacePlaceholder": "owner/model-name", diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ExternalProviders/ExternalProvidersForm.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ExternalProviders/ExternalProvidersForm.tsx index 9e1b861cce4..b4bbe8d6335 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ExternalProviders/ExternalProvidersForm.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ExternalProviders/ExternalProvidersForm.tsx @@ -1,3 +1,4 @@ +import type { SystemStyleObject } from '@invoke-ai/ui-library'; import { Badge, Button, @@ -8,18 +9,20 @@ import { FormLabel, Heading, Input, + Switch, Text, Tooltip, + useToast, } from '@invoke-ai/ui-library'; -import { useStore } from '@nanostores/react'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; import { useBuildModelInstallArg } from 'features/modelManagerV2/hooks/useBuildModelsToInstall'; import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel'; -import { $installModelsTabIndex } from 'features/modelManagerV2/store/installModelsStore'; import type { ChangeEvent } from 'react'; import { memo, useCallback, useEffect, useMemo, useState } from 'react'; import { useTranslation } from 'react-i18next'; +import type { IconType } from 'react-icons'; import { PiCheckBold, PiWarningBold } from 'react-icons/pi'; +import { SiAlibabacloud, SiBytedance, SiGooglegemini, SiOpenai } from 'react-icons/si'; import { useGetExternalProviderConfigsQuery, useResetExternalProviderConfigMutation, @@ -30,15 +33,39 @@ import type { ExternalProviderConfig, StarterModel } from 'services/api/types'; const PROVIDER_SORT_ORDER = ['gemini', 'openai', 'seedream', 'alibabacloud']; +function resolveProviderIcon(providerId: string): IconType | null { + const provider = providerId.toLowerCase(); + + switch (provider) { + case 'openai': + return SiOpenai; + case 'gemini': + return SiGooglegemini; + case 'seedream': + return SiBytedance; + case 'alibabacloud': + return SiAlibabacloud; + default: + return null; + } +} + +const FORM_CONTROL_SX: SystemStyleObject = { + flexDir: 'column', + alignItems: 'flex-start', + gap: 2, +}; + type ProviderCardProps = { provider: ExternalProviderConfig; onInstallModels: (providerId: string) => void; + iconResolver: (providerId: string) => IconType | null; }; type UpdatePayload = { provider_id: string; api_key?: string; - base_url?: string; + base_url?: string | null; }; export const ExternalProvidersForm = memo(() => { @@ -47,7 +74,6 @@ export const ExternalProvidersForm = memo(() => { const { data: starterModels } = useGetStarterModelsQuery(); const [installModel] = useInstallModel(); const { getIsInstalled, buildModelInstallArg } = useBuildModelInstallArg(); - const tabIndex = useStore($installModelsTabIndex); const externalModelsByProvider = useMemo(() => { const groups = new Map(); @@ -55,7 +81,7 @@ export const ExternalProvidersForm = memo(() => { if (!model.source.startsWith('external://')) { continue; } - const providerId = model.source.replace('external://', '').split('/')[0]; + const providerId = model.source.slice('external://'.length).split('/')[0]; if (!providerId) { continue; } @@ -107,8 +133,9 @@ export const ExternalProvidersForm = memo(() => { return ( - {t('modelManager.externalSetupTitle')} - {t('modelManager.externalSetupDescription')} + {t('modelManager.externalSetupTitle')} + {t('modelManager.externalSetupDescription')} + {t('modelManager.externalSetupFooter')} @@ -120,32 +147,39 @@ export const ExternalProvidersForm = memo(() => { ))} - {tabIndex === 3 && ( - - {t('modelManager.externalSetupFooter')} - - )} ); }); ExternalProvidersForm.displayName = 'ExternalProvidersForm'; -const ProviderCard = memo(({ provider, onInstallModels }: ProviderCardProps) => { +const ProviderCard = memo(({ provider, onInstallModels, iconResolver }: ProviderCardProps) => { const { t } = useTranslation(); + const toast = useToast(); const [apiKey, setApiKey] = useState(''); const [baseUrl, setBaseUrl] = useState(provider.base_url ?? ''); const [saveConfig, { isLoading }] = useSetExternalProviderConfigMutation(); const [resetConfig, { isLoading: isResetting }] = useResetExternalProviderConfigMutation(); + const [overrideBaseUrl, setOverrideBaseUrl] = useState(!!provider.base_url); useEffect(() => { + setApiKey(''); setBaseUrl(provider.base_url ?? ''); - }, [provider.base_url]); + setOverrideBaseUrl(!!provider.base_url); + }, [provider.base_url, provider.provider_id]); + + const hasBaseUrlChange = useMemo(() => { + if (!overrideBaseUrl) { + return provider.base_url !== null; + } + return baseUrl.trim() !== (provider.base_url ?? ''); + }, [baseUrl, overrideBaseUrl, provider.base_url]); const handleSave = useCallback(() => { const trimmedApiKey = apiKey.trim(); @@ -156,7 +190,9 @@ const ProviderCard = memo(({ provider, onInstallModels }: ProviderCardProps) => if (trimmedApiKey) { updatePayload.api_key = trimmedApiKey; } - if (trimmedBaseUrl !== (provider.base_url ?? '')) { + if (!overrideBaseUrl && provider.base_url !== null) { + updatePayload.base_url = null; + } else if (overrideBaseUrl && trimmedBaseUrl !== (provider.base_url ?? '')) { updatePayload.base_url = trimmedBaseUrl; } @@ -167,15 +203,31 @@ const ProviderCard = memo(({ provider, onInstallModels }: ProviderCardProps) => saveConfig(updatePayload) .unwrap() .then((result) => { - if (result.api_key_configured) { + if (trimmedApiKey && result.api_key_configured) { setApiKey(''); onInstallModels(provider.provider_id); } - if (result.base_url !== undefined) { - setBaseUrl(result.base_url ?? ''); - } + setBaseUrl(result.base_url ?? ''); + setOverrideBaseUrl(!!result.base_url); + }) + .catch(() => { + toast({ + id: `EXTERNAL_PROVIDER_SAVE_FAILED_${provider.provider_id}`, + title: t('modelManager.externalProviderSaveFailed'), + status: 'error', + }); }); - }, [apiKey, baseUrl, onInstallModels, provider.base_url, provider.provider_id, saveConfig]); + }, [ + apiKey, + baseUrl, + onInstallModels, + overrideBaseUrl, + provider.base_url, + provider.provider_id, + saveConfig, + t, + toast, + ]); const handleReset = useCallback(() => { resetConfig(provider.provider_id) @@ -183,8 +235,16 @@ const ProviderCard = memo(({ provider, onInstallModels }: ProviderCardProps) => .then((result) => { setApiKey(''); setBaseUrl(result.base_url ?? ''); + setOverrideBaseUrl(!!result.base_url); + }) + .catch(() => { + toast({ + id: `EXTERNAL_PROVIDER_RESET_FAILED_${provider.provider_id}`, + title: t('modelManager.externalProviderResetFailed'), + status: 'error', + }); }); - }, [provider.provider_id, resetConfig]); + }, [provider.provider_id, resetConfig, t, toast]); const handleApiKeyChange = useCallback((event: ChangeEvent) => { setApiKey(event.target.value); @@ -206,21 +266,34 @@ const ProviderCard = memo(({ provider, onInstallModels }: ProviderCardProps) => ); + const handleOverrideBaseUrlChange = useCallback((event: ChangeEvent) => { + event.stopPropagation(); + setOverrideBaseUrl(event.target.checked); + if (!event.target.checked) { + setBaseUrl(''); + } + }, []); + + const ProviderIcon = iconResolver(provider.provider_id); + return ( - + - - - {provider.provider_id} - - - {t('modelManager.externalProviderCardDescription', { providerId: provider.provider_id })} - + + {ProviderIcon && } + + + {provider.provider_id} + + + {t('modelManager.externalProviderCardDescription', { providerId: provider.provider_id })} + + {statusBadge} - + {t('modelManager.externalApiKey')} /> {t('modelManager.externalApiKeyHelper')} - - {t('modelManager.externalBaseUrl')} - + - {t('modelManager.externalBaseUrlHelper')} + + {t('modelManager.externalOverrideBaseUrl')} + - + +