diff --git a/apps/memos-local-plugin/core/llm/client.ts b/apps/memos-local-plugin/core/llm/client.ts index 7749a572f..a53d0efdd 100644 --- a/apps/memos-local-plugin/core/llm/client.ts +++ b/apps/memos-local-plugin/core/llm/client.ts @@ -125,6 +125,21 @@ export function createLlmClientWithProvider( return [{ role: "system", content: systemInsert }, ...messages]; } + function ensureJsonWordInUserMessage(messages: LlmMessage[]): LlmMessage[] { + const lastUserIdx = messages.map((m) => m.role).lastIndexOf("user"); + if (lastUserIdx < 0) return [...messages, { role: "user", content: "Return valid json only." }]; + + const msg = messages[lastUserIdx]; + if (/\bjson\b/i.test(msg.content)) return messages; + + const out = messages.slice(); + out[lastUserIdx] = { + ...msg, + content: `${msg.content}\n\nReturn valid json only.`, + }; + return out; + } + function buildCallInput(opts: LlmCallOptions | undefined, jsonMode: boolean): ProviderCallInput { return { temperature: opts?.temperature ?? config.temperature, @@ -339,7 +354,7 @@ export function createLlmClientWithProvider( ): Promise { const messages = normalizeMessages(input); const msgsWithJsonHint = opts?.jsonMode - ? inject(messages, buildJsonSystemHint()) + ? ensureJsonWordInUserMessage(inject(messages, buildJsonSystemHint())) : messages; const call = buildCallInput(opts, opts?.jsonMode === true); const { completion } = await callWithFallback(msgsWithJsonHint, call, opts, opts?.op ?? "complete"); @@ -352,7 +367,7 @@ export function createLlmClientWithProvider( ): Promise> { const messages = normalizeMessages(input); const systemHint = buildJsonSystemHint(opts.schemaHint); - const msgs = inject(messages, systemHint); + const msgs = ensureJsonWordInUserMessage(inject(messages, systemHint)); const call = buildCallInput(opts, true); const op = opts.op ?? "complete.json"; const maxMalformedRetries = Math.max(0, opts.malformedRetries ?? 1); diff --git a/apps/memos-local-plugin/tests/unit/llm/client.test.ts b/apps/memos-local-plugin/tests/unit/llm/client.test.ts index cd5cf6106..d378e4bf6 100644 --- a/apps/memos-local-plugin/tests/unit/llm/client.test.ts +++ b/apps/memos-local-plugin/tests/unit/llm/client.test.ts @@ -95,12 +95,14 @@ describe("llm/client", () => { expect(fake.lastMessages).toEqual([{ role: "user", content: "hi there" }]); }); - it("injects a json system hint when jsonMode=true", async () => { + it("injects json hints into system and user messages when jsonMode=true", async () => { const fake = new FakeProvider("openai_compatible", () => ({ text: '{"ok":1}', durationMs: 1 })); const client = createLlmClientWithProvider(cfg(), fake); await client.complete("do it", { jsonMode: true }); expect(fake.lastMessages?.[0]?.role).toBe("system"); expect(fake.lastMessages?.[0]?.content).toMatch(/single valid JSON value/i); + expect(fake.lastMessages?.at(-1)?.role).toBe("user"); + expect(fake.lastMessages?.at(-1)?.content).toMatch(/valid json only/i); expect(fake.lastInput?.jsonMode).toBe(true); }); @@ -269,7 +271,9 @@ describe("llm/client", () => { expect(fake.lastMessages?.[0]?.role).toBe("system"); expect(fake.lastMessages?.[0]?.content).toMatch(/You are strict\./); expect(fake.lastMessages?.[0]?.content).toMatch(/single valid JSON value/); - expect(fake.lastMessages?.[1]).toEqual({ role: "user", content: "go" }); + expect(fake.lastMessages?.[1]?.role).toBe("user"); + expect(fake.lastMessages?.[1]?.content).toMatch(/^go/); + expect(fake.lastMessages?.[1]?.content).toMatch(/valid json only/i); }); it("rejects empty messages array", async () => {