From c91f665605fe3168c619df2a1209d5edfb07a6bc Mon Sep 17 00:00:00 2001 From: Michael Doise Date: Mon, 23 Feb 2026 09:02:27 -0600 Subject: [PATCH 1/3] Add MLX KV cache reuse for incremental prefill Persist KV caches across respond()/streamResponse() calls within the same LanguageModelSession. On subsequent turns only the new tokens are prefilled instead of re-encoding the entire conversation history, dramatically reducing time to first token. - Add maxKVSize, kvBits, kvGroupSize to GenerationOptions - Add SessionCacheEntry store with NSMapTable weak keys - Implement incremental prefill in streamResponse() and respond() - Enhance prewarm() to prefill system prompt into KV cache Co-Authored-By: Claude Opus 4.6 --- .../AnyLanguageModel/GenerationOptions.swift | 30 +++- .../Models/MLXLanguageModel.swift | 150 +++++++++++++++++- 2 files changed, 172 insertions(+), 8 deletions(-) diff --git a/Sources/AnyLanguageModel/GenerationOptions.swift b/Sources/AnyLanguageModel/GenerationOptions.swift index 5d0f6338..17da525c 100644 --- a/Sources/AnyLanguageModel/GenerationOptions.swift +++ b/Sources/AnyLanguageModel/GenerationOptions.swift @@ -121,6 +121,25 @@ public struct GenerationOptions: Sendable, Equatable, Codable { /// an error will be thrown. public var maximumResponseTokens: Int? + /// Maximum number of tokens to retain in the KV cache. + /// + /// When set, uses a rotating cache that evicts oldest tokens beyond this limit. + /// When `nil` (default), the cache grows unbounded. + /// + /// Recommended values: 2048–4096 for iPhone, `nil` for Mac. + public var maxKVSize: Int? + + /// Bit width for KV cache quantization (for example, 4 or 8). + /// + /// Reduces cache memory usage at slight quality cost. + /// When `nil` (default), the cache uses full precision. + public var kvBits: Int? + + /// Group size for KV cache quantization. + /// + /// Only meaningful when ``kvBits`` is set. Default is 64. + public var kvGroupSize: Int + /// Storage for model-specific custom options. private var customOptionsStorage: CustomOptionsStorage = .init() @@ -157,14 +176,23 @@ public struct GenerationOptions: Sendable, Equatable, Codable { /// responses. Must be between `0` and `1`, inclusive. /// - maximumResponseTokens: The maximum number of tokens the model is allowed /// to produce before being artificially halted. Must be positive. + /// - maxKVSize: Maximum tokens in the KV cache. When set, enables a rotating cache. + /// - kvBits: Bit width for KV cache quantization. + /// - kvGroupSize: Group size for KV cache quantization. Default is 64. public init( sampling: SamplingMode? = nil, temperature: Double? = nil, - maximumResponseTokens: Int? = nil + maximumResponseTokens: Int? = nil, + maxKVSize: Int? = nil, + kvBits: Int? = nil, + kvGroupSize: Int = 64 ) { self.sampling = sampling self.temperature = temperature self.maximumResponseTokens = maximumResponseTokens + self.maxKVSize = maxKVSize + self.kvBits = kvBits + self.kvGroupSize = kvGroupSize } } diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index 4ffb877a..e84a6aee 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -120,6 +120,42 @@ import Foundation /// Shared cache across MLXLanguageModel instances. private nonisolated(unsafe) let modelCache = ModelContextCache(countLimit: 3) + // MARK: - Session KV Cache Store + + /// Stores a KV cache and its prefill token count for a session. + private final class SessionCacheEntry: NSObject, @unchecked Sendable { + var kvCache: [MLXLMCommon.KVCache] + var prefillTokenCount: Int + + init(kvCache: [MLXLMCommon.KVCache], prefillTokenCount: Int) { + self.kvCache = kvCache + self.prefillTokenCount = prefillTokenCount + } + } + + /// Maps LanguageModelSession (weak key) → SessionCacheEntry. + /// When a session is deallocated, its cache entry is automatically released. + private nonisolated(unsafe) let sessionKVCache = NSMapTable.weakToStrongObjects() + private let sessionKVCacheLock = NSLock() + + private func getSessionCache(_ session: LanguageModelSession) -> SessionCacheEntry? { + sessionKVCacheLock.lock() + defer { sessionKVCacheLock.unlock() } + return sessionKVCache.object(forKey: session) + } + + private func setSessionCache(_ entry: SessionCacheEntry, for session: LanguageModelSession) { + sessionKVCacheLock.lock() + defer { sessionKVCacheLock.unlock() } + sessionKVCache.setObject(entry, forKey: session) + } + + private func removeSessionCache(for session: LanguageModelSession) { + sessionKVCacheLock.lock() + defer { sessionKVCacheLock.unlock() } + sessionKVCache.removeObject(forKey: session) + } + // MARK: - MLXLanguageModel /// A language model that runs locally using MLX. @@ -228,6 +264,11 @@ import Foundation var allTextChunks: [String] = [] var allEntries: [Transcript.Entry] = [] + // Track the KV cache across the tool-calling loop. + // On the first iteration we try to reuse the session's cached KV state; + // on subsequent iterations (tool results added) we must rebuild. + var isFirstIteration = true + // Loop until no more tool calls while true { // Build user input with current chat history and tools @@ -238,9 +279,45 @@ import Foundation ) let lmInput = try await context.processor.prepare(input: userInput) + // Determine cache and input for generation + let cache: [MLXLMCommon.KVCache] + let inputForGeneration: MLXLMCommon.LMInput + + if isFirstIteration { + let existingEntry = getSessionCache(session) + let fullTokenCount = lmInput.text.tokens.dim(0) + + if let existingEntry, + existingEntry.prefillTokenCount > 0, + fullTokenCount > existingEntry.prefillTokenCount, + lmInput.image == nil + { + // Cache HIT: only prefill new tokens + let cachedCount = existingEntry.prefillTokenCount + let newTokens = lmInput.text.tokens[cachedCount...] + let partialText = MLXLMCommon.LMInput.Text(tokens: newTokens) + inputForGeneration = MLXLMCommon.LMInput(text: partialText) + cache = existingEntry.kvCache + } else { + // Cache MISS: create fresh cache + if existingEntry != nil { + removeSessionCache(for: session) + } + cache = context.model.newCache(parameters: generateParameters) + inputForGeneration = lmInput + } + } else { + // Tool-calling iterations: fresh cache (chat has been mutated) + cache = context.model.newCache(parameters: generateParameters) + inputForGeneration = lmInput + } + + isFirstIteration = false + // Generate let stream = try MLXLMCommon.generate( - input: lmInput, + input: inputForGeneration, + cache: cache, parameters: generateParameters, context: context ) @@ -259,6 +336,11 @@ import Foundation } } + // Update session cache with current offset after generation + let currentOffset = cache.first?.offset ?? 0 + let cacheEntry = SessionCacheEntry(kvCache: cache, prefillTokenCount: currentOffset) + setSessionCache(cacheEntry, for: session) + let assistantText = chunks.joined() allTextChunks.append(assistantText) @@ -344,8 +426,37 @@ import Foundation ) let lmInput = try await context.processor.prepare(input: userInput) + // Check for existing KV cache for this session + let existingEntry = getSessionCache(session) + let cache: [MLXLMCommon.KVCache] + let inputForGeneration: MLXLMCommon.LMInput + + let fullTokenCount = lmInput.text.tokens.dim(0) + + if let existingEntry, + existingEntry.prefillTokenCount > 0, + fullTokenCount > existingEntry.prefillTokenCount, + lmInput.image == nil + { + // Cache HIT: only prefill new tokens + let cachedCount = existingEntry.prefillTokenCount + let newTokens = lmInput.text.tokens[cachedCount...] + let partialText = MLXLMCommon.LMInput.Text(tokens: newTokens) + inputForGeneration = MLXLMCommon.LMInput(text: partialText) + cache = existingEntry.kvCache + } else { + // Cache MISS: create fresh cache, prefill everything + if existingEntry != nil { + removeSessionCache(for: session) + } + let newCache = context.model.newCache(parameters: generateParameters) + cache = newCache + inputForGeneration = lmInput + } + let mlxStream = try MLXLMCommon.generate( - input: lmInput, + input: inputForGeneration, + cache: cache, parameters: generateParameters, context: context ) @@ -366,6 +477,11 @@ import Foundation } } + // Update the session cache with current offset + let currentOffset = cache.first?.offset ?? 0 + let entry = SessionCacheEntry(kvCache: cache, prefillTokenCount: currentOffset) + setSessionCache(entry, for: session) + continuation.finish() } catch { continuation.finish(throwing: error) @@ -377,7 +493,7 @@ import Foundation return LanguageModelSession.ResponseStream(stream: stream) } - /// Prewarms the model + /// Prewarms the model by loading it and optionally prefilling the system prompt into a KV cache. public func prewarm( for session: LanguageModelSession, promptPrefix: Prompt? @@ -388,7 +504,27 @@ import Foundation Task { do { - _ = try await loadContext(modelId: modelId, hub: hub, directory: directory) + let context = try await loadContext(modelId: modelId, hub: hub, directory: directory) + + // Prefill the system prompt into a KV cache so the first turn is faster + if let instructions = session.instructions?.description, !instructions.isEmpty { + let params = MLXLMCommon.GenerateParameters() + let newCache = context.model.newCache(parameters: params) + let chat: [MLXLMCommon.Chat.Message] = [.init(role: .system, content: instructions)] + let userInput = MLXLMCommon.UserInput( + chat: chat, + processing: .init(resize: .init(width: 512, height: 512)), + tools: nil + ) + let lmInput = try await context.processor.prepare(input: userInput) + _ = try context.model.prepare(lmInput, cache: newCache, windowSize: nil) + + let entry = SessionCacheEntry( + kvCache: newCache, + prefillTokenCount: newCache.first?.offset ?? 0 + ) + setSessionCache(entry, for: session) + } } catch { // Ignore errors during prewarm } @@ -401,9 +537,9 @@ import Foundation private func toGenerateParameters(_ options: GenerationOptions) -> MLXLMCommon.GenerateParameters { MLXLMCommon.GenerateParameters( maxTokens: options.maximumResponseTokens, - maxKVSize: nil, - kvBits: nil, - kvGroupSize: 64, + maxKVSize: options.maxKVSize, + kvBits: options.kvBits, + kvGroupSize: options.kvGroupSize, quantizedKVStart: 0, temperature: Float(options.temperature ?? 0.6), topP: 1.0, From 2d2269e21f4a341c59cdf0c92964e376f8aecc99 Mon Sep 17 00:00:00 2001 From: Michael Doise Date: Mon, 23 Feb 2026 14:39:17 -0600 Subject: [PATCH 2/3] Add GPU memory management and upgrade mlx-swift-lm to 2.30.6 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add GPUMemoryConfiguration struct with .automatic (RAM-scaled) and .unconstrained presets for controlling Metal buffer pool limits - Add GPUMemoryManager singleton with reference-counted active/idle toggling — cache stays high during concurrent generations, drops to idle limit only when all sessions complete - Wrap respond(), streamResponse(), and prewarm() with markActive/markIdle - Call evict() on removeFromCache/removeAllFromCache to reclaim GPU buffers - Upgrade mlx-swift from 0.29.1 to 0.30.6 (fast SDPA, cache race fix, Memory API, wired memory, iPhone 16 Pro NAX fix) - Upgrade mlx-swift-lm from 2.29.3 to 2.30.6 (Gemma3n per-layer intermediate_size, model loading perf, chat rehydration, tool calling) - Migrate deprecated GPU.set(cacheLimit:)/GPU.clearCache() to Memory.* Co-Authored-By: Claude Opus 4.6 --- Package.resolved | 58 ++++--- Package.swift | 5 +- .../Models/MLXLanguageModel.swift | 161 +++++++++++++++++- 3 files changed, 200 insertions(+), 24 deletions(-) diff --git a/Package.resolved b/Package.resolved index 837d7768..1495dc88 100644 --- a/Package.resolved +++ b/Package.resolved @@ -4,10 +4,10 @@ { "identity" : "eventsource", "kind" : "remoteSourceControl", - "location" : "https://github.com/mattt/EventSource.git", + "location" : "https://github.com/mattt/EventSource", "state" : { - "revision" : "ca2a9d90cbe49e09b92f4b6ebd922c03ebea51d0", - "version" : "1.3.0" + "revision" : "bd64824505da71a1a403adb221f6e25413c0bc7f", + "version" : "1.4.0" } }, { @@ -19,22 +19,13 @@ "version" : "1.3.1" } }, - { - "identity" : "llama.swift", - "kind" : "remoteSourceControl", - "location" : "https://github.com/mattt/llama.swift", - "state" : { - "revision" : "4d57cff84ba85914baa39850157e7c27684db9c8", - "version" : "2.7966.0" - } - }, { "identity" : "mlx-swift", "kind" : "remoteSourceControl", "location" : "https://github.com/ml-explore/mlx-swift", "state" : { - "revision" : "072b684acaae80b6a463abab3a103732f33774bf", - "version" : "0.29.1" + "revision" : "6ba4827fb82c97d012eec9ab4b2de21f85c3b33d", + "version" : "0.30.6" } }, { @@ -42,8 +33,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/ml-explore/mlx-swift-lm", "state" : { - "revision" : "5064b8c5d8ed3b0bbb71385c4124f0fc102e74a2", - "version" : "2.29.3" + "revision" : "7e19e09027923d89ac47dd087d9627f610e5a91a", + "version" : "2.30.6" } }, { @@ -55,6 +46,15 @@ "version" : "1.0.0" } }, + { + "identity" : "swift-asn1", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-asn1.git", + "state" : { + "revision" : "810496cf121e525d660cd0ea89a758740476b85f", + "version" : "1.5.1" + } + }, { "identity" : "swift-collections", "kind" : "remoteSourceControl", @@ -64,13 +64,22 @@ "version" : "1.3.0" } }, + { + "identity" : "swift-crypto", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-crypto.git", + "state" : { + "revision" : "6f70fa9eab24c1fd982af18c281c4525d05e3095", + "version" : "4.2.0" + } + }, { "identity" : "swift-jinja", "kind" : "remoteSourceControl", "location" : "https://github.com/huggingface/swift-jinja.git", "state" : { - "revision" : "d81197f35f41445bc10e94600795e68c6f5e94b0", - "version" : "2.3.1" + "revision" : "f731f03bf746481d4fda07f817c3774390c4d5b9", + "version" : "2.3.2" } }, { @@ -96,8 +105,17 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/huggingface/swift-transformers", "state" : { - "revision" : "573e5c9036c2f136b3a8a071da8e8907322403d0", - "version" : "1.1.6" + "revision" : "3aecdf18e62303fb5a5543f8e87502b13474573f", + "version" : "1.1.7" + } + }, + { + "identity" : "yyjson", + "kind" : "remoteSourceControl", + "location" : "https://github.com/ibireme/yyjson.git", + "state" : { + "revision" : "8b4a38dc994a110abaec8a400615567bd996105f", + "version" : "0.12.0" } } ], diff --git a/Package.swift b/Package.swift index 3916bf01..d3a6893e 100644 --- a/Package.swift +++ b/Package.swift @@ -33,8 +33,9 @@ let package = Package( .package(url: "https://github.com/mattt/JSONSchema", from: "1.3.0"), .package(url: "https://github.com/mattt/llama.swift", .upToNextMajor(from: "2.7484.0")), .package(url: "https://github.com/mattt/PartialJSONDecoder", from: "1.0.0"), - // mlx-swift-lm must be >= 2.25.5 for ToolSpec/tool calls and UserInput(chat:processing:tools:). - .package(url: "https://github.com/ml-explore/mlx-swift-lm", from: "2.25.5"), + // mlx-swift-lm >= 2.30.3 for fast SDPA, Gemma3n per-layer intermediate_size, + // cache race fix, Memory API, and chat rehydration. >= 2.25.5 for ToolSpec/tool calls. + .package(url: "https://github.com/ml-explore/mlx-swift-lm", from: "2.30.3"), .package(url: "https://github.com/swiftlang/swift-syntax", from: "600.0.0"), ], targets: [ diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index e84a6aee..0001b022 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -18,6 +18,135 @@ import Foundation import Tokenizers import Hub + // MARK: - GPU Memory Configuration + + /// Controls Metal buffer pool behavior during and between MLX inference. + /// + /// MLX maintains a recycled buffer pool to avoid repeated Metal allocations. + /// This configuration sets the pool size during active inference (`activeCacheLimit`) + /// and between generations (`idleCacheLimit`). + /// + /// ```swift + /// // Automatic (scaled by device RAM): + /// let model = MLXLanguageModel(modelId: "...", gpuMemory: .automatic) + /// + /// // Custom: + /// let model = MLXLanguageModel(modelId: "...", gpuMemory: .init( + /// activeCacheLimit: 256_000_000, + /// idleCacheLimit: 50_000_000 + /// )) + /// ``` + public struct GPUMemoryConfiguration: Sendable, Equatable { + /// Maximum Metal buffer cache size in bytes during active inference. + public var activeCacheLimit: Int + + /// Maximum Metal buffer cache size in bytes when no inference is running. + public var idleCacheLimit: Int + + /// Whether to call `Memory.clearCache()` when a model is evicted. + public var clearCacheOnEviction: Bool + + public init( + activeCacheLimit: Int, + idleCacheLimit: Int, + clearCacheOnEviction: Bool = true + ) { + self.activeCacheLimit = activeCacheLimit + self.idleCacheLimit = idleCacheLimit + self.clearCacheOnEviction = clearCacheOnEviction + } + + /// Scaled by device RAM. Idle: 50 MB. Clear cache on eviction. + /// + /// Active limits: <4 GB → 128 MB, <6 GB → 256 MB, <8 GB → 512 MB, 8+ GB → 768 MB. + public static var automatic: GPUMemoryConfiguration { + let ramBytes = ProcessInfo.processInfo.physicalMemory + let ramGB = ramBytes / (1024 * 1024 * 1024) + + let active: Int + switch ramGB { + case ..<4: + active = 128_000_000 + case ..<6: + active = 256_000_000 + case ..<8: + active = 512_000_000 + default: + active = 768_000_000 + } + + return GPUMemoryConfiguration( + activeCacheLimit: active, + idleCacheLimit: 50_000_000, + clearCacheOnEviction: true + ) + } + + /// No management — MLX defaults, unbounded cache. + public static var unconstrained: GPUMemoryConfiguration { + GPUMemoryConfiguration( + activeCacheLimit: Int.max, + idleCacheLimit: Int.max, + clearCacheOnEviction: false + ) + } + } + + // MARK: - GPU Memory Manager + + /// Reference-counted active/idle toggling for the global Metal buffer cache. + /// + /// Multiple sessions can generate concurrently. The cache stays at `activeCacheLimit` + /// as long as ANY session is generating, and drops to `idleCacheLimit` only when ALL + /// sessions complete. + private final class GPUMemoryManager: @unchecked Sendable { + static let shared = GPUMemoryManager() + + private let lock = NSLock() + private var activeCount = 0 + private var config: GPUMemoryConfiguration = .automatic + + private init() { + Memory.cacheLimit = config.idleCacheLimit + } + + func configure(_ configuration: GPUMemoryConfiguration) { + lock.withLock { + config = configuration + if activeCount == 0 { + Memory.cacheLimit = configuration.idleCacheLimit + } + } + } + + func markActive() { + lock.withLock { + if activeCount == 0 { + Memory.cacheLimit = config.activeCacheLimit + } + activeCount += 1 + } + } + + func markIdle() { + lock.withLock { + activeCount = max(0, activeCount - 1) + if activeCount == 0 { + Memory.cacheLimit = config.idleCacheLimit + } + } + } + + func evict() { + lock.withLock { + Memory.cacheLimit = config.idleCacheLimit + if config.clearCacheOnEviction { + Memory.clearCache() + } + } + } + } + /// Wrapper to store ModelContext in NSCache (requires NSObject subclass). private final class CachedContext: NSObject, @unchecked Sendable { let context: ModelContext @@ -180,16 +309,22 @@ import Foundation /// The local directory containing the model files. public let directory: URL? + /// GPU memory management configuration for Metal buffer pools. + public let gpuMemory: GPUMemoryConfiguration + /// Creates an MLX language model. /// /// - Parameters: /// - modelId: The model identifier (for example, "mlx-community/Llama-3.2-3B-Instruct-4bit"). /// - hub: An optional Hub API instance for downloading models. If not provided, the default Hub API is used. /// - directory: An optional local directory URL containing the model files. If provided, the model is loaded from this directory instead of downloading. - public init(modelId: String, hub: HubApi? = nil, directory: URL? = nil) { + /// - gpuMemory: GPU memory configuration. Defaults to `.automatic` which scales by device RAM. + public init(modelId: String, hub: HubApi? = nil, directory: URL? = nil, gpuMemory: GPUMemoryConfiguration = .automatic) { self.modelId = modelId self.hub = hub self.directory = directory + self.gpuMemory = gpuMemory + GPUMemoryManager.shared.configure(gpuMemory) } /// Removes this model from the shared cache and cancels any in-flight load. @@ -199,11 +334,13 @@ import Foundation public func removeFromCache() async { let key = directory?.absoluteString ?? modelId await modelCache.removeAndCancel(for: key) + GPUMemoryManager.shared.evict() } /// Removes all MLX models from the shared cache and cancels in-flight loads. public static func removeAllFromCache() async { await modelCache.removeAllAndCancel() + GPUMemoryManager.shared.evict() } /// Get or load model context with caching @@ -229,6 +366,9 @@ import Foundation // Get cached or load fresh ModelContext let context = try await loadContext(modelId: modelId, hub: hub, directory: directory) + GPUMemoryManager.shared.markActive() + defer { GPUMemoryManager.shared.markIdle() } + if type != String.self { let jsonString = try await generateStructuredJSON( context: context, @@ -410,6 +550,9 @@ import Foundation let stream: AsyncThrowingStream.Snapshot, any Error> = .init { continuation in + let didMarkIdle = Locked(false) + GPUMemoryManager.shared.markActive() + let task = Task { @Sendable in do { // Get cached or load fresh ModelContext @@ -482,12 +625,23 @@ import Foundation let entry = SessionCacheEntry(kvCache: cache, prefillTokenCount: currentOffset) setSessionCache(entry, for: session) + didMarkIdle.withLock { done in + if !done { GPUMemoryManager.shared.markIdle(); done = true } + } continuation.finish() } catch { + didMarkIdle.withLock { done in + if !done { GPUMemoryManager.shared.markIdle(); done = true } + } continuation.finish(throwing: error) } } - continuation.onTermination = { _ in task.cancel() } + continuation.onTermination = { _ in + didMarkIdle.withLock { done in + if !done { GPUMemoryManager.shared.markIdle(); done = true } + } + task.cancel() + } } return LanguageModelSession.ResponseStream(stream: stream) @@ -503,6 +657,9 @@ import Foundation let directory = self.directory Task { + GPUMemoryManager.shared.markActive() + defer { GPUMemoryManager.shared.markIdle() } + do { let context = try await loadContext(modelId: modelId, hub: hub, directory: directory) From c37ab33f256c96ec0332ffbbc33ea70f843c5445 Mon Sep 17 00:00:00 2001 From: Michael Doise Date: Tue, 24 Feb 2026 20:41:06 -0600 Subject: [PATCH 3/3] Fix MLX tool calls failing due to Jinja template errors Gemma 3's Jinja chat template has no tool role support, causing tool result messages to crash the template engine during chat history replay. This fixes the issue by folding tool outputs into the preceding assistant message instead of using a separate .tool() role. Changes: - Fold tool results into assistant messages with [Tool result]: prefix to maintain strict user/assistant alternation required by Gemma 3 - Add max tool iteration guard (5) to prevent infinite tool-call loops - Fix convertToSendableJSONValue to return NSNull() instead of JSONValue.null so Jinja's Value(any:) can handle it - Check Bool before NSNumber to prevent booleans becoming integers - Record assistant text before tool calls in transcript for accurate chat replay and KV cache consistency - Move final text accumulation to after tool loop exit so only the final response is returned Fixes #112 Co-Authored-By: Claude Opus 4.6 --- .../Models/MLXLanguageModel.swift | 69 +++++++++++++++---- 1 file changed, 56 insertions(+), 13 deletions(-) diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index 107f8e84..592c58b5 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -479,8 +479,18 @@ import Foundation // on subsequent iterations (tool results added) we must rebuild. var isFirstIteration = true + // Guard against infinite tool-call loops (e.g. model keeps retrying the + // same tool call). After this many iterations, break and return whatever + // text has been accumulated. + let maxToolIterations = 5 + var toolIteration = 0 + // Loop until no more tool calls while true { + toolIteration += 1 + if toolIteration > maxToolIterations { + break + } // Build user input with current chat history and tools let userInput = MLXLMCommon.UserInput( chat: chat, @@ -552,7 +562,6 @@ import Foundation setSessionCache(cacheEntry, for: session) let assistantText = chunks.joined() - allTextChunks.append(assistantText) // Add assistant response to chat history if !assistantText.isEmpty { @@ -561,6 +570,17 @@ import Foundation // If there are tool calls, execute them and continue if !collectedToolCalls.isEmpty { + // Record the assistant text generated before the tool call + // as a transcript entry so convertTranscriptToMLXChat() can + // reproduce the exact same chat sequence on future turns + // (keeping the KV cache valid). + if !assistantText.isEmpty { + allEntries.append(.response(Transcript.Response( + assetIDs: [], + segments: [.text(.init(content: assistantText))] + ))) + } + let resolution = try await resolveToolCalls(collectedToolCalls, session: session) switch resolution { case .stop(let calls): @@ -576,13 +596,20 @@ import Foundation if !invocations.isEmpty { allEntries.append(.toolCalls(Transcript.ToolCalls(invocations.map(\.call)))) - // Execute each tool and add results to chat + // Execute each tool and fold results into the + // preceding assistant message to maintain strict + // user/assistant alternation for templates like Gemma 3. + var toolResults: [String] = [] for invocation in invocations { allEntries.append(.toolOutput(invocation.output)) - - // Convert tool output to JSON string for MLX - let toolResultJSON = toolOutputToJSON(invocation.output) - chat.append(.tool(toolResultJSON)) + toolResults.append(toolOutputToJSON(invocation.output)) + } + let combinedResults = toolResults.joined(separator: "\n") + if let lastIdx = chat.indices.last, chat[lastIdx].role == .assistant { + let existing = chat[lastIdx].content + chat[lastIdx] = .assistant(existing + "\n\n[Tool result]: " + combinedResults) + } else { + chat.append(.assistant("[Tool result]: " + combinedResults)) } // Continue loop to generate with tool results @@ -591,11 +618,13 @@ import Foundation } } - // No more tool calls, exit loop + // No more tool calls — this is the final response text + allTextChunks.append(assistantText) break } let text = allTextChunks.joined() + return LanguageModelSession.Response( content: text as! Content, rawContent: GeneratedContent(text), @@ -812,7 +841,10 @@ import Foundation chat.append(.init(role: .system, content: instructions)) } - // Convert each transcript entry + // Convert each transcript entry. + // Tool call/output entries are folded into adjacent assistant messages + // to maintain strict user/assistant alternation required by some templates + // (e.g. Gemma 3 which has no tool role support in its Jinja template). for entry in session.transcript { switch entry { case .instructions(let instr): @@ -826,12 +858,23 @@ import Foundation chat.append(.assistant(content)) case .toolCalls: - // Tool calls are handled inline during generation loop + // Skip — tool calls were already executed; the output is captured below break case .toolOutput(let toolOutput): - let content = toolOutput.segments.map { extractText(from: $0) }.joined(separator: "\n") - chat.append(.tool(content)) + // Fold tool output into the preceding assistant message to avoid + // injecting a .tool() role that breaks strict-alternation templates. + // Use toolOutputToJSON() — the same function used in the live tool + // loop — so the replayed chat produces identical tokens, keeping + // the KV cache valid for session continuation. + let content = toolOutputToJSON(toolOutput) + if let lastIdx = chat.indices.last, chat[lastIdx].role == .assistant { + let existing = chat[lastIdx].content + chat[lastIdx] = .assistant(existing + "\n\n[Tool result]: " + content) + } else { + // No preceding assistant message — wrap as assistant + chat.append(.assistant("[Tool result]: " + content)) + } } } @@ -947,7 +990,7 @@ import Foundation } private func convertToSendableJSONValue(_ value: Any) throws -> any Sendable { - if value is NSNull { return MLXLMCommon.JSONValue.null } + if value is NSNull { return NSNull() } if let stringValue = value as? String { return stringValue } if let boolValue = value as? Bool { return boolValue } if let intValue = value as? Int { return intValue } @@ -956,7 +999,7 @@ import Foundation return numberValue.doubleValue } if let arrayValue = value as? [Any] { - return try arrayValue.map { try convertToSendableJSONValue($0) } + return try arrayValue.map { try convertToSendableJSONValue($0) } as [any Sendable] } if let dictionaryValue = value as? [String: Any] { return try convertToSendableJSONObject(dictionaryValue)