From 1589fab30b555f1334598424b33e91137211d218 Mon Sep 17 00:00:00 2001 From: Peter Kirkham Date: Mon, 25 May 2026 12:11:09 +0100 Subject: [PATCH] feat(billing): threshold notification service with persisted dedupe MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Moves usage-limit detection out of the renderer into a main-process `UsageMonitorService` that polls /v1/usage every 30s, detects when a bucket newly crosses 50/75/90/100%, and emits an event through a tRPC subscription. Dedupe state lives in a persistent electron-store keyed by `${userId}:${product}:${bucket}:${anchor}:${threshold}` so crossings don't re-fire after an app relaunch. Anchors are `reset_at` rounded to the hour for burst (jitter-tolerant), and `billing_period_end` (Pro) or the date of `reset_at` (Free) for sustained. Stale entries are pruned on startup. The renderer subscribes via `initializeUsageThresholdToast` (modelled on `connectivityToast`) and shows a warning toast at 50/75/90% with a "View usage" action that opens the Plan & Usage settings. At 100% the existing `UsageLimitModal` opens if a session is active, else the user gets a blocking error toast. `useUsageLimitDetection` is deleted — the 100% path is now driven from the same subscription. The renderer holds no detection state. `toast.warning` is extended to forward an action button (the wiring already exists in `ToastComponent`). Generated-By: PostHog Code Task-Id: bac06178-1ab1-4000-9a56-1901215bd4af Generated-By: PostHog Code Task-Id: bac06178-1ab1-4000-9a56-1901215bd4af --- apps/code/src/main/di/container.ts | 2 + apps/code/src/main/di/tokens.ts | 1 + .../main/services/usage-monitor/schemas.ts | 28 +++ .../services/usage-monitor/service.test.ts | 182 +++++++++++++++ .../main/services/usage-monitor/service.ts | 215 ++++++++++++++++++ .../src/main/services/usage-monitor/store.ts | 17 ++ apps/code/src/main/trpc/router.ts | 2 + .../src/main/trpc/routers/usage-monitor.ts | 25 ++ apps/code/src/renderer/App.tsx | 7 + .../src/renderer/components/MainLayout.tsx | 2 - .../billing/hooks/useUsageLimitDetection.ts | 38 ---- .../features/billing/usageThresholdToast.ts | 65 ++++++ apps/code/src/renderer/utils/toast.tsx | 8 +- 13 files changed, 551 insertions(+), 41 deletions(-) create mode 100644 apps/code/src/main/services/usage-monitor/schemas.ts create mode 100644 apps/code/src/main/services/usage-monitor/service.test.ts create mode 100644 apps/code/src/main/services/usage-monitor/service.ts create mode 100644 apps/code/src/main/services/usage-monitor/store.ts create mode 100644 apps/code/src/main/trpc/routers/usage-monitor.ts delete mode 100644 apps/code/src/renderer/features/billing/hooks/useUsageLimitDetection.ts create mode 100644 apps/code/src/renderer/features/billing/usageThresholdToast.ts diff --git a/apps/code/src/main/di/container.ts b/apps/code/src/main/di/container.ts index 5d6a8d508..796e8486b 100644 --- a/apps/code/src/main/di/container.ts +++ b/apps/code/src/main/di/container.ts @@ -66,6 +66,7 @@ import { SuspensionService } from "../services/suspension/service"; import { TaskLinkService } from "../services/task-link/service"; import { UIService } from "../services/ui/service"; import { UpdatesService } from "../services/updates/service"; +import { UsageMonitorService } from "../services/usage-monitor/service"; import { WatcherRegistryService } from "../services/watcher-registry/service"; import { WorkspaceService } from "../services/workspace/service"; import { MAIN_TOKENS } from "./tokens"; @@ -146,6 +147,7 @@ container.bind(MAIN_TOKENS.ShellService).to(ShellService); container.bind(MAIN_TOKENS.SlackIntegrationService).to(SlackIntegrationService); container.bind(MAIN_TOKENS.UIService).to(UIService); container.bind(MAIN_TOKENS.UpdatesService).to(UpdatesService); +container.bind(MAIN_TOKENS.UsageMonitorService).to(UsageMonitorService); container.bind(MAIN_TOKENS.TaskLinkService).to(TaskLinkService); container.bind(MAIN_TOKENS.InboxLinkService).to(InboxLinkService); container.bind(MAIN_TOKENS.WatcherRegistryService).to(WatcherRegistryService); diff --git a/apps/code/src/main/di/tokens.ts b/apps/code/src/main/di/tokens.ts index aeade0e77..f0697d70e 100644 --- a/apps/code/src/main/di/tokens.ts +++ b/apps/code/src/main/di/tokens.ts @@ -82,4 +82,5 @@ export const MAIN_TOKENS = Object.freeze({ ProvisioningService: Symbol.for("Main.ProvisioningService"), WorkspaceService: Symbol.for("Main.WorkspaceService"), EnrichmentService: Symbol.for("Main.EnrichmentService"), + UsageMonitorService: Symbol.for("Main.UsageMonitorService"), }); diff --git a/apps/code/src/main/services/usage-monitor/schemas.ts b/apps/code/src/main/services/usage-monitor/schemas.ts new file mode 100644 index 000000000..2923b1cc4 --- /dev/null +++ b/apps/code/src/main/services/usage-monitor/schemas.ts @@ -0,0 +1,28 @@ +import { z } from "zod"; + +export const USAGE_THRESHOLDS = [50, 75, 90, 100] as const; +export type UsageThreshold = (typeof USAGE_THRESHOLDS)[number]; + +export const thresholdCrossedEvent = z.object({ + bucket: z.enum(["burst", "sustained"]), + threshold: z.union([ + z.literal(50), + z.literal(75), + z.literal(90), + z.literal(100), + ]), + usedPercent: z.number(), + resetAt: z.string().datetime().nullable(), + resetsInSeconds: z.number(), + isPro: z.boolean(), +}); + +export type ThresholdCrossedEvent = z.infer; + +export const UsageMonitorEvent = { + ThresholdCrossed: "threshold-crossed", +} as const; + +export interface UsageMonitorEvents { + [UsageMonitorEvent.ThresholdCrossed]: ThresholdCrossedEvent; +} diff --git a/apps/code/src/main/services/usage-monitor/service.test.ts b/apps/code/src/main/services/usage-monitor/service.test.ts new file mode 100644 index 000000000..0e9fcfcbd --- /dev/null +++ b/apps/code/src/main/services/usage-monitor/service.test.ts @@ -0,0 +1,182 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import type { UsageOutput } from "../llm-gateway/schemas"; +import { UsageMonitorEvent } from "./schemas"; + +const mockStoreGet = vi.hoisted(() => vi.fn()); +const mockStoreSet = vi.hoisted(() => vi.fn()); + +vi.mock("./store", () => ({ + usageMonitorStore: { + get: mockStoreGet, + set: mockStoreSet, + }, +})); + +vi.mock("../../utils/logger.js", () => ({ + logger: { + scope: () => ({ + info: vi.fn(), + error: vi.fn(), + warn: vi.fn(), + debug: vi.fn(), + }), + }, +})); + +import { LlmGatewayService } from "../llm-gateway/service"; +import { UsageMonitorService } from "./service"; + +function makeUsage(overrides?: { + burstPercent?: number; + sustainedPercent?: number; + billingPeriodEnd?: string | null; + burstResetAt?: string; + sustainedResetAt?: string; +}): UsageOutput { + return { + product: "posthog_code", + user_id: 42, + is_rate_limited: false, + billing_period_end: + overrides?.billingPeriodEnd === undefined + ? null + : overrides.billingPeriodEnd, + burst: { + used_percent: overrides?.burstPercent ?? 0, + resets_in_seconds: 3600, + reset_at: overrides?.burstResetAt ?? "2026-05-25T16:00:00.000Z", + exceeded: false, + }, + sustained: { + used_percent: overrides?.sustainedPercent ?? 0, + resets_in_seconds: 86400, + reset_at: overrides?.sustainedResetAt ?? "2026-06-01T00:00:00.000Z", + exceeded: false, + }, + }; +} + +function mockGateway(usage: UsageOutput | null): LlmGatewayService { + return { + fetchUsage: vi.fn().mockResolvedValue(usage), + } as unknown as LlmGatewayService; +} + +describe("UsageMonitorService", () => { + let service: UsageMonitorService; + let persisted: Record; + + beforeEach(() => { + vi.useFakeTimers(); + vi.setSystemTime(new Date("2026-05-25T12:00:00.000Z")); + persisted = {}; + mockStoreGet.mockImplementation((_key: string, fallback: unknown) => ({ + ...persisted, + ...(fallback as Record), + })); + mockStoreSet.mockImplementation( + (_key: string, value: Record) => { + persisted = { ...value }; + }, + ); + }); + + afterEach(() => { + service?.stop(); + vi.useRealTimers(); + }); + + it("emits at 75% but not again on the next poll for the same anchor", async () => { + const events: unknown[] = []; + const gateway = mockGateway(makeUsage({ burstPercent: 78 })); + service = new UsageMonitorService(gateway); + service.on(UsageMonitorEvent.ThresholdCrossed, (e) => events.push(e)); + + await service.pollOnce(); + expect(events).toHaveLength(1); + expect(events[0]).toMatchObject({ + bucket: "burst", + threshold: 75, + usedPercent: 78, + }); + + await service.pollOnce(); + expect(events).toHaveLength(1); + }); + + it("only emits the highest threshold a bucket has crossed", async () => { + const events: unknown[] = []; + const gateway = mockGateway(makeUsage({ burstPercent: 95 })); + service = new UsageMonitorService(gateway); + service.on(UsageMonitorEvent.ThresholdCrossed, (e) => events.push(e)); + + await service.pollOnce(); + expect(events).toHaveLength(1); + expect(events[0]).toMatchObject({ threshold: 90 }); + }); + + it("doesn't re-emit after a relaunch with persisted dedupe", async () => { + const events: unknown[] = []; + const gateway = mockGateway(makeUsage({ burstPercent: 55 })); + service = new UsageMonitorService(gateway); + service.on(UsageMonitorEvent.ThresholdCrossed, (e) => events.push(e)); + await service.pollOnce(); + expect(events).toHaveLength(1); + service.stop(); + + // Simulate relaunch + service = new UsageMonitorService(gateway); + service.on(UsageMonitorEvent.ThresholdCrossed, (e) => events.push(e)); + await service.pollOnce(); + expect(events).toHaveLength(1); + }); + + it("tracks burst and sustained as independent buckets", async () => { + const events: unknown[] = []; + const gateway = mockGateway( + makeUsage({ + burstPercent: 55, + sustainedPercent: 80, + billingPeriodEnd: "2026-06-01T00:00:00.000Z", + }), + ); + service = new UsageMonitorService(gateway); + service.on(UsageMonitorEvent.ThresholdCrossed, (e) => events.push(e)); + + await service.pollOnce(); + expect(events).toHaveLength(2); + expect(events.map((e) => (e as { bucket: string }).bucket).sort()).toEqual([ + "burst", + "sustained", + ]); + }); + + it("marks events with isPro when billing_period_end is set", async () => { + const events: { isPro: boolean }[] = []; + const gateway = mockGateway( + makeUsage({ + sustainedPercent: 60, + billingPeriodEnd: "2026-06-01T00:00:00.000Z", + }), + ); + service = new UsageMonitorService(gateway); + service.on(UsageMonitorEvent.ThresholdCrossed, (e) => + events.push(e as { isPro: boolean }), + ); + + await service.pollOnce(); + expect(events[0]?.isPro).toBe(true); + }); + + it("silently skips polls when the gateway throws", async () => { + const events: unknown[] = []; + const gateway = { + fetchUsage: vi.fn().mockRejectedValue(new Error("not authenticated")), + } as unknown as LlmGatewayService; + service = new UsageMonitorService(gateway); + service.on(UsageMonitorEvent.ThresholdCrossed, (e) => events.push(e)); + + await expect(service.pollOnce()).resolves.toBeNull(); + expect(events).toHaveLength(0); + }); +}); diff --git a/apps/code/src/main/services/usage-monitor/service.ts b/apps/code/src/main/services/usage-monitor/service.ts new file mode 100644 index 000000000..de6611851 --- /dev/null +++ b/apps/code/src/main/services/usage-monitor/service.ts @@ -0,0 +1,215 @@ +import { inject, injectable, postConstruct, preDestroy } from "inversify"; +import { MAIN_TOKENS } from "../../di/tokens"; +import { logger } from "../../utils/logger"; +import { TypedEventEmitter } from "../../utils/typed-event-emitter"; +import { LlmGatewayService } from "../llm-gateway/service"; +import type { UsageBucket, UsageOutput } from "../llm-gateway/schemas"; +import { + USAGE_THRESHOLDS, + UsageMonitorEvent, + type UsageMonitorEvents, + type UsageThreshold, +} from "./schemas"; +import { usageMonitorStore } from "./store"; + +const log = logger.scope("usage-monitor"); + +const POLL_INTERVAL_MS = 30_000; + +type BucketName = "burst" | "sustained"; + +@injectable() +export class UsageMonitorService extends TypedEventEmitter { + private pollTimeoutId: ReturnType | null = null; + private isPolling = false; + // Snapshot of the most recent thresholdsSeen map so we hit electron-store + // only when we actually persist a new threshold. + private thresholdsSeen: Record; + + constructor( + @inject(MAIN_TOKENS.LlmGatewayService) + private readonly llmGateway: LlmGatewayService, + ) { + super(); + this.thresholdsSeen = { ...usageMonitorStore.get("thresholdsSeen", {}) }; + } + + @postConstruct() + init(): void { + this.pruneStaleEntries(); + this.schedulePoll(POLL_INTERVAL_MS); + } + + @preDestroy() + stop(): void { + if (this.pollTimeoutId) { + clearTimeout(this.pollTimeoutId); + this.pollTimeoutId = null; + } + } + + // Exposed so tests can drive the loop deterministically. + async pollOnce(): Promise { + if (this.isPolling) return null; + this.isPolling = true; + try { + const usage = await this.fetchUsageQuietly(); + if (usage) this.processUsage(usage); + return usage; + } finally { + this.isPolling = false; + } + } + + private async fetchUsageQuietly(): Promise { + try { + return await this.llmGateway.fetchUsage(); + } catch (err) { + log.debug("Usage poll skipped", { + error: err instanceof Error ? err.message : String(err), + }); + return null; + } + } + + private schedulePoll(delayMs: number): void { + this.pollTimeoutId = setTimeout(async () => { + this.pollTimeoutId = null; + await this.pollOnce(); + this.schedulePoll(POLL_INTERVAL_MS); + }, delayMs); + } + + private processUsage(usage: UsageOutput): void { + const userId = usage.user_id.toString(); + const product = usage.product; + // Plan-key isn't on UsageOutput; the only signal we have client-side is + // whether limits are at the Pro tier — but fetchUsage doesn't return that + // either. Best-effort: assume Pro if billing_period_end is present + // (free users never have it). + const isPro = !!usage.billing_period_end; + + this.maybeEmit(usage, "burst", usage.burst, userId, product, isPro); + this.maybeEmit( + usage, + "sustained", + usage.sustained, + userId, + product, + isPro, + ); + } + + private maybeEmit( + usage: UsageOutput, + bucket: BucketName, + status: UsageBucket, + userId: string, + product: string, + isPro: boolean, + ): void { + const anchor = this.anchorFor(bucket, status, usage); + if (!anchor) return; + + const threshold = highestThresholdCrossed(status.used_percent); + if (threshold === null) return; + + const key = makeKey(userId, product, bucket, anchor, threshold); + if (this.thresholdsSeen[key]) return; + + this.thresholdsSeen[key] = anchor; + usageMonitorStore.set("thresholdsSeen", this.thresholdsSeen); + + log.info("Usage threshold crossed", { + bucket, + threshold, + usedPercent: status.used_percent, + }); + + this.emit(UsageMonitorEvent.ThresholdCrossed, { + bucket, + threshold, + usedPercent: status.used_percent, + resetAt: status.reset_at ?? null, + resetsInSeconds: status.resets_in_seconds, + isPro, + }); + } + + // Burst anchor rounds reset_at to the hour so transient TTL jitter doesn't + // make every poll look like a new window. Sustained anchor is the billing + // period end (Pro) or the reset_at ISO date (free). + private anchorFor( + bucket: BucketName, + status: UsageBucket, + usage: UsageOutput, + ): string | null { + if (bucket === "sustained") { + return ( + usage.billing_period_end ?? + sustainedFreeAnchor(status) ?? + null + ); + } + return burstAnchor(status); + } + + private pruneStaleEntries(): void { + const now = Date.now(); + let dirty = false; + for (const [key, anchor] of Object.entries(this.thresholdsSeen)) { + const parsed = Date.parse(anchor); + if (Number.isNaN(parsed) || parsed < now) { + delete this.thresholdsSeen[key]; + dirty = true; + } + } + if (dirty) { + usageMonitorStore.set("thresholdsSeen", this.thresholdsSeen); + } + } +} + +function highestThresholdCrossed(usedPercent: number): UsageThreshold | null { + for (let i = USAGE_THRESHOLDS.length - 1; i >= 0; i--) { + const t = USAGE_THRESHOLDS[i]; + if (usedPercent >= t) return t; + } + return null; +} + +function burstAnchor(status: UsageBucket): string | null { + const resetMs = resetMillis(status); + if (resetMs === null) return null; + // Round to the nearest hour so 30s polling doesn't churn the anchor. + const rounded = Math.round(resetMs / 3_600_000) * 3_600_000; + return new Date(rounded).toISOString(); +} + +function sustainedFreeAnchor(status: UsageBucket): string | null { + const resetMs = resetMillis(status); + if (resetMs === null) return null; + return new Date(resetMs).toISOString().slice(0, 10); +} + +function resetMillis(status: UsageBucket): number | null { + if (status.reset_at) { + const parsed = Date.parse(status.reset_at); + if (!Number.isNaN(parsed)) return parsed; + } + if (status.resets_in_seconds > 0) { + return Date.now() + status.resets_in_seconds * 1000; + } + return null; +} + +function makeKey( + userId: string, + product: string, + bucket: BucketName, + anchor: string, + threshold: UsageThreshold, +): string { + return `${userId}:${product}:${bucket}:${anchor}:${threshold}`; +} + diff --git a/apps/code/src/main/services/usage-monitor/store.ts b/apps/code/src/main/services/usage-monitor/store.ts new file mode 100644 index 000000000..95cc9a486 --- /dev/null +++ b/apps/code/src/main/services/usage-monitor/store.ts @@ -0,0 +1,17 @@ +import Store from "electron-store"; +import { getUserDataDir } from "../../utils/env"; + +interface UsageMonitorSchema { + // Map of dedupe-keys ⇒ ISO timestamp anchor at which the threshold was + // first fired. Stored so we don't re-toast after relaunch within the same + // billing window. Anchored entries with a past anchor are pruned on boot. + thresholdsSeen: Record; +} + +export const usageMonitorStore = new Store({ + name: "usage-monitor", + cwd: getUserDataDir(), + defaults: { + thresholdsSeen: {}, + }, +}); diff --git a/apps/code/src/main/trpc/router.ts b/apps/code/src/main/trpc/router.ts index 81fd00d4e..f0f8dd9eb 100644 --- a/apps/code/src/main/trpc/router.ts +++ b/apps/code/src/main/trpc/router.ts @@ -36,6 +36,7 @@ import { sleepRouter } from "./routers/sleep"; import { suspensionRouter } from "./routers/suspension.js"; import { uiRouter } from "./routers/ui"; import { updatesRouter } from "./routers/updates"; +import { usageMonitorRouter } from "./routers/usage-monitor"; import { workspaceRouter } from "./routers/workspace"; import { router } from "./trpc"; @@ -78,6 +79,7 @@ export const trpcRouter = router({ slackIntegration: slackIntegrationRouter, ui: uiRouter, updates: updatesRouter, + usageMonitor: usageMonitorRouter, deepLink: deepLinkRouter, workspace: workspaceRouter, }); diff --git a/apps/code/src/main/trpc/routers/usage-monitor.ts b/apps/code/src/main/trpc/routers/usage-monitor.ts new file mode 100644 index 000000000..d103612db --- /dev/null +++ b/apps/code/src/main/trpc/routers/usage-monitor.ts @@ -0,0 +1,25 @@ +import { container } from "../../di/container"; +import { MAIN_TOKENS } from "../../di/tokens"; +import { + UsageMonitorEvent, + type UsageMonitorEvents, +} from "../../services/usage-monitor/schemas"; +import type { UsageMonitorService } from "../../services/usage-monitor/service"; +import { publicProcedure, router } from "../trpc"; + +const getService = () => + container.get(MAIN_TOKENS.UsageMonitorService); + +function subscribe(event: K) { + return publicProcedure.subscription(async function* (opts) { + const service = getService(); + const iterable = service.toIterable(event, { signal: opts.signal }); + for await (const data of iterable) { + yield data; + } + }); +} + +export const usageMonitorRouter = router({ + onThresholdCrossed: subscribe(UsageMonitorEvent.ThresholdCrossed), +}); diff --git a/apps/code/src/renderer/App.tsx b/apps/code/src/renderer/App.tsx index c6595e9e2..05de56b05 100644 --- a/apps/code/src/renderer/App.tsx +++ b/apps/code/src/renderer/App.tsx @@ -12,6 +12,7 @@ import { } from "@features/auth/hooks/authQueries"; import { useAuthSession } from "@features/auth/hooks/useAuthSession"; import { useIsOrgAdmin } from "@features/auth/hooks/useOrgRole"; +import { initializeUsageThresholdToast } from "@features/billing/usageThresholdToast"; import { AddDirectoryDialog } from "@features/folder-picker/components/AddDirectoryDialog"; import { OnboardingFlow } from "@features/onboarding/components/OnboardingFlow"; import { useOnboardingStore } from "@features/onboarding/stores/onboardingStore"; @@ -63,6 +64,12 @@ function App() { }; }, []); + // Initialize usage threshold notifications (50/75/90/100%) + useEffect(() => { + if (!isAuthenticated) return; + return initializeUsageThresholdToast(); + }, [isAuthenticated]); + // Initialize update store useEffect(() => { return initializeUpdateStore(); diff --git a/apps/code/src/renderer/components/MainLayout.tsx b/apps/code/src/renderer/components/MainLayout.tsx index ce3690e68..75bec800c 100644 --- a/apps/code/src/renderer/components/MainLayout.tsx +++ b/apps/code/src/renderer/components/MainLayout.tsx @@ -5,7 +5,6 @@ import { SpaceSwitcher } from "@components/SpaceSwitcher"; import { ArchivedTasksView } from "@features/archive/components/ArchivedTasksView"; import { UsageLimitModal } from "@features/billing/components/UsageLimitModal"; -import { useUsageLimitDetection } from "@features/billing/hooks/useUsageLimitDetection"; import { CommandMenu } from "@features/command/components/CommandMenu"; import { CommandCenterView } from "@features/command-center/components/CommandCenterView"; import { InboxView } from "@features/inbox/components/InboxView"; @@ -76,7 +75,6 @@ export function MainLayout() { const activeTaskId = view.type === "task-detail" && view.data ? view.data.id : null; - useUsageLimitDetection(billingEnabled); useIntegrations(); useTaskDeepLink(); useInboxDeepLink(); diff --git a/apps/code/src/renderer/features/billing/hooks/useUsageLimitDetection.ts b/apps/code/src/renderer/features/billing/hooks/useUsageLimitDetection.ts deleted file mode 100644 index 8f7f685c8..000000000 --- a/apps/code/src/renderer/features/billing/hooks/useUsageLimitDetection.ts +++ /dev/null @@ -1,38 +0,0 @@ -import { useUsageLimitStore } from "@features/billing/stores/usageLimitStore"; -import { isUsageExceeded } from "@features/billing/utils"; -import { useSessionStore } from "@features/sessions/stores/sessionStore"; -import { useEffect, useRef } from "react"; -import { useFreeUsage } from "./useFreeUsage"; - -export function useUsageLimitDetection(billingEnabled: boolean) { - const { usage } = useFreeUsage(billingEnabled); - const hasAlertedRef = useRef(false); - - useEffect(() => { - if (!billingEnabled) { - hasAlertedRef.current = false; - } - }, [billingEnabled]); - - useEffect(() => { - if (!usage) return; - - const exceeded = isUsageExceeded(usage); - - if (exceeded && !hasAlertedRef.current) { - const sessions = useSessionStore.getState().sessions; - const hasActiveSession = Object.values(sessions).some( - (s) => s.status === "connected" && s.isPromptPending, - ); - - if (hasActiveSession) { - hasAlertedRef.current = true; - useUsageLimitStore.getState().show(); - } - } - - if (!exceeded) { - hasAlertedRef.current = false; - } - }, [usage]); -} diff --git a/apps/code/src/renderer/features/billing/usageThresholdToast.ts b/apps/code/src/renderer/features/billing/usageThresholdToast.ts new file mode 100644 index 000000000..d6b3a005c --- /dev/null +++ b/apps/code/src/renderer/features/billing/usageThresholdToast.ts @@ -0,0 +1,65 @@ +import { useUsageLimitStore } from "@features/billing/stores/usageLimitStore"; +import { formatResetTime } from "@features/billing/utils"; +import { useSessionStore } from "@features/sessions/stores/sessionStore"; +import { useSettingsDialogStore } from "@features/settings/stores/settingsDialogStore"; +import { trpcClient } from "@renderer/trpc/client"; +import { logger } from "@utils/logger"; +import { toast } from "@utils/toast"; + +const log = logger.scope("usage-threshold-toast"); + +const openPlanUsage = () => { + useSettingsDialogStore.getState().open("plan-usage"); +}; + +function hasActiveSession(): boolean { + const sessions = useSessionStore.getState().sessions; + return Object.values(sessions).some( + (s) => s.status === "connected" && s.isPromptPending, + ); +} + +export function initializeUsageThresholdToast() { + const subscription = trpcClient.usageMonitor.onThresholdCrossed.subscribe( + undefined, + { + onData: (event) => { + const resetLabel = formatResetTime( + event.resetAt ?? undefined, + event.resetsInSeconds, + ); + + if (event.threshold === 100) { + if (hasActiveSession()) { + useUsageLimitStore.getState().show(); + return; + } + toast.error("Usage limit reached", { + id: `usage-threshold-${event.bucket}-100`, + description: resetLabel, + }); + return; + } + + const limitName = + event.bucket === "burst" ? "daily limit" : "monthly limit"; + toast.warning( + `You've used ${Math.round(event.usedPercent)}% of your ${limitName}`, + { + id: `usage-threshold-${event.bucket}-${event.threshold}`, + description: resetLabel, + action: { label: "View usage", onClick: openPlanUsage }, + duration: 10_000, + }, + ); + }, + onError: (error) => { + log.error("Usage threshold subscription error", { error }); + }, + }, + ); + + return () => { + subscription.unsubscribe(); + }; +} diff --git a/apps/code/src/renderer/utils/toast.tsx b/apps/code/src/renderer/utils/toast.tsx index e1d610b58..611016d61 100644 --- a/apps/code/src/renderer/utils/toast.tsx +++ b/apps/code/src/renderer/utils/toast.tsx @@ -149,7 +149,12 @@ export const toast = { warning: ( title: ReactNode, - options?: { description?: string; id?: string | number; duration?: number }, + options?: { + description?: string; + id?: string | number; + duration?: number; + action?: ToastAction; + }, ) => { return sonnerToast.custom( (id) => ( @@ -158,6 +163,7 @@ export const toast = { type="warning" title={title} description={options?.description} + action={options?.action} /> ), { id: options?.id, duration: options?.duration },