Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions invokeai/frontend/web/public/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -1165,9 +1165,12 @@
"externalApiKeyPlaceholderSet": "API key configured",
"externalApiKeyHelper": "Stored in api_keys.yaml in your InvokeAI root directory.",
"externalBaseUrl": "Base URL (optional)",
"externalOverrideBaseUrl": "Override Base URL",
"externalBaseUrlPlaceholder": "https://...",
"externalBaseUrlHelper": "Override the default API base URL if needed.",
"externalResetHelper": "Clear API key and base URL.",
"externalProviderSaveFailed": "Failed to save external provider configuration.",
"externalProviderResetFailed": "Failed to reset external provider configuration.",
"height": "Height",
"huggingFace": "HuggingFace",
"huggingFacePlaceholder": "owner/model-name",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import {
Badge,
Button,
Expand All @@ -8,18 +9,20 @@ import {
FormLabel,
Heading,
Input,
Switch,
Text,
Tooltip,
useToast,
} from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { useBuildModelInstallArg } from 'features/modelManagerV2/hooks/useBuildModelsToInstall';
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
import { $installModelsTabIndex } from 'features/modelManagerV2/store/installModelsStore';
import type { ChangeEvent } from 'react';
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import type { IconType } from 'react-icons';
import { PiCheckBold, PiWarningBold } from 'react-icons/pi';
import { SiAlibabacloud, SiBytedance, SiGooglegemini, SiOpenai } from 'react-icons/si';
import {
useGetExternalProviderConfigsQuery,
useResetExternalProviderConfigMutation,
Expand All @@ -30,15 +33,39 @@ import type { ExternalProviderConfig, StarterModel } from 'services/api/types';

const PROVIDER_SORT_ORDER = ['gemini', 'openai', 'seedream', 'alibabacloud'];

function resolveProviderIcon(providerId: string): IconType | null {
const provider = providerId.toLowerCase();

switch (provider) {
case 'openai':
return SiOpenai;
case 'gemini':
return SiGooglegemini;
case 'seedream':
return SiBytedance;
case 'alibabacloud':
return SiAlibabacloud;
default:
return null;
}
}

const FORM_CONTROL_SX: SystemStyleObject = {
flexDir: 'column',
alignItems: 'flex-start',
gap: 2,
};

type ProviderCardProps = {
provider: ExternalProviderConfig;
onInstallModels: (providerId: string) => void;
iconResolver: (providerId: string) => IconType | null;
};

type UpdatePayload = {
provider_id: string;
api_key?: string;
base_url?: string;
base_url?: string | null;
};

export const ExternalProvidersForm = memo(() => {
Expand All @@ -47,15 +74,14 @@ export const ExternalProvidersForm = memo(() => {
const { data: starterModels } = useGetStarterModelsQuery();
const [installModel] = useInstallModel();
const { getIsInstalled, buildModelInstallArg } = useBuildModelInstallArg();
const tabIndex = useStore($installModelsTabIndex);

const externalModelsByProvider = useMemo(() => {
const groups = new Map<string, StarterModel[]>();
for (const model of starterModels?.starter_models ?? []) {
if (!model.source.startsWith('external://')) {
continue;
}
const providerId = model.source.replace('external://', '').split('/')[0];
const providerId = model.source.slice('external://'.length).split('/')[0];
if (!providerId) {
continue;
}
Expand Down Expand Up @@ -107,8 +133,9 @@ export const ExternalProvidersForm = memo(() => {
return (
<Flex flexDir="column" height="100%" gap={4}>
<Flex flexDir="column" gap={1}>
<Heading size="sm">{t('modelManager.externalSetupTitle')}</Heading>
<Text color="base.300">{t('modelManager.externalSetupDescription')}</Text>
<Heading size="md">{t('modelManager.externalSetupTitle')}</Heading>
<Text variant="subtext">{t('modelManager.externalSetupDescription')}</Text>
<Text variant="subtext">{t('modelManager.externalSetupFooter')}</Text>
</Flex>
<ScrollableContent>
<Flex flexDir="column" gap={4}>
Expand All @@ -120,32 +147,39 @@ export const ExternalProvidersForm = memo(() => {
<ProviderCard
key={provider.provider_id}
provider={provider}
iconResolver={resolveProviderIcon}
onInstallModels={handleInstallProviderModels}
/>
))}
</Flex>
</ScrollableContent>
{tabIndex === 3 && (
<Text variant="subtext" color="base.400">
{t('modelManager.externalSetupFooter')}
</Text>
)}
</Flex>
);
});

ExternalProvidersForm.displayName = 'ExternalProvidersForm';

const ProviderCard = memo(({ provider, onInstallModels }: ProviderCardProps) => {
const ProviderCard = memo(({ provider, onInstallModels, iconResolver }: ProviderCardProps) => {
const { t } = useTranslation();
const toast = useToast();
const [apiKey, setApiKey] = useState('');
const [baseUrl, setBaseUrl] = useState(provider.base_url ?? '');
const [saveConfig, { isLoading }] = useSetExternalProviderConfigMutation();
const [resetConfig, { isLoading: isResetting }] = useResetExternalProviderConfigMutation();
const [overrideBaseUrl, setOverrideBaseUrl] = useState(!!provider.base_url);

useEffect(() => {
setApiKey('');
setBaseUrl(provider.base_url ?? '');
}, [provider.base_url]);
setOverrideBaseUrl(!!provider.base_url);
}, [provider.base_url, provider.provider_id]);

const hasBaseUrlChange = useMemo(() => {
if (!overrideBaseUrl) {
return provider.base_url !== null;
}
return baseUrl.trim() !== (provider.base_url ?? '');
}, [baseUrl, overrideBaseUrl, provider.base_url]);

const handleSave = useCallback(() => {
const trimmedApiKey = apiKey.trim();
Expand All @@ -156,7 +190,9 @@ const ProviderCard = memo(({ provider, onInstallModels }: ProviderCardProps) =>
if (trimmedApiKey) {
updatePayload.api_key = trimmedApiKey;
}
if (trimmedBaseUrl !== (provider.base_url ?? '')) {
if (!overrideBaseUrl && provider.base_url !== null) {
updatePayload.base_url = null;
} else if (overrideBaseUrl && trimmedBaseUrl !== (provider.base_url ?? '')) {
updatePayload.base_url = trimmedBaseUrl;
}

Expand All @@ -167,24 +203,48 @@ const ProviderCard = memo(({ provider, onInstallModels }: ProviderCardProps) =>
saveConfig(updatePayload)
.unwrap()
.then((result) => {
if (result.api_key_configured) {
if (trimmedApiKey && result.api_key_configured) {
setApiKey('');
onInstallModels(provider.provider_id);
}
if (result.base_url !== undefined) {
setBaseUrl(result.base_url ?? '');
}
setBaseUrl(result.base_url ?? '');
setOverrideBaseUrl(!!result.base_url);
})
.catch(() => {
toast({
id: `EXTERNAL_PROVIDER_SAVE_FAILED_${provider.provider_id}`,
title: t('modelManager.externalProviderSaveFailed'),
status: 'error',
});
});
}, [apiKey, baseUrl, onInstallModels, provider.base_url, provider.provider_id, saveConfig]);
}, [
apiKey,
baseUrl,
onInstallModels,
overrideBaseUrl,
provider.base_url,
provider.provider_id,
saveConfig,
t,
toast,
]);

const handleReset = useCallback(() => {
resetConfig(provider.provider_id)
.unwrap()
.then((result) => {
setApiKey('');
setBaseUrl(result.base_url ?? '');
setOverrideBaseUrl(!!result.base_url);
})
.catch(() => {
toast({
id: `EXTERNAL_PROVIDER_RESET_FAILED_${provider.provider_id}`,
title: t('modelManager.externalProviderResetFailed'),
status: 'error',
});
});
}, [provider.provider_id, resetConfig]);
}, [provider.provider_id, resetConfig, t, toast]);

const handleApiKeyChange = useCallback((event: ChangeEvent<HTMLInputElement>) => {
setApiKey(event.target.value);
Expand All @@ -206,21 +266,34 @@ const ProviderCard = memo(({ provider, onInstallModels }: ProviderCardProps) =>
</Badge>
);

const handleOverrideBaseUrlChange = useCallback((event: ChangeEvent<HTMLInputElement>) => {
event.stopPropagation();
setOverrideBaseUrl(event.target.checked);
if (!event.target.checked) {
setBaseUrl('');
}
}, []);

const ProviderIcon = iconResolver(provider.provider_id);

return (
<Card p={4} gap={4} variant="outline">
<Card p={4} gap={2} layerStyle="second">
<Flex justifyContent="space-between" alignItems="center" flexWrap="wrap" gap={3}>
<Flex flexDir="column" gap={1}>
<Heading size="xs" textTransform="capitalize">
{provider.provider_id}
</Heading>
<Text variant="subtext">
{t('modelManager.externalProviderCardDescription', { providerId: provider.provider_id })}
</Text>
<Flex alignItems="start" gap="4">
{ProviderIcon && <ProviderIcon />}
<Flex flexDir="column" gap={1} mt="-0.5">
<Heading size="xs" textTransform="capitalize" display="flex" alignItems="center" gap={2}>
{provider.provider_id}
</Heading>
<Text variant="subtext">
{t('modelManager.externalProviderCardDescription', { providerId: provider.provider_id })}
</Text>
</Flex>
</Flex>
{statusBadge}
</Flex>
<Flex flexDir="column" gap={4}>
<FormControl>
<FormControl sx={FORM_CONTROL_SX}>
<FormLabel>{t('modelManager.externalApiKey')}</FormLabel>
<Input
type="password"
Expand All @@ -235,16 +308,28 @@ const ProviderCard = memo(({ provider, onInstallModels }: ProviderCardProps) =>
/>
<FormHelperText>{t('modelManager.externalApiKeyHelper')}</FormHelperText>
</FormControl>
<FormControl>
<FormLabel>{t('modelManager.externalBaseUrl')}</FormLabel>
<Input
placeholder={t('modelManager.externalBaseUrlPlaceholder')}
value={baseUrl}
onChange={handleBaseUrlChange}
<FormControl display="flex" alignItems="center">
<Switch
id={`${provider.provider_id}-override-baseurl`}
isChecked={overrideBaseUrl}
onChange={handleOverrideBaseUrlChange}
/>
<FormHelperText>{t('modelManager.externalBaseUrlHelper')}</FormHelperText>
<FormLabel htmlFor={`${provider.provider_id}-override-baseurl`}>
{t('modelManager.externalOverrideBaseUrl')}
</FormLabel>
</FormControl>
<Flex gap={2} justifyContent="flex-end" flexWrap="wrap">
<Flex hidden={!overrideBaseUrl}>
<FormControl sx={FORM_CONTROL_SX}>
<FormLabel>{t('modelManager.externalBaseUrl')}</FormLabel>
<Input
placeholder={t('modelManager.externalBaseUrlPlaceholder')}
value={baseUrl}
onChange={handleBaseUrlChange}
/>
<FormHelperText>{t('modelManager.externalBaseUrlHelper')}</FormHelperText>
</FormControl>
</Flex>
<Flex gap={2} justifyContent="flex-end" flexWrap="wrap" borderTopWidth="1px" pt="4">
<Tooltip label={t('modelManager.externalResetHelper')}>
<Button variant="ghost" onClick={handleReset} isLoading={isResetting}>
{t('common.reset')}
Expand All @@ -254,7 +339,7 @@ const ProviderCard = memo(({ provider, onInstallModels }: ProviderCardProps) =>
colorScheme="invokeYellow"
onClick={handleSave}
isLoading={isLoading}
isDisabled={!apiKey.trim() && baseUrl.trim() === (provider.base_url ?? '')}
isDisabled={!apiKey.trim() && !hasBaseUrlChange}
>
{t('common.save')}
</Button>
Expand Down
Loading