diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index e13315ad918..75dbb56379c 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -227,6 +227,7 @@ "direction": "Direction", "ipAdapter": "IP Adapter", "t2iAdapter": "T2I Adapter", + "prompt": "Prompt", "positivePrompt": "Positive Prompt", "negativePrompt": "Negative Prompt", "removeNegativePrompt": "Remove Negative Prompt", diff --git a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.test.ts b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.test.ts index 24dee85a66a..d210d2fd2ac 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.test.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.test.ts @@ -9,6 +9,8 @@ import { describe, expect, it } from 'vitest'; import { paramsSliceConfig, + positivePromptAddedToHistory, + promptRemovedFromHistory, selectModelSupportsDimensions, selectModelSupportsGuidance, selectModelSupportsNegativePrompt, @@ -165,4 +167,52 @@ describe('paramsSliceConfig persisted state migration', () => { expect(result.dimensions.width).toBe(768); expect(result.dimensions.height).toBe(768); }); + + it('migrates old positive prompt history entries to prompt pairs', () => { + expect(migrate).toBeDefined(); + + const initial = getInitialParamsState(); + const v3State: Record = { + ...initial, + positivePromptHistory: ['a fluffy cat'], + }; + + const result = migrate?.(v3State) as ReturnType; + + expect(result.positivePromptHistory).toEqual([{ positivePrompt: 'a fluffy cat', negativePrompt: null }]); + }); +}); + +describe('paramsSlice prompt history', () => { + it('stores positive and negative prompts in the same history item', () => { + const initial = getInitialParamsState(); + const state = paramsSliceConfig.slice.reducer( + initial, + positivePromptAddedToHistory({ positivePrompt: ' a fluffy cat ', negativePrompt: ' blurry ' }) + ); + + expect(state.positivePromptHistory).toEqual([{ positivePrompt: 'a fluffy cat', negativePrompt: 'blurry' }]); + }); + + it('deduplicates and removes prompt history by positive and negative prompt pair', () => { + const initial = getInitialParamsState(); + const withFirstPrompt = paramsSliceConfig.slice.reducer( + initial, + positivePromptAddedToHistory({ positivePrompt: 'a cat', negativePrompt: 'blurry' }) + ); + const withSecondPrompt = paramsSliceConfig.slice.reducer( + withFirstPrompt, + positivePromptAddedToHistory({ positivePrompt: 'a cat', negativePrompt: 'low quality' }) + ); + const removed = paramsSliceConfig.slice.reducer( + withSecondPrompt, + promptRemovedFromHistory({ positivePrompt: 'a cat', negativePrompt: 'blurry' }) + ); + + expect(withSecondPrompt.positivePromptHistory).toEqual([ + { positivePrompt: 'a cat', negativePrompt: 'low quality' }, + { positivePrompt: 'a cat', negativePrompt: 'blurry' }, + ]); + expect(removed.positivePromptHistory).toEqual([{ positivePrompt: 'a cat', negativePrompt: 'low quality' }]); + }); }); diff --git a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts index a5200ef1ff8..752d7dd2837 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts @@ -7,7 +7,13 @@ import { roundDownToMultiple, roundToMultiple } from 'common/util/roundDownToMul import { isPlainObject } from 'es-toolkit'; import { clamp } from 'es-toolkit/compat'; import { logout } from 'features/auth/store/authSlice'; -import type { AspectRatioID, InfillMethod, ParamsState, RgbaColor } from 'features/controlLayers/store/types'; +import type { + AspectRatioID, + InfillMethod, + ParamsState, + PromptHistoryItem, + RgbaColor, +} from 'features/controlLayers/store/types'; import { ASPECT_RATIO_MAP, DEFAULT_ASPECT_RATIO_CONFIG, @@ -306,20 +312,33 @@ const slice = createSlice({ positivePromptChanged: (state, action: PayloadAction) => { state.positivePrompt = action.payload; }, - positivePromptAddedToHistory: (state, action: PayloadAction) => { - const prompt = action.payload.trim(); - if (prompt.length === 0) { + positivePromptAddedToHistory: (state, action: PayloadAction) => { + const prompt: PromptHistoryItem = { + positivePrompt: action.payload.positivePrompt.trim(), + negativePrompt: action.payload.negativePrompt?.trim() || null, + }; + if (prompt.positivePrompt.length === 0 && !prompt.negativePrompt) { return; } - state.positivePromptHistory = [prompt, ...state.positivePromptHistory.filter((p) => p !== prompt)]; + state.positivePromptHistory = [ + prompt, + ...state.positivePromptHistory.filter( + (p) => + p.positivePrompt !== prompt.positivePrompt || (p.negativePrompt ?? null) !== (prompt.negativePrompt ?? null) + ), + ]; if (state.positivePromptHistory.length > MAX_POSITIVE_PROMPT_HISTORY) { state.positivePromptHistory = state.positivePromptHistory.slice(0, MAX_POSITIVE_PROMPT_HISTORY); } }, - promptRemovedFromHistory: (state, action: PayloadAction) => { - state.positivePromptHistory = state.positivePromptHistory.filter((p) => p !== action.payload); + promptRemovedFromHistory: (state, action: PayloadAction) => { + state.positivePromptHistory = state.positivePromptHistory.filter( + (p) => + p.positivePrompt !== action.payload.positivePrompt || + (p.negativePrompt ?? null) !== (action.payload.negativePrompt ?? null) + ); }, promptHistoryCleared: (state) => { state.positivePromptHistory = []; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index cbeccdfa930..92b7478c0e4 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -765,8 +765,16 @@ const zDimensionsState = z.object({ }); export const MAX_POSITIVE_PROMPT_HISTORY = 100; +const zPromptHistoryItem = z.union([ + zParameterPositivePrompt.transform((positivePrompt) => ({ positivePrompt, negativePrompt: null })), + z.object({ + positivePrompt: zParameterPositivePrompt, + negativePrompt: zParameterNegativePrompt, + }), +]); +export type PromptHistoryItem = z.infer; const zPositivePromptHistory = z - .array(zParameterPositivePrompt) + .array(zPromptHistoryItem) .transform((arr) => arr.slice(0, MAX_POSITIVE_PROMPT_HISTORY)); export const zInfillMethod = z.enum(['patchmatch', 'lama', 'cv2', 'color', 'tile']); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearBatchConfig.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearBatchConfig.ts index 36b6ad4b738..3f15a8ee9c1 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearBatchConfig.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearBatchConfig.ts @@ -4,6 +4,7 @@ import { range } from 'es-toolkit/compat'; import type { SeedBehaviour } from 'features/dynamicPrompts/store/dynamicPromptsSlice'; import type { BaseModelType } from 'features/nodes/types/common'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; +import { selectPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils'; import type { components } from 'services/api/schema'; import type { Batch, EnqueueBatchArg, Invocation } from 'services/api/types'; @@ -26,11 +27,12 @@ export const prepareLinearUIBatch = (arg: { prepend: boolean; base: BaseModelType; positivePromptNode: Invocation<'string'>; + negativePromptNode?: Invocation<'string'>; seedNode?: Invocation<'integer'>; origin: string; destination: string; }): EnqueueBatchArg => { - const { state, g, base, prepend, positivePromptNode, seedNode, origin, destination } = arg; + const { state, g, base, prepend, positivePromptNode, negativePromptNode, seedNode, origin, destination } = arg; const { iterations, shouldRandomizeSeed, seed } = state.params; const { prompts, seedBehaviour } = state.dynamicPrompts; @@ -74,6 +76,15 @@ export const prepareLinearUIBatch = (arg: { items: extendedPrompts, }); + if (negativePromptNode) { + const negativePrompt = selectPresetModifiedPrompts(state).negative; + firstBatchDatumList.push({ + node_path: negativePromptNode.id, + field_name: 'value', + items: extendedPrompts.map(() => negativePrompt), + }); + } + data.push(firstBatchDatumList); // Models without a seed node (e.g. external API models without seed support) can't express diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildMultidiffusionUpscaleGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildMultidiffusionUpscaleGraph.ts index 26db510599c..8d0e38a9081 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildMultidiffusionUpscaleGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildMultidiffusionUpscaleGraph.ts @@ -7,7 +7,7 @@ import { isNonRefinerMainModelConfig, isSpandrelImageToImageModelConfig } from ' import { assert } from 'tsafe'; import { addLoRAs } from './generation/addLoRAs'; -import { getBoardField, selectPresetModifiedPrompts } from './graphBuilderUtils'; +import { getBoardField } from './graphBuilderUtils'; import type { GraphBuilderReturn } from './types'; export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise => { @@ -35,6 +35,10 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise id: getPrefixedId('positive_prompt'), type: 'string', }); + const negativePrompt = g.addNode({ + id: getPrefixedId('negative_prompt'), + type: 'string', + }); const spandrelAutoscale = g.addNode({ type: 'spandrel_image_to_image_autoscale', @@ -105,8 +109,6 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise let modelLoader; if (model.base === 'sdxl') { - const prompts = selectPresetModifiedPrompts(state); - posCond = g.addNode({ type: 'sdxl_compel_prompt', id: getPrefixedId('pos_cond'), @@ -114,8 +116,6 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise negCond = g.addNode({ type: 'sdxl_compel_prompt', id: getPrefixedId('neg_cond'), - prompt: prompts.negative, - style: prompts.negative, }); modelLoader = g.addNode({ type: 'sdxl_model_loader', @@ -131,16 +131,14 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise g.addEdge(positivePrompt, 'value', posCond, 'prompt'); g.addEdge(positivePrompt, 'value', posCond, 'style'); + g.addEdge(negativePrompt, 'value', negCond, 'prompt'); + g.addEdge(negativePrompt, 'value', negCond, 'style'); addSDXLLoRAs(state, g, tiledMultidiffusion, modelLoader, null, posCond, negCond); - g.upsertMetadata({ - negative_prompt: prompts.negative, - }); g.addEdgeToMetadata(positivePrompt, 'value', 'positive_prompt'); + g.addEdgeToMetadata(negativePrompt, 'value', 'negative_prompt'); } else { - const prompts = selectPresetModifiedPrompts(state); - posCond = g.addNode({ type: 'compel', id: getPrefixedId('pos_cond'), @@ -148,7 +146,6 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise negCond = g.addNode({ type: 'compel', id: getPrefixedId('neg_cond'), - prompt: prompts.negative, }); modelLoader = g.addNode({ type: 'main_model_loader', @@ -166,14 +163,12 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise g.addEdge(modelLoader, 'unet', tiledMultidiffusion, 'unet'); g.addEdge(positivePrompt, 'value', posCond, 'prompt'); + g.addEdge(negativePrompt, 'value', negCond, 'prompt'); addLoRAs(state, g, tiledMultidiffusion, modelLoader, null, clipSkipNode, posCond, negCond); - g.upsertMetadata({ - negative_prompt: prompts.negative, - }); - g.addEdgeToMetadata(positivePrompt, 'value', 'positive_prompt'); + g.addEdgeToMetadata(negativePrompt, 'value', 'negative_prompt'); } const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig); @@ -261,5 +256,6 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise g, seed, positivePrompt, + negativePrompt, }; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts index eae07532011..1dead116ddb 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts @@ -15,7 +15,7 @@ import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless'; import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage'; import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker'; import { Graph } from 'features/nodes/util/graph/generation/Graph'; -import { selectCanvasOutputFields, selectPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils'; +import { selectCanvasOutputFields } from 'features/nodes/util/graph/graphBuilderUtils'; import type { GraphBuilderArg, GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types'; import { selectActiveTab } from 'features/ui/store/uiSelectors'; import type { Invocation } from 'services/api/types'; @@ -51,8 +51,6 @@ export const buildSD1Graph = async (arg: GraphBuilderArg): Promise = l2i; @@ -188,5 +190,6 @@ export const buildSD3Graph = async (arg: GraphBuilderArg): Promise ({ + logger: () => ({ + debug: vi.fn(), + }), +})); + +let nextId = 0; +vi.mock('features/controlLayers/konva/util', () => ({ + getPrefixedId: (prefix: string) => `${prefix}:${nextId++}`, +})); + +const models = { + sd1: { + key: 'sd1-model', + hash: 'sd1-hash', + name: 'SD 1.5', + base: 'sd-1', + type: 'main', + }, + sd3: { + key: 'sd3-model', + hash: 'sd3-hash', + name: 'SD3', + base: 'sd-3', + type: 'main', + }, + sdxl: { + key: 'sdxl-model', + hash: 'sdxl-hash', + name: 'SDXL', + base: 'sdxl', + type: 'main', + }, +} as const; + +const upscaleModel = { + key: 'upscale-model', + hash: 'upscale-hash', + name: 'Upscale', + base: 'any', + type: 'spandrel_image_to_image', +} as const; + +const controlnetModel = { + key: 'controlnet-model', + hash: 'controlnet-hash', + name: 'Tile ControlNet', + base: 'sd-1', + type: 'controlnet', +} as const; + +vi.mock('features/controlLayers/store/paramsSlice', () => ({ + selectMainModelConfig: vi.fn((state: RootState) => state.params.model), + selectParamsSlice: vi.fn((state: RootState) => state.params), +})); + +vi.mock('features/controlLayers/store/refImagesSlice', () => ({ + selectRefImagesSlice: vi.fn(() => ({ entities: [] })), +})); + +vi.mock('features/controlLayers/store/selectors', () => ({ + selectCanvasMetadata: vi.fn(() => ({})), + selectCanvasSlice: vi.fn(() => ({ + bbox: { rect: { x: 0, y: 0, width: 1024, height: 1024 } }, + controlLayers: { entities: [] }, + regionalGuidance: { entities: [] }, + })), +})); + +vi.mock('features/metadata/util/modelFetchingHelpers', () => ({ + fetchModelConfigWithTypeGuard: vi.fn((key: string) => { + if (key === upscaleModel.key) { + return Promise.resolve(upscaleModel); + } + return Promise.resolve(models.sdxl); + }), +})); + +vi.mock('features/nodes/util/graph/generation/addControlAdapters', () => ({ + addControlNets: vi.fn(() => Promise.resolve({ addedControlNets: 0 })), + addT2IAdapters: vi.fn(() => Promise.resolve({ addedT2IAdapters: 0 })), +})); + +vi.mock('features/nodes/util/graph/generation/addImageToImage', () => ({ + addImageToImage: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addInpaint', () => ({ + addInpaint: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addIPAdapters', () => ({ + addIPAdapters: vi.fn(() => ({ addedIPAdapters: 0 })), +})); + +vi.mock('features/nodes/util/graph/generation/addLoRAs', () => ({ + addLoRAs: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addNSFWChecker', () => ({ + addNSFWChecker: vi.fn((_g, node) => node), +})); + +vi.mock('features/nodes/util/graph/generation/addOutpaint', () => ({ + addOutpaint: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addRegions', () => ({ + addRegions: vi.fn(() => []), +})); + +vi.mock('features/nodes/util/graph/generation/addSDXLLoRAs', () => ({ + addSDXLLoRAs: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addSDXLRefiner', () => ({ + addSDXLRefiner: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addSeamless', () => ({ + addSeamless: vi.fn(() => null), +})); + +vi.mock('features/nodes/util/graph/generation/addTextToImage', () => ({ + addTextToImage: vi.fn(({ l2i }) => l2i), +})); + +vi.mock('features/nodes/util/graph/generation/addWatermarker', () => ({ + addWatermarker: vi.fn((_g, node) => node), +})); + +vi.mock('features/nodes/util/graph/graphBuilderUtils', () => ({ + getBoardField: vi.fn(() => undefined), + selectCanvasOutputFields: vi.fn(() => ({})), + selectPresetModifiedPrompts: vi.fn(() => ({ + positive: 'preset positive prompt', + negative: 'preset negative prompt', + })), +})); + +vi.mock('features/ui/store/uiSelectors', () => ({ + selectActiveTab: vi.fn(() => 'generate'), +})); + +import type { BaseModelType } from 'features/nodes/types/common'; +import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig'; +import { buildMultidiffusionUpscaleGraph } from 'features/nodes/util/graph/buildMultidiffusionUpscaleGraph'; +import type { GraphBuilderReturn } from 'features/nodes/util/graph/types'; + +import { buildSD1Graph } from './buildSD1Graph'; +import { buildSD3Graph } from './buildSD3Graph'; +import { buildSDXLGraph } from './buildSDXLGraph'; +import type { GraphType } from './Graph'; + +type TestNode = { id: string; type: string; [key: string]: unknown }; + +const buildState = (model: (typeof models)[keyof typeof models]): RootState => + ({ + dynamicPrompts: { + prompts: ['positive prompt 1', 'positive prompt 2'], + seedBehaviour: 'PER_PROMPT', + }, + gallery: { + autoAddBoardId: 'none', + }, + params: { + cfgRescaleMultiplier: 0, + cfgScale: 7, + clipGEmbedModel: null, + clipLEmbedModel: null, + clipSkip: 0, + colorCompensation: false, + iterations: 2, + model, + negativePrompt: 'raw negative prompt', + positivePrompt: 'raw positive prompt', + refinerModel: null, + scheduler: 'euler', + seed: 123, + shouldRandomizeSeed: false, + shouldUseCpuNoise: true, + steps: 20, + t5EncoderModel: null, + upscaleCfgScale: 2, + upscaleScheduler: 'euler', + vae: null, + vaePrecision: 'fp32', + }, + system: { + shouldUseNSFWChecker: false, + shouldUseWatermarker: false, + }, + upscale: { + creativity: 4, + scale: 2, + structure: 5, + tileControlnetModel: controlnetModel, + tileOverlap: 64, + tileSize: 512, + upscaleInitialImage: { + height: 512, + image_name: 'initial.png', + width: 512, + }, + upscaleModel, + }, + }) as unknown as RootState; + +const getNodeByPrefix = (graph: GraphType, prefix: string) => { + const nodes = graph.nodes as Record; + return Object.entries(nodes).find(([id]) => id.startsWith(prefix))?.[1]; +}; + +const expectNegativePromptWiring = (graph: GraphType, negativePromptId: string, conditioningFields: string[]) => { + const negCond = getNodeByPrefix(graph, 'neg_cond:'); + expect(negCond).toBeDefined(); + if (!negCond) { + throw new Error('Expected negative conditioning node to exist'); + } + expect(negCond).not.toHaveProperty('prompt'); + expect(negCond).not.toHaveProperty('style'); + + for (const field of conditioningFields) { + expect(graph.edges).toContainEqual({ + destination: { field, node_id: negCond?.id }, + source: { field: 'value', node_id: negativePromptId }, + }); + } + + const nodes = graph.nodes as Record; + const metadataNode = Object.values(nodes).find((node) => node.type === 'core_metadata'); + expect(metadataNode).toBeDefined(); + if (!metadataNode) { + throw new Error('Expected metadata node to exist'); + } + expect(metadataNode).not.toHaveProperty('negative_prompt'); + expect(graph.edges).toContainEqual({ + destination: { field: 'negative_prompt', node_id: metadataNode?.id }, + source: { field: 'value', node_id: negativePromptId }, + }); +}; + +const expectNegativePromptBatching = (state: RootState, graphBuilderReturn: GraphBuilderReturn) => { + expect(graphBuilderReturn.negativePrompt).toBeDefined(); + + const batchConfig = prepareLinearUIBatch({ + state, + g: graphBuilderReturn.g, + base: (state.params.model?.base ?? 'sdxl') as BaseModelType, + prepend: false, + seedNode: graphBuilderReturn.seed, + positivePromptNode: graphBuilderReturn.positivePrompt, + negativePromptNode: graphBuilderReturn.negativePrompt, + origin: 'test', + destination: 'test', + }); + + const negativePromptBatchDatum = batchConfig.batch.data + ?.flat() + .find((datum) => datum.node_path === graphBuilderReturn.negativePrompt?.id); + + expect(negativePromptBatchDatum).toEqual({ + field_name: 'value', + items: ['preset negative prompt', 'preset negative prompt', 'preset negative prompt', 'preset negative prompt'], + node_path: graphBuilderReturn.negativePrompt?.id, + }); +}; + +beforeEach(() => { + nextId = 0; +}); + +describe('SD negative prompt graph wiring', () => { + it('wires SD1 negative prompt through a string node into conditioning, metadata, and batch data', async () => { + const state = buildState(models.sd1); + const result = await buildSD1Graph({ generationMode: 'txt2img', manager: null, state }); + const graph = result.g.getGraph(); + + expect(result.negativePrompt).toBeDefined(); + expectNegativePromptWiring(graph, result.negativePrompt?.id ?? '', ['prompt']); + expectNegativePromptBatching(state, result); + }); + + it('wires SDXL negative prompt through a string node into conditioning, metadata, and batch data', async () => { + const state = buildState(models.sdxl); + const result = await buildSDXLGraph({ generationMode: 'txt2img', manager: null, state }); + const graph = result.g.getGraph(); + + expect(result.negativePrompt).toBeDefined(); + expectNegativePromptWiring(graph, result.negativePrompt?.id ?? '', ['prompt', 'style']); + expectNegativePromptBatching(state, result); + }); + + it('wires SD3 negative prompt through a string node into conditioning, metadata, and batch data', async () => { + const state = buildState(models.sd3); + const result = await buildSD3Graph({ generationMode: 'txt2img', manager: null, state }); + const graph = result.g.getGraph(); + + expect(result.negativePrompt).toBeDefined(); + expectNegativePromptWiring(graph, result.negativePrompt?.id ?? '', ['prompt']); + expectNegativePromptBatching(state, result); + }); + + it.each([ + { conditioningFields: ['prompt'], model: models.sd1 }, + { conditioningFields: ['prompt', 'style'], model: models.sdxl }, + ])( + 'wires $model.base multidiffusion upscale negative prompt through a string node into conditioning, metadata, and batch data', + async ({ conditioningFields, model }) => { + const state = buildState(model); + const result = await buildMultidiffusionUpscaleGraph(state); + const graph = result.g.getGraph(); + + expect(result.negativePrompt).toBeDefined(); + expectNegativePromptWiring(graph, result.negativePrompt?.id ?? '', conditioningFields); + expectNegativePromptBatching(state, result); + } + ); +}); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts index 9d65076a70d..f31c42ee561 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts @@ -15,7 +15,7 @@ import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless'; import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage'; import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker'; import { Graph } from 'features/nodes/util/graph/generation/Graph'; -import { selectCanvasOutputFields, selectPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils'; +import { selectCanvasOutputFields } from 'features/nodes/util/graph/graphBuilderUtils'; import type { GraphBuilderArg, GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types'; import { selectActiveTab } from 'features/ui/store/uiSelectors'; import type { Invocation } from 'services/api/types'; @@ -53,8 +53,6 @@ export const buildSDXLGraph = async (arg: GraphBuilderArg): Promise; positivePrompt: Invocation<'string'>; + negativePrompt?: Invocation<'string'>; }; export class UnsupportedGenerationModeError extends Error { diff --git a/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx b/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx index 955eeb4abf0..3c36ab1e791 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx @@ -4,8 +4,10 @@ import { Box, Flex, Textarea } from '@invoke-ai/ui-library'; import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks'; import { usePersistedTextAreaSize } from 'common/hooks/usePersistedTextareaSize'; import { + negativePromptChanged, positivePromptChanged, selectModelSupportsNegativePrompt, + selectNegativePrompt, selectPositivePrompt, selectPositivePromptHistory, } from 'features/controlLayers/store/paramsSlice'; @@ -62,7 +64,10 @@ const usePromptHistory = () => { * When we are moving thru history, we will always have a stashedPrompt (the prompt before we started browsing) * and a historyIdx which is an index into the history array (0 = most recent, 1 = previous, etc). */ - const stateRef = useRef<{ stashedPrompt: string; historyIdx: number } | null>(null); + const stateRef = useRef<{ + stashedPrompts: { positivePrompt: string; negativePrompt: string | null }; + historyIdx: number; + } | null>(null); const prev = useCallback(() => { if (history.length === 0) { @@ -72,8 +77,15 @@ const usePromptHistory = () => { let state = stateRef.current; if (!state) { // First time going "back" in history, init state - state = { stashedPrompt: selectPositivePrompt(store.getState()), historyIdx: 0 }; - stateRef.current = state; + const currentState = store.getState(); + stateRef.current = { + stashedPrompts: { + positivePrompt: selectPositivePrompt(currentState), + negativePrompt: selectNegativePrompt(currentState), + }, + historyIdx: 0, + }; + state = stateRef.current; } else { // Subsequent "back" in history, increment index if (state.historyIdx === history.length - 1) { @@ -83,12 +95,13 @@ const usePromptHistory = () => { state.historyIdx = state.historyIdx + 1; } // We should go "back" in history - const newPrompt = history[state.historyIdx]; - if (newPrompt === undefined) { + const newPrompts = history[state.historyIdx]; + if (newPrompts === undefined) { // Shouldn't happen return; } - store.dispatch(positivePromptChanged(newPrompt)); + store.dispatch(positivePromptChanged(newPrompts.positivePrompt)); + store.dispatch(negativePromptChanged(newPrompts.negativePrompt)); }, [history, store]); const next = useCallback(() => { if (history.length === 0) { @@ -103,18 +116,20 @@ const usePromptHistory = () => { state.historyIdx = state.historyIdx - 1; if (state.historyIdx < 0) { // Overshot to the "current" stashed prompt - store.dispatch(positivePromptChanged(state.stashedPrompt)); + store.dispatch(positivePromptChanged(state.stashedPrompts.positivePrompt)); + store.dispatch(negativePromptChanged(state.stashedPrompts.negativePrompt)); // Clear state bc we're back to current prompt stateRef.current = null; return; } // We should go "forward" in history - const newPrompt = history[state.historyIdx]; - if (newPrompt === undefined) { + const newPrompts = history[state.historyIdx]; + if (newPrompts === undefined) { // Shouldn't happen return; } - store.dispatch(positivePromptChanged(newPrompt)); + store.dispatch(positivePromptChanged(newPrompts.positivePrompt)); + store.dispatch(negativePromptChanged(newPrompts.negativePrompt)); }, [history, store]); const reset = useCallback(() => { // Clear stashed state - used when user clicks away or types in the prompt box diff --git a/invokeai/frontend/web/src/features/parameters/components/Core/PositivePromptHistory.tsx b/invokeai/frontend/web/src/features/parameters/components/Core/PositivePromptHistory.tsx index eb0f533e0a3..a7e2b4ce407 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Core/PositivePromptHistory.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Core/PositivePromptHistory.tsx @@ -16,11 +16,13 @@ import { import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; import { + negativePromptChanged, positivePromptChanged, promptHistoryCleared, promptRemovedFromHistory, selectPositivePromptHistory, } from 'features/controlLayers/store/paramsSlice'; +import type { PromptHistoryItem } from 'features/controlLayers/store/types'; import type { ChangeEvent } from 'react'; import { memo, useCallback, useMemo, useState } from 'react'; import { useTranslation } from 'react-i18next'; @@ -67,7 +69,12 @@ const PromptHistoryContent = memo(() => { if (!trimmedSearchTerm) { return positivePromptHistory; } - return positivePromptHistory.filter((prompt) => prompt.toLowerCase().includes(trimmedSearchTerm.toLowerCase())); + const searchTermLower = trimmedSearchTerm.toLowerCase(); + return positivePromptHistory.filter( + (prompt) => + prompt.positivePrompt.toLowerCase().includes(searchTermLower) || + (prompt.negativePrompt ?? '').toLowerCase().includes(searchTermLower) + ); }, [positivePromptHistory, searchTerm]); const onChangeSearchTerm = useCallback((e: ChangeEvent) => { @@ -115,7 +122,7 @@ const PromptHistoryContent = memo(() => { {filteredPrompts.map((prompt, index) => ( - + ))} @@ -131,13 +138,14 @@ const PromptHistoryContent = memo(() => { }); PromptHistoryContent.displayName = 'PromptHistoryContent'; -const PromptItem = memo(({ prompt }: { prompt: string }) => { +const PromptItem = memo(({ prompt }: { prompt: PromptHistoryItem }) => { const { t } = useTranslation(); const dispatch = useAppDispatch(); const shiftKey = useShiftModifier(); const onClickUse = useCallback(() => { - dispatch(positivePromptChanged(prompt)); + dispatch(positivePromptChanged(prompt.positivePrompt)); + dispatch(negativePromptChanged(prompt.negativePrompt)); }, [dispatch, prompt]); const onClickDelete = useCallback(() => { @@ -165,7 +173,24 @@ const PromptItem = memo(({ prompt }: { prompt: string }) => { colorScheme="error" /> )} - {prompt} + + {prompt.positivePrompt && ( + + + {t('common.prompt')}: + {' '} + {prompt.positivePrompt} + + )} + {prompt.negativePrompt && ( + + + {t('common.negativePrompt')}: + {' '} + {prompt.negativePrompt} + + )} + ); }); diff --git a/invokeai/frontend/web/src/features/queue/hooks/useEnqueueCanvas.ts b/invokeai/frontend/web/src/features/queue/hooks/useEnqueueCanvas.ts index 68e1e9a382e..1229371b6e8 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useEnqueueCanvas.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useEnqueueCanvas.ts @@ -6,7 +6,11 @@ import { extractMessageFromAssertionError } from 'common/util/extractMessageFrom import { withResult, withResultAsync } from 'common/util/result'; import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; -import { positivePromptAddedToHistory, selectPositivePrompt } from 'features/controlLayers/store/paramsSlice'; +import { + positivePromptAddedToHistory, + selectNegativePrompt, + selectPositivePrompt, +} from 'features/controlLayers/store/paramsSlice'; import type { BaseModelType } from 'features/nodes/types/common'; import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig'; import { buildAnimaGraph } from 'features/nodes/util/graph/generation/buildAnimaGraph'; @@ -95,7 +99,7 @@ const enqueueCanvas = async (store: AppStore, canvasManager: CanvasManager, prep return; } - const { g, seed, positivePrompt } = buildGraphResult.value; + const { g, seed, positivePrompt, negativePrompt } = buildGraphResult.value; const prepareBatchResult = withResult(() => prepareLinearUIBatch({ @@ -105,6 +109,7 @@ const enqueueCanvas = async (store: AppStore, canvasManager: CanvasManager, prep prepend, seedNode: seed, positivePromptNode: positivePrompt, + negativePromptNode: negativePrompt, origin: 'canvas', destination, }) @@ -127,7 +132,12 @@ const enqueueCanvas = async (store: AppStore, canvasManager: CanvasManager, prep const enqueueResult = await req.unwrap(); // Push to prompt history on successful enqueue - dispatch(positivePromptAddedToHistory(selectPositivePrompt(state))); + dispatch( + positivePromptAddedToHistory({ + positivePrompt: selectPositivePrompt(state), + negativePrompt: selectNegativePrompt(state), + }) + ); return { batchConfig, enqueueResult }; }; diff --git a/invokeai/frontend/web/src/features/queue/hooks/useEnqueueGenerate.ts b/invokeai/frontend/web/src/features/queue/hooks/useEnqueueGenerate.ts index 54b37e1b95e..8b0c30d924f 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useEnqueueGenerate.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useEnqueueGenerate.ts @@ -4,7 +4,11 @@ import type { AppStore } from 'app/store/store'; import { useAppStore } from 'app/store/storeHooks'; import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError'; import { withResult, withResultAsync } from 'common/util/result'; -import { positivePromptAddedToHistory, selectPositivePrompt } from 'features/controlLayers/store/paramsSlice'; +import { + positivePromptAddedToHistory, + selectNegativePrompt, + selectPositivePrompt, +} from 'features/controlLayers/store/paramsSlice'; import type { BaseModelType } from 'features/nodes/types/common'; import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig'; import { buildAnimaGraph } from 'features/nodes/util/graph/generation/buildAnimaGraph'; @@ -88,7 +92,7 @@ const enqueueGenerate = async (store: AppStore, prepend: boolean) => { return; } - const { g, seed, positivePrompt } = buildGraphResult.value; + const { g, seed, positivePrompt, negativePrompt } = buildGraphResult.value; const prepareBatchResult = withResult(() => prepareLinearUIBatch({ @@ -98,6 +102,7 @@ const enqueueGenerate = async (store: AppStore, prepend: boolean) => { prepend, seedNode: seed, positivePromptNode: positivePrompt, + negativePromptNode: negativePrompt, origin: 'generate', destination: 'generate', }) @@ -120,7 +125,12 @@ const enqueueGenerate = async (store: AppStore, prepend: boolean) => { const enqueueResult = await req.unwrap(); // Push to prompt history on successful enqueue - dispatch(positivePromptAddedToHistory(selectPositivePrompt(state))); + dispatch( + positivePromptAddedToHistory({ + positivePrompt: selectPositivePrompt(state), + negativePrompt: selectNegativePrompt(state), + }) + ); return { batchConfig, enqueueResult }; }; diff --git a/invokeai/frontend/web/src/features/queue/hooks/useEnqueueUpscaling.ts b/invokeai/frontend/web/src/features/queue/hooks/useEnqueueUpscaling.ts index 318c1fd5ff0..89820005cd1 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useEnqueueUpscaling.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useEnqueueUpscaling.ts @@ -1,7 +1,11 @@ import { logger } from 'app/logging/logger'; import type { AppStore } from 'app/store/store'; import { useAppStore } from 'app/store/storeHooks'; -import { positivePromptAddedToHistory, selectPositivePrompt } from 'features/controlLayers/store/paramsSlice'; +import { + positivePromptAddedToHistory, + selectNegativePrompt, + selectPositivePrompt, +} from 'features/controlLayers/store/paramsSlice'; import type { BaseModelType } from 'features/nodes/types/common'; import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig'; import { buildMultidiffusionUpscaleGraph } from 'features/nodes/util/graph/buildMultidiffusionUpscaleGraph'; @@ -22,7 +26,7 @@ const enqueueUpscaling = async (store: AppStore, prepend: boolean) => { } const base = model.base; - const { g, seed, positivePrompt } = await buildMultidiffusionUpscaleGraph(state); + const { g, seed, positivePrompt, negativePrompt } = await buildMultidiffusionUpscaleGraph(state); const batchConfig = prepareLinearUIBatch({ state, @@ -31,6 +35,7 @@ const enqueueUpscaling = async (store: AppStore, prepend: boolean) => { prepend, seedNode: seed, positivePromptNode: positivePrompt, + negativePromptNode: negativePrompt, origin: 'upscaling', destination: 'gallery', }); @@ -41,7 +46,12 @@ const enqueueUpscaling = async (store: AppStore, prepend: boolean) => { const enqueueResult = await req.unwrap(); // Push to prompt history on successful enqueue - dispatch(positivePromptAddedToHistory(selectPositivePrompt(state))); + dispatch( + positivePromptAddedToHistory({ + positivePrompt: selectPositivePrompt(state), + negativePrompt: selectNegativePrompt(state), + }) + ); return { batchConfig, enqueueResult }; };