From c91f665605fe3168c619df2a1209d5edfb07a6bc Mon Sep 17 00:00:00 2001 From: Michael Doise Date: Mon, 23 Feb 2026 09:02:27 -0600 Subject: [PATCH 1/7] Add MLX KV cache reuse for incremental prefill Persist KV caches across respond()/streamResponse() calls within the same LanguageModelSession. On subsequent turns only the new tokens are prefilled instead of re-encoding the entire conversation history, dramatically reducing time to first token. - Add maxKVSize, kvBits, kvGroupSize to GenerationOptions - Add SessionCacheEntry store with NSMapTable weak keys - Implement incremental prefill in streamResponse() and respond() - Enhance prewarm() to prefill system prompt into KV cache Co-Authored-By: Claude Opus 4.6 --- .../AnyLanguageModel/GenerationOptions.swift | 30 +++- .../Models/MLXLanguageModel.swift | 150 +++++++++++++++++- 2 files changed, 172 insertions(+), 8 deletions(-) diff --git a/Sources/AnyLanguageModel/GenerationOptions.swift b/Sources/AnyLanguageModel/GenerationOptions.swift index 5d0f6338..17da525c 100644 --- a/Sources/AnyLanguageModel/GenerationOptions.swift +++ b/Sources/AnyLanguageModel/GenerationOptions.swift @@ -121,6 +121,25 @@ public struct GenerationOptions: Sendable, Equatable, Codable { /// an error will be thrown. public var maximumResponseTokens: Int? + /// Maximum number of tokens to retain in the KV cache. + /// + /// When set, uses a rotating cache that evicts oldest tokens beyond this limit. + /// When `nil` (default), the cache grows unbounded. + /// + /// Recommended values: 2048–4096 for iPhone, `nil` for Mac. + public var maxKVSize: Int? + + /// Bit width for KV cache quantization (for example, 4 or 8). + /// + /// Reduces cache memory usage at slight quality cost. + /// When `nil` (default), the cache uses full precision. + public var kvBits: Int? + + /// Group size for KV cache quantization. + /// + /// Only meaningful when ``kvBits`` is set. Default is 64. + public var kvGroupSize: Int + /// Storage for model-specific custom options. private var customOptionsStorage: CustomOptionsStorage = .init() @@ -157,14 +176,23 @@ public struct GenerationOptions: Sendable, Equatable, Codable { /// responses. Must be between `0` and `1`, inclusive. /// - maximumResponseTokens: The maximum number of tokens the model is allowed /// to produce before being artificially halted. Must be positive. + /// - maxKVSize: Maximum tokens in the KV cache. When set, enables a rotating cache. + /// - kvBits: Bit width for KV cache quantization. + /// - kvGroupSize: Group size for KV cache quantization. Default is 64. public init( sampling: SamplingMode? = nil, temperature: Double? = nil, - maximumResponseTokens: Int? = nil + maximumResponseTokens: Int? = nil, + maxKVSize: Int? = nil, + kvBits: Int? = nil, + kvGroupSize: Int = 64 ) { self.sampling = sampling self.temperature = temperature self.maximumResponseTokens = maximumResponseTokens + self.maxKVSize = maxKVSize + self.kvBits = kvBits + self.kvGroupSize = kvGroupSize } } diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index 4ffb877a..e84a6aee 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -120,6 +120,42 @@ import Foundation /// Shared cache across MLXLanguageModel instances. private nonisolated(unsafe) let modelCache = ModelContextCache(countLimit: 3) + // MARK: - Session KV Cache Store + + /// Stores a KV cache and its prefill token count for a session. + private final class SessionCacheEntry: NSObject, @unchecked Sendable { + var kvCache: [MLXLMCommon.KVCache] + var prefillTokenCount: Int + + init(kvCache: [MLXLMCommon.KVCache], prefillTokenCount: Int) { + self.kvCache = kvCache + self.prefillTokenCount = prefillTokenCount + } + } + + /// Maps LanguageModelSession (weak key) → SessionCacheEntry. + /// When a session is deallocated, its cache entry is automatically released. + private nonisolated(unsafe) let sessionKVCache = NSMapTable.weakToStrongObjects() + private let sessionKVCacheLock = NSLock() + + private func getSessionCache(_ session: LanguageModelSession) -> SessionCacheEntry? { + sessionKVCacheLock.lock() + defer { sessionKVCacheLock.unlock() } + return sessionKVCache.object(forKey: session) + } + + private func setSessionCache(_ entry: SessionCacheEntry, for session: LanguageModelSession) { + sessionKVCacheLock.lock() + defer { sessionKVCacheLock.unlock() } + sessionKVCache.setObject(entry, forKey: session) + } + + private func removeSessionCache(for session: LanguageModelSession) { + sessionKVCacheLock.lock() + defer { sessionKVCacheLock.unlock() } + sessionKVCache.removeObject(forKey: session) + } + // MARK: - MLXLanguageModel /// A language model that runs locally using MLX. @@ -228,6 +264,11 @@ import Foundation var allTextChunks: [String] = [] var allEntries: [Transcript.Entry] = [] + // Track the KV cache across the tool-calling loop. + // On the first iteration we try to reuse the session's cached KV state; + // on subsequent iterations (tool results added) we must rebuild. + var isFirstIteration = true + // Loop until no more tool calls while true { // Build user input with current chat history and tools @@ -238,9 +279,45 @@ import Foundation ) let lmInput = try await context.processor.prepare(input: userInput) + // Determine cache and input for generation + let cache: [MLXLMCommon.KVCache] + let inputForGeneration: MLXLMCommon.LMInput + + if isFirstIteration { + let existingEntry = getSessionCache(session) + let fullTokenCount = lmInput.text.tokens.dim(0) + + if let existingEntry, + existingEntry.prefillTokenCount > 0, + fullTokenCount > existingEntry.prefillTokenCount, + lmInput.image == nil + { + // Cache HIT: only prefill new tokens + let cachedCount = existingEntry.prefillTokenCount + let newTokens = lmInput.text.tokens[cachedCount...] + let partialText = MLXLMCommon.LMInput.Text(tokens: newTokens) + inputForGeneration = MLXLMCommon.LMInput(text: partialText) + cache = existingEntry.kvCache + } else { + // Cache MISS: create fresh cache + if existingEntry != nil { + removeSessionCache(for: session) + } + cache = context.model.newCache(parameters: generateParameters) + inputForGeneration = lmInput + } + } else { + // Tool-calling iterations: fresh cache (chat has been mutated) + cache = context.model.newCache(parameters: generateParameters) + inputForGeneration = lmInput + } + + isFirstIteration = false + // Generate let stream = try MLXLMCommon.generate( - input: lmInput, + input: inputForGeneration, + cache: cache, parameters: generateParameters, context: context ) @@ -259,6 +336,11 @@ import Foundation } } + // Update session cache with current offset after generation + let currentOffset = cache.first?.offset ?? 0 + let cacheEntry = SessionCacheEntry(kvCache: cache, prefillTokenCount: currentOffset) + setSessionCache(cacheEntry, for: session) + let assistantText = chunks.joined() allTextChunks.append(assistantText) @@ -344,8 +426,37 @@ import Foundation ) let lmInput = try await context.processor.prepare(input: userInput) + // Check for existing KV cache for this session + let existingEntry = getSessionCache(session) + let cache: [MLXLMCommon.KVCache] + let inputForGeneration: MLXLMCommon.LMInput + + let fullTokenCount = lmInput.text.tokens.dim(0) + + if let existingEntry, + existingEntry.prefillTokenCount > 0, + fullTokenCount > existingEntry.prefillTokenCount, + lmInput.image == nil + { + // Cache HIT: only prefill new tokens + let cachedCount = existingEntry.prefillTokenCount + let newTokens = lmInput.text.tokens[cachedCount...] + let partialText = MLXLMCommon.LMInput.Text(tokens: newTokens) + inputForGeneration = MLXLMCommon.LMInput(text: partialText) + cache = existingEntry.kvCache + } else { + // Cache MISS: create fresh cache, prefill everything + if existingEntry != nil { + removeSessionCache(for: session) + } + let newCache = context.model.newCache(parameters: generateParameters) + cache = newCache + inputForGeneration = lmInput + } + let mlxStream = try MLXLMCommon.generate( - input: lmInput, + input: inputForGeneration, + cache: cache, parameters: generateParameters, context: context ) @@ -366,6 +477,11 @@ import Foundation } } + // Update the session cache with current offset + let currentOffset = cache.first?.offset ?? 0 + let entry = SessionCacheEntry(kvCache: cache, prefillTokenCount: currentOffset) + setSessionCache(entry, for: session) + continuation.finish() } catch { continuation.finish(throwing: error) @@ -377,7 +493,7 @@ import Foundation return LanguageModelSession.ResponseStream(stream: stream) } - /// Prewarms the model + /// Prewarms the model by loading it and optionally prefilling the system prompt into a KV cache. public func prewarm( for session: LanguageModelSession, promptPrefix: Prompt? @@ -388,7 +504,27 @@ import Foundation Task { do { - _ = try await loadContext(modelId: modelId, hub: hub, directory: directory) + let context = try await loadContext(modelId: modelId, hub: hub, directory: directory) + + // Prefill the system prompt into a KV cache so the first turn is faster + if let instructions = session.instructions?.description, !instructions.isEmpty { + let params = MLXLMCommon.GenerateParameters() + let newCache = context.model.newCache(parameters: params) + let chat: [MLXLMCommon.Chat.Message] = [.init(role: .system, content: instructions)] + let userInput = MLXLMCommon.UserInput( + chat: chat, + processing: .init(resize: .init(width: 512, height: 512)), + tools: nil + ) + let lmInput = try await context.processor.prepare(input: userInput) + _ = try context.model.prepare(lmInput, cache: newCache, windowSize: nil) + + let entry = SessionCacheEntry( + kvCache: newCache, + prefillTokenCount: newCache.first?.offset ?? 0 + ) + setSessionCache(entry, for: session) + } } catch { // Ignore errors during prewarm } @@ -401,9 +537,9 @@ import Foundation private func toGenerateParameters(_ options: GenerationOptions) -> MLXLMCommon.GenerateParameters { MLXLMCommon.GenerateParameters( maxTokens: options.maximumResponseTokens, - maxKVSize: nil, - kvBits: nil, - kvGroupSize: 64, + maxKVSize: options.maxKVSize, + kvBits: options.kvBits, + kvGroupSize: options.kvGroupSize, quantizedKVStart: 0, temperature: Float(options.temperature ?? 0.6), topP: 1.0, From 2d2269e21f4a341c59cdf0c92964e376f8aecc99 Mon Sep 17 00:00:00 2001 From: Michael Doise Date: Mon, 23 Feb 2026 14:39:17 -0600 Subject: [PATCH 2/7] Add GPU memory management and upgrade mlx-swift-lm to 2.30.6 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add GPUMemoryConfiguration struct with .automatic (RAM-scaled) and .unconstrained presets for controlling Metal buffer pool limits - Add GPUMemoryManager singleton with reference-counted active/idle toggling — cache stays high during concurrent generations, drops to idle limit only when all sessions complete - Wrap respond(), streamResponse(), and prewarm() with markActive/markIdle - Call evict() on removeFromCache/removeAllFromCache to reclaim GPU buffers - Upgrade mlx-swift from 0.29.1 to 0.30.6 (fast SDPA, cache race fix, Memory API, wired memory, iPhone 16 Pro NAX fix) - Upgrade mlx-swift-lm from 2.29.3 to 2.30.6 (Gemma3n per-layer intermediate_size, model loading perf, chat rehydration, tool calling) - Migrate deprecated GPU.set(cacheLimit:)/GPU.clearCache() to Memory.* Co-Authored-By: Claude Opus 4.6 --- Package.resolved | 58 ++++--- Package.swift | 5 +- .../Models/MLXLanguageModel.swift | 161 +++++++++++++++++- 3 files changed, 200 insertions(+), 24 deletions(-) diff --git a/Package.resolved b/Package.resolved index 837d7768..1495dc88 100644 --- a/Package.resolved +++ b/Package.resolved @@ -4,10 +4,10 @@ { "identity" : "eventsource", "kind" : "remoteSourceControl", - "location" : "https://github.com/mattt/EventSource.git", + "location" : "https://github.com/mattt/EventSource", "state" : { - "revision" : "ca2a9d90cbe49e09b92f4b6ebd922c03ebea51d0", - "version" : "1.3.0" + "revision" : "bd64824505da71a1a403adb221f6e25413c0bc7f", + "version" : "1.4.0" } }, { @@ -19,22 +19,13 @@ "version" : "1.3.1" } }, - { - "identity" : "llama.swift", - "kind" : "remoteSourceControl", - "location" : "https://github.com/mattt/llama.swift", - "state" : { - "revision" : "4d57cff84ba85914baa39850157e7c27684db9c8", - "version" : "2.7966.0" - } - }, { "identity" : "mlx-swift", "kind" : "remoteSourceControl", "location" : "https://github.com/ml-explore/mlx-swift", "state" : { - "revision" : "072b684acaae80b6a463abab3a103732f33774bf", - "version" : "0.29.1" + "revision" : "6ba4827fb82c97d012eec9ab4b2de21f85c3b33d", + "version" : "0.30.6" } }, { @@ -42,8 +33,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/ml-explore/mlx-swift-lm", "state" : { - "revision" : "5064b8c5d8ed3b0bbb71385c4124f0fc102e74a2", - "version" : "2.29.3" + "revision" : "7e19e09027923d89ac47dd087d9627f610e5a91a", + "version" : "2.30.6" } }, { @@ -55,6 +46,15 @@ "version" : "1.0.0" } }, + { + "identity" : "swift-asn1", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-asn1.git", + "state" : { + "revision" : "810496cf121e525d660cd0ea89a758740476b85f", + "version" : "1.5.1" + } + }, { "identity" : "swift-collections", "kind" : "remoteSourceControl", @@ -64,13 +64,22 @@ "version" : "1.3.0" } }, + { + "identity" : "swift-crypto", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-crypto.git", + "state" : { + "revision" : "6f70fa9eab24c1fd982af18c281c4525d05e3095", + "version" : "4.2.0" + } + }, { "identity" : "swift-jinja", "kind" : "remoteSourceControl", "location" : "https://github.com/huggingface/swift-jinja.git", "state" : { - "revision" : "d81197f35f41445bc10e94600795e68c6f5e94b0", - "version" : "2.3.1" + "revision" : "f731f03bf746481d4fda07f817c3774390c4d5b9", + "version" : "2.3.2" } }, { @@ -96,8 +105,17 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/huggingface/swift-transformers", "state" : { - "revision" : "573e5c9036c2f136b3a8a071da8e8907322403d0", - "version" : "1.1.6" + "revision" : "3aecdf18e62303fb5a5543f8e87502b13474573f", + "version" : "1.1.7" + } + }, + { + "identity" : "yyjson", + "kind" : "remoteSourceControl", + "location" : "https://github.com/ibireme/yyjson.git", + "state" : { + "revision" : "8b4a38dc994a110abaec8a400615567bd996105f", + "version" : "0.12.0" } } ], diff --git a/Package.swift b/Package.swift index 3916bf01..d3a6893e 100644 --- a/Package.swift +++ b/Package.swift @@ -33,8 +33,9 @@ let package = Package( .package(url: "https://github.com/mattt/JSONSchema", from: "1.3.0"), .package(url: "https://github.com/mattt/llama.swift", .upToNextMajor(from: "2.7484.0")), .package(url: "https://github.com/mattt/PartialJSONDecoder", from: "1.0.0"), - // mlx-swift-lm must be >= 2.25.5 for ToolSpec/tool calls and UserInput(chat:processing:tools:). - .package(url: "https://github.com/ml-explore/mlx-swift-lm", from: "2.25.5"), + // mlx-swift-lm >= 2.30.3 for fast SDPA, Gemma3n per-layer intermediate_size, + // cache race fix, Memory API, and chat rehydration. >= 2.25.5 for ToolSpec/tool calls. + .package(url: "https://github.com/ml-explore/mlx-swift-lm", from: "2.30.3"), .package(url: "https://github.com/swiftlang/swift-syntax", from: "600.0.0"), ], targets: [ diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index e84a6aee..0001b022 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -18,6 +18,135 @@ import Foundation import Tokenizers import Hub + // MARK: - GPU Memory Configuration + + /// Controls Metal buffer pool behavior during and between MLX inference. + /// + /// MLX maintains a recycled buffer pool to avoid repeated Metal allocations. + /// This configuration sets the pool size during active inference (`activeCacheLimit`) + /// and between generations (`idleCacheLimit`). + /// + /// ```swift + /// // Automatic (scaled by device RAM): + /// let model = MLXLanguageModel(modelId: "...", gpuMemory: .automatic) + /// + /// // Custom: + /// let model = MLXLanguageModel(modelId: "...", gpuMemory: .init( + /// activeCacheLimit: 256_000_000, + /// idleCacheLimit: 50_000_000 + /// )) + /// ``` + public struct GPUMemoryConfiguration: Sendable, Equatable { + /// Maximum Metal buffer cache size in bytes during active inference. + public var activeCacheLimit: Int + + /// Maximum Metal buffer cache size in bytes when no inference is running. + public var idleCacheLimit: Int + + /// Whether to call `Memory.clearCache()` when a model is evicted. + public var clearCacheOnEviction: Bool + + public init( + activeCacheLimit: Int, + idleCacheLimit: Int, + clearCacheOnEviction: Bool = true + ) { + self.activeCacheLimit = activeCacheLimit + self.idleCacheLimit = idleCacheLimit + self.clearCacheOnEviction = clearCacheOnEviction + } + + /// Scaled by device RAM. Idle: 50 MB. Clear cache on eviction. + /// + /// Active limits: <4 GB → 128 MB, <6 GB → 256 MB, <8 GB → 512 MB, 8+ GB → 768 MB. + public static var automatic: GPUMemoryConfiguration { + let ramBytes = ProcessInfo.processInfo.physicalMemory + let ramGB = ramBytes / (1024 * 1024 * 1024) + + let active: Int + switch ramGB { + case ..<4: + active = 128_000_000 + case ..<6: + active = 256_000_000 + case ..<8: + active = 512_000_000 + default: + active = 768_000_000 + } + + return GPUMemoryConfiguration( + activeCacheLimit: active, + idleCacheLimit: 50_000_000, + clearCacheOnEviction: true + ) + } + + /// No management — MLX defaults, unbounded cache. + public static var unconstrained: GPUMemoryConfiguration { + GPUMemoryConfiguration( + activeCacheLimit: Int.max, + idleCacheLimit: Int.max, + clearCacheOnEviction: false + ) + } + } + + // MARK: - GPU Memory Manager + + /// Reference-counted active/idle toggling for the global Metal buffer cache. + /// + /// Multiple sessions can generate concurrently. The cache stays at `activeCacheLimit` + /// as long as ANY session is generating, and drops to `idleCacheLimit` only when ALL + /// sessions complete. + private final class GPUMemoryManager: @unchecked Sendable { + static let shared = GPUMemoryManager() + + private let lock = NSLock() + private var activeCount = 0 + private var config: GPUMemoryConfiguration = .automatic + + private init() { + Memory.cacheLimit = config.idleCacheLimit + } + + func configure(_ configuration: GPUMemoryConfiguration) { + lock.withLock { + config = configuration + if activeCount == 0 { + Memory.cacheLimit = configuration.idleCacheLimit + } + } + } + + func markActive() { + lock.withLock { + if activeCount == 0 { + Memory.cacheLimit = config.activeCacheLimit + } + activeCount += 1 + } + } + + func markIdle() { + lock.withLock { + activeCount = max(0, activeCount - 1) + if activeCount == 0 { + Memory.cacheLimit = config.idleCacheLimit + } + } + } + + func evict() { + lock.withLock { + Memory.cacheLimit = config.idleCacheLimit + if config.clearCacheOnEviction { + Memory.clearCache() + } + } + } + } + /// Wrapper to store ModelContext in NSCache (requires NSObject subclass). private final class CachedContext: NSObject, @unchecked Sendable { let context: ModelContext @@ -180,16 +309,22 @@ import Foundation /// The local directory containing the model files. public let directory: URL? + /// GPU memory management configuration for Metal buffer pools. + public let gpuMemory: GPUMemoryConfiguration + /// Creates an MLX language model. /// /// - Parameters: /// - modelId: The model identifier (for example, "mlx-community/Llama-3.2-3B-Instruct-4bit"). /// - hub: An optional Hub API instance for downloading models. If not provided, the default Hub API is used. /// - directory: An optional local directory URL containing the model files. If provided, the model is loaded from this directory instead of downloading. - public init(modelId: String, hub: HubApi? = nil, directory: URL? = nil) { + /// - gpuMemory: GPU memory configuration. Defaults to `.automatic` which scales by device RAM. + public init(modelId: String, hub: HubApi? = nil, directory: URL? = nil, gpuMemory: GPUMemoryConfiguration = .automatic) { self.modelId = modelId self.hub = hub self.directory = directory + self.gpuMemory = gpuMemory + GPUMemoryManager.shared.configure(gpuMemory) } /// Removes this model from the shared cache and cancels any in-flight load. @@ -199,11 +334,13 @@ import Foundation public func removeFromCache() async { let key = directory?.absoluteString ?? modelId await modelCache.removeAndCancel(for: key) + GPUMemoryManager.shared.evict() } /// Removes all MLX models from the shared cache and cancels in-flight loads. public static func removeAllFromCache() async { await modelCache.removeAllAndCancel() + GPUMemoryManager.shared.evict() } /// Get or load model context with caching @@ -229,6 +366,9 @@ import Foundation // Get cached or load fresh ModelContext let context = try await loadContext(modelId: modelId, hub: hub, directory: directory) + GPUMemoryManager.shared.markActive() + defer { GPUMemoryManager.shared.markIdle() } + if type != String.self { let jsonString = try await generateStructuredJSON( context: context, @@ -410,6 +550,9 @@ import Foundation let stream: AsyncThrowingStream.Snapshot, any Error> = .init { continuation in + let didMarkIdle = Locked(false) + GPUMemoryManager.shared.markActive() + let task = Task { @Sendable in do { // Get cached or load fresh ModelContext @@ -482,12 +625,23 @@ import Foundation let entry = SessionCacheEntry(kvCache: cache, prefillTokenCount: currentOffset) setSessionCache(entry, for: session) + didMarkIdle.withLock { done in + if !done { GPUMemoryManager.shared.markIdle(); done = true } + } continuation.finish() } catch { + didMarkIdle.withLock { done in + if !done { GPUMemoryManager.shared.markIdle(); done = true } + } continuation.finish(throwing: error) } } - continuation.onTermination = { _ in task.cancel() } + continuation.onTermination = { _ in + didMarkIdle.withLock { done in + if !done { GPUMemoryManager.shared.markIdle(); done = true } + } + task.cancel() + } } return LanguageModelSession.ResponseStream(stream: stream) @@ -503,6 +657,9 @@ import Foundation let directory = self.directory Task { + GPUMemoryManager.shared.markActive() + defer { GPUMemoryManager.shared.markIdle() } + do { let context = try await loadContext(modelId: modelId, hub: hub, directory: directory) From c37ab33f256c96ec0332ffbbc33ea70f843c5445 Mon Sep 17 00:00:00 2001 From: Michael Doise Date: Tue, 24 Feb 2026 20:41:06 -0600 Subject: [PATCH 3/7] Fix MLX tool calls failing due to Jinja template errors Gemma 3's Jinja chat template has no tool role support, causing tool result messages to crash the template engine during chat history replay. This fixes the issue by folding tool outputs into the preceding assistant message instead of using a separate .tool() role. Changes: - Fold tool results into assistant messages with [Tool result]: prefix to maintain strict user/assistant alternation required by Gemma 3 - Add max tool iteration guard (5) to prevent infinite tool-call loops - Fix convertToSendableJSONValue to return NSNull() instead of JSONValue.null so Jinja's Value(any:) can handle it - Check Bool before NSNumber to prevent booleans becoming integers - Record assistant text before tool calls in transcript for accurate chat replay and KV cache consistency - Move final text accumulation to after tool loop exit so only the final response is returned Fixes #112 Co-Authored-By: Claude Opus 4.6 --- .../Models/MLXLanguageModel.swift | 69 +++++++++++++++---- 1 file changed, 56 insertions(+), 13 deletions(-) diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index 107f8e84..592c58b5 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -479,8 +479,18 @@ import Foundation // on subsequent iterations (tool results added) we must rebuild. var isFirstIteration = true + // Guard against infinite tool-call loops (e.g. model keeps retrying the + // same tool call). After this many iterations, break and return whatever + // text has been accumulated. + let maxToolIterations = 5 + var toolIteration = 0 + // Loop until no more tool calls while true { + toolIteration += 1 + if toolIteration > maxToolIterations { + break + } // Build user input with current chat history and tools let userInput = MLXLMCommon.UserInput( chat: chat, @@ -552,7 +562,6 @@ import Foundation setSessionCache(cacheEntry, for: session) let assistantText = chunks.joined() - allTextChunks.append(assistantText) // Add assistant response to chat history if !assistantText.isEmpty { @@ -561,6 +570,17 @@ import Foundation // If there are tool calls, execute them and continue if !collectedToolCalls.isEmpty { + // Record the assistant text generated before the tool call + // as a transcript entry so convertTranscriptToMLXChat() can + // reproduce the exact same chat sequence on future turns + // (keeping the KV cache valid). + if !assistantText.isEmpty { + allEntries.append(.response(Transcript.Response( + assetIDs: [], + segments: [.text(.init(content: assistantText))] + ))) + } + let resolution = try await resolveToolCalls(collectedToolCalls, session: session) switch resolution { case .stop(let calls): @@ -576,13 +596,20 @@ import Foundation if !invocations.isEmpty { allEntries.append(.toolCalls(Transcript.ToolCalls(invocations.map(\.call)))) - // Execute each tool and add results to chat + // Execute each tool and fold results into the + // preceding assistant message to maintain strict + // user/assistant alternation for templates like Gemma 3. + var toolResults: [String] = [] for invocation in invocations { allEntries.append(.toolOutput(invocation.output)) - - // Convert tool output to JSON string for MLX - let toolResultJSON = toolOutputToJSON(invocation.output) - chat.append(.tool(toolResultJSON)) + toolResults.append(toolOutputToJSON(invocation.output)) + } + let combinedResults = toolResults.joined(separator: "\n") + if let lastIdx = chat.indices.last, chat[lastIdx].role == .assistant { + let existing = chat[lastIdx].content + chat[lastIdx] = .assistant(existing + "\n\n[Tool result]: " + combinedResults) + } else { + chat.append(.assistant("[Tool result]: " + combinedResults)) } // Continue loop to generate with tool results @@ -591,11 +618,13 @@ import Foundation } } - // No more tool calls, exit loop + // No more tool calls — this is the final response text + allTextChunks.append(assistantText) break } let text = allTextChunks.joined() + return LanguageModelSession.Response( content: text as! Content, rawContent: GeneratedContent(text), @@ -812,7 +841,10 @@ import Foundation chat.append(.init(role: .system, content: instructions)) } - // Convert each transcript entry + // Convert each transcript entry. + // Tool call/output entries are folded into adjacent assistant messages + // to maintain strict user/assistant alternation required by some templates + // (e.g. Gemma 3 which has no tool role support in its Jinja template). for entry in session.transcript { switch entry { case .instructions(let instr): @@ -826,12 +858,23 @@ import Foundation chat.append(.assistant(content)) case .toolCalls: - // Tool calls are handled inline during generation loop + // Skip — tool calls were already executed; the output is captured below break case .toolOutput(let toolOutput): - let content = toolOutput.segments.map { extractText(from: $0) }.joined(separator: "\n") - chat.append(.tool(content)) + // Fold tool output into the preceding assistant message to avoid + // injecting a .tool() role that breaks strict-alternation templates. + // Use toolOutputToJSON() — the same function used in the live tool + // loop — so the replayed chat produces identical tokens, keeping + // the KV cache valid for session continuation. + let content = toolOutputToJSON(toolOutput) + if let lastIdx = chat.indices.last, chat[lastIdx].role == .assistant { + let existing = chat[lastIdx].content + chat[lastIdx] = .assistant(existing + "\n\n[Tool result]: " + content) + } else { + // No preceding assistant message — wrap as assistant + chat.append(.assistant("[Tool result]: " + content)) + } } } @@ -947,7 +990,7 @@ import Foundation } private func convertToSendableJSONValue(_ value: Any) throws -> any Sendable { - if value is NSNull { return MLXLMCommon.JSONValue.null } + if value is NSNull { return NSNull() } if let stringValue = value as? String { return stringValue } if let boolValue = value as? Bool { return boolValue } if let intValue = value as? Int { return intValue } @@ -956,7 +999,7 @@ import Foundation return numberValue.doubleValue } if let arrayValue = value as? [Any] { - return try arrayValue.map { try convertToSendableJSONValue($0) } + return try arrayValue.map { try convertToSendableJSONValue($0) } as [any Sendable] } if let dictionaryValue = value as? [String: Any] { return try convertToSendableJSONObject(dictionaryValue) From ec343604f95a5fdf884381db5ba0e1d79dabec84 Mon Sep 17 00:00:00 2001 From: Michael Doise Date: Wed, 25 Feb 2026 12:05:29 -0600 Subject: [PATCH 4/7] Address PR review: cache hash validation, dedup, GPU config safety, tool-aware prewarm - Add prefillTokenHash to SessionCacheEntry to detect stale cache from replaced conversations (not just token count) - Extract resolveCache() helper to deduplicate cache hit/miss logic between respond() and streamResponse() - GPUMemoryManager.configure() now uses first-write-wins to prevent multiple MLXLanguageModel instances from silently overwriting config - prewarm() accepts tools via protocol and session automatically forwards registered tools so prefill tokenization matches respond() Co-Authored-By: Claude Opus 4.6 --- Sources/AnyLanguageModel/LanguageModel.swift | 6 +- .../LanguageModelSession.swift | 2 +- .../Models/MLXLanguageModel.swift | 154 +++++++++++------- 3 files changed, 104 insertions(+), 58 deletions(-) diff --git a/Sources/AnyLanguageModel/LanguageModel.swift b/Sources/AnyLanguageModel/LanguageModel.swift index 635de68c..2d801da8 100644 --- a/Sources/AnyLanguageModel/LanguageModel.swift +++ b/Sources/AnyLanguageModel/LanguageModel.swift @@ -14,7 +14,8 @@ public protocol LanguageModel: Sendable { func prewarm( for session: LanguageModelSession, - promptPrefix: Prompt? + promptPrefix: Prompt?, + tools: [any Tool]? ) func respond( @@ -54,7 +55,8 @@ extension LanguageModel { public func prewarm( for session: LanguageModelSession, - promptPrefix: Prompt? = nil + promptPrefix: Prompt? = nil, + tools: [any Tool]? = nil ) { return } diff --git a/Sources/AnyLanguageModel/LanguageModelSession.swift b/Sources/AnyLanguageModel/LanguageModelSession.swift index ba38550e..d3ea39b3 100644 --- a/Sources/AnyLanguageModel/LanguageModelSession.swift +++ b/Sources/AnyLanguageModel/LanguageModelSession.swift @@ -98,7 +98,7 @@ public final class LanguageModelSession: @unchecked Sendable { } public func prewarm(promptPrefix: Prompt? = nil) { - model.prewarm(for: self, promptPrefix: promptPrefix) + model.prewarm(for: self, promptPrefix: promptPrefix, tools: tools.isEmpty ? nil : tools) } nonisolated private func beginResponding() { diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index 107f8e84..e843c132 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -105,13 +105,21 @@ import Foundation private let lock = NSLock() private var activeCount = 0 private var config: GPUMemoryConfiguration = .automatic + private var hasCustomConfig = false private init() { Memory.cacheLimit = config.idleCacheLimit } + /// Applies a GPU memory configuration. First custom configuration wins — + /// subsequent calls with a different configuration are ignored to prevent + /// multiple MLXLanguageModel instances from silently overwriting each other. func configure(_ configuration: GPUMemoryConfiguration) { lock.withLock { + if hasCustomConfig && config != configuration { + return + } + hasCustomConfig = true config = configuration if activeCount == 0 { Memory.cacheLimit = configuration.idleCacheLimit @@ -307,10 +315,12 @@ import Foundation private final class SessionCacheEntry: NSObject, @unchecked Sendable { var kvCache: [MLXLMCommon.KVCache] var prefillTokenCount: Int + var prefillTokenHash: Int - init(kvCache: [MLXLMCommon.KVCache], prefillTokenCount: Int) { + init(kvCache: [MLXLMCommon.KVCache], prefillTokenCount: Int, prefillTokenHash: Int) { self.kvCache = kvCache self.prefillTokenCount = prefillTokenCount + self.prefillTokenHash = prefillTokenHash } } @@ -337,6 +347,48 @@ import Foundation sessionKVCache.removeObject(forKey: session) } + /// Hashes up to the first `count` tokens of an MLXArray for cache identity checks. + private func hashTokenPrefix(_ tokens: MLXArray, count: Int = 64) -> Int { + let tokenCount = tokens.dim(0) + let n = min(count, tokenCount) + guard n > 0 else { return 0 } + let prefix = tokens[0.. (cache: [MLXLMCommon.KVCache], input: MLXLMCommon.LMInput) { + let existingEntry = getSessionCache(session) + let fullTokenCount = lmInput.text.tokens.dim(0) + let currentHash = hashTokenPrefix(lmInput.text.tokens) + + if let existingEntry, + existingEntry.prefillTokenCount > 0, + fullTokenCount > existingEntry.prefillTokenCount, + existingEntry.prefillTokenHash == currentHash, + lmInput.image == nil + { + let cachedCount = existingEntry.prefillTokenCount + let newTokens = lmInput.text.tokens[cachedCount...] + let partialText = MLXLMCommon.LMInput.Text(tokens: newTokens) + return (cache: existingEntry.kvCache, input: MLXLMCommon.LMInput(text: partialText)) + } + + if existingEntry != nil { + removeSessionCache(for: session) + } + let freshCache = context.model.newCache(parameters: generateParameters) + return (cache: freshCache, input: lmInput) + } + // MARK: - MLXLanguageModel /// A language model that runs locally using MLX. @@ -494,28 +546,14 @@ import Foundation let inputForGeneration: MLXLMCommon.LMInput if isFirstIteration { - let existingEntry = getSessionCache(session) - let fullTokenCount = lmInput.text.tokens.dim(0) - - if let existingEntry, - existingEntry.prefillTokenCount > 0, - fullTokenCount > existingEntry.prefillTokenCount, - lmInput.image == nil - { - // Cache HIT: only prefill new tokens - let cachedCount = existingEntry.prefillTokenCount - let newTokens = lmInput.text.tokens[cachedCount...] - let partialText = MLXLMCommon.LMInput.Text(tokens: newTokens) - inputForGeneration = MLXLMCommon.LMInput(text: partialText) - cache = existingEntry.kvCache - } else { - // Cache MISS: create fresh cache - if existingEntry != nil { - removeSessionCache(for: session) - } - cache = context.model.newCache(parameters: generateParameters) - inputForGeneration = lmInput - } + let resolved = resolveCache( + for: session, + lmInput: lmInput, + generateParameters: generateParameters, + context: context + ) + cache = resolved.cache + inputForGeneration = resolved.input } else { // Tool-calling iterations: fresh cache (chat has been mutated) cache = context.model.newCache(parameters: generateParameters) @@ -548,7 +586,11 @@ import Foundation // Update session cache with current offset after generation let currentOffset = cache.first?.offset ?? 0 - let cacheEntry = SessionCacheEntry(kvCache: cache, prefillTokenCount: currentOffset) + let cacheEntry = SessionCacheEntry( + kvCache: cache, + prefillTokenCount: currentOffset, + prefillTokenHash: hashTokenPrefix(lmInput.text.tokens) + ) setSessionCache(cacheEntry, for: session) let assistantText = chunks.joined() @@ -639,33 +681,15 @@ import Foundation ) let lmInput = try await context.processor.prepare(input: userInput) - // Check for existing KV cache for this session - let existingEntry = getSessionCache(session) - let cache: [MLXLMCommon.KVCache] - let inputForGeneration: MLXLMCommon.LMInput - - let fullTokenCount = lmInput.text.tokens.dim(0) - - if let existingEntry, - existingEntry.prefillTokenCount > 0, - fullTokenCount > existingEntry.prefillTokenCount, - lmInput.image == nil - { - // Cache HIT: only prefill new tokens - let cachedCount = existingEntry.prefillTokenCount - let newTokens = lmInput.text.tokens[cachedCount...] - let partialText = MLXLMCommon.LMInput.Text(tokens: newTokens) - inputForGeneration = MLXLMCommon.LMInput(text: partialText) - cache = existingEntry.kvCache - } else { - // Cache MISS: create fresh cache, prefill everything - if existingEntry != nil { - removeSessionCache(for: session) - } - let newCache = context.model.newCache(parameters: generateParameters) - cache = newCache - inputForGeneration = lmInput - } + // Resolve KV cache for this session + let resolved = resolveCache( + for: session, + lmInput: lmInput, + generateParameters: generateParameters, + context: context + ) + let cache = resolved.cache + let inputForGeneration = resolved.input let mlxStream = try MLXLMCommon.generate( input: inputForGeneration, @@ -692,7 +716,11 @@ import Foundation // Update the session cache with current offset let currentOffset = cache.first?.offset ?? 0 - let entry = SessionCacheEntry(kvCache: cache, prefillTokenCount: currentOffset) + let entry = SessionCacheEntry( + kvCache: cache, + prefillTokenCount: currentOffset, + prefillTokenHash: hashTokenPrefix(lmInput.text.tokens) + ) setSessionCache(entry, for: session) didMarkIdle.withLock { done in @@ -718,9 +746,17 @@ import Foundation } /// Prewarms the model by loading it and optionally prefilling the system prompt into a KV cache. + /// + /// - Parameters: + /// - session: The session whose instructions will be prefilled. + /// - promptPrefix: An optional prompt prefix (reserved for future use). + /// - tools: Tools that will be used with this session. Pass the same tools here + /// so the prefilled cache includes tool definitions in its tokenization, + /// avoiding a cache miss on the first real request. public func prewarm( for session: LanguageModelSession, - promptPrefix: Prompt? + promptPrefix: Prompt?, + tools: [any Tool]? = nil ) { let modelId = self.modelId let hub = self.hub @@ -738,17 +774,25 @@ import Foundation let params = MLXLMCommon.GenerateParameters() let newCache = context.model.newCache(parameters: params) let chat: [MLXLMCommon.Chat.Message] = [.init(role: .system, content: instructions)] + + // Convert tools to MLX ToolSpec format so the prefill tokenization + // matches what respond() will produce, ensuring cache hits. + let toolSpecs: [ToolSpec]? = tools.flatMap { toolList in + toolList.isEmpty ? nil : toolList.map { convertToolToMLXSpec($0) } + } + let userInput = MLXLMCommon.UserInput( chat: chat, processing: .init(resize: .init(width: 512, height: 512)), - tools: nil + tools: toolSpecs ) let lmInput = try await context.processor.prepare(input: userInput) _ = try context.model.prepare(lmInput, cache: newCache, windowSize: nil) let entry = SessionCacheEntry( kvCache: newCache, - prefillTokenCount: newCache.first?.offset ?? 0 + prefillTokenCount: newCache.first?.offset ?? 0, + prefillTokenHash: hashTokenPrefix(lmInput.text.tokens) ) setSessionCache(entry, for: session) } From 4ea4fdbc5531c128064a97b89753fd0267bda690 Mon Sep 17 00:00:00 2001 From: Michael Doise Date: Wed, 25 Feb 2026 13:39:42 -0600 Subject: [PATCH 5/7] Fix duplicate tool call loops and clear KV caches on eviction - Detect when the MLX tool loop generates the same tool call signature as the previous iteration and break early instead of retrying - Clear sessionKVCache in removeAllFromCache() so memory warning handlers actually free GPU memory from cached KV states Co-Authored-By: Claude Opus 4.6 --- .../Models/MLXLanguageModel.swift | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index e843c132..9d1f02e3 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -462,6 +462,9 @@ import Foundation /// Removes all MLX models from the shared cache and cancels in-flight loads. public static func removeAllFromCache() async { await modelCache.removeAllAndCancel() + sessionKVCacheLock.withLock { + sessionKVCache.removeAllObjects() + } GPUMemoryManager.shared.evict() } @@ -531,6 +534,13 @@ import Foundation // on subsequent iterations (tool results added) we must rebuild. var isFirstIteration = true + // Guard against infinite tool-call loops (e.g. model keeps retrying the + // same tool call). After this many iterations, break and return whatever + // text has been accumulated. + let maxToolIterations = 5 + var toolIteration = 0 + var previousToolCallSignature: String? + // Loop until no more tool calls while true { // Build user input with current chat history and tools @@ -603,6 +613,29 @@ import Foundation // If there are tool calls, execute them and continue if !collectedToolCalls.isEmpty { + // Detect repeated tool calls — if the model generates the exact + // same tool call(s) as the previous iteration, it's stuck in a + // loop. Break and return whatever text we have so far. + let signature = collectedToolCalls + .map { "\($0.function.name):\($0.function.arguments)" } + .joined(separator: "|") + if signature == previousToolCallSignature { + allTextChunks.append(assistantText) + break + } + previousToolCallSignature = signature + + // Record the assistant text generated before the tool call + // as a transcript entry so convertTranscriptToMLXChat() can + // reproduce the exact same chat sequence on future turns + // (keeping the KV cache valid). + if !assistantText.isEmpty { + allEntries.append(.response(Transcript.Response( + assetIDs: [], + segments: [.text(.init(content: assistantText))] + ))) + } + let resolution = try await resolveToolCalls(collectedToolCalls, session: session) switch resolution { case .stop(let calls): From 4dc86c7b6fa7308b9f2d888f2157daa77930a3c9 Mon Sep 17 00:00:00 2001 From: Michael Doise Date: Thu, 26 Feb 2026 12:08:06 -0600 Subject: [PATCH 6/7] Add session transcript swap API for memory-efficient session reuse MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds snapshotTranscript() and replaceTranscript() to LanguageModelSession, enabling consumers to swap a session's context without creating a second session. This is critical for MLX where each session allocates a KV cache in app memory — swapping transcripts keeps peak memory at one cache. - LanguageModel protocol: add invalidateCache(for:) with default no-op - LanguageModelSession: add snapshotTranscript() and replaceTranscript() with atomic check-and-mutate to prevent TOCTOU races - MLXLanguageModel: override invalidateCache to evict KV cache entry Co-Authored-By: Claude Opus 4.6 --- Sources/AnyLanguageModel/LanguageModel.swift | 9 ++++ .../LanguageModelSession.swift | 52 +++++++++++++++++++ .../Models/MLXLanguageModel.swift | 4 ++ 3 files changed, 65 insertions(+) diff --git a/Sources/AnyLanguageModel/LanguageModel.swift b/Sources/AnyLanguageModel/LanguageModel.swift index 2d801da8..9e96e965 100644 --- a/Sources/AnyLanguageModel/LanguageModel.swift +++ b/Sources/AnyLanguageModel/LanguageModel.swift @@ -40,6 +40,13 @@ public protocol LanguageModel: Sendable { issues: [LanguageModelFeedback.Issue], desiredOutput: Transcript.Entry? ) -> Data + + /// Invalidates any cached state associated with the given session. + /// + /// Called when a session's transcript is replaced. Models that maintain + /// per-session caches (such as MLX KV caches) should evict the entry + /// for the given session. The default implementation does nothing. + func invalidateCache(for session: LanguageModelSession) } // MARK: - Default Implementation @@ -69,6 +76,8 @@ extension LanguageModel { ) -> Data { return Data() } + + public func invalidateCache(for session: LanguageModelSession) {} } extension LanguageModel where UnavailableReason == Never { diff --git a/Sources/AnyLanguageModel/LanguageModelSession.swift b/Sources/AnyLanguageModel/LanguageModelSession.swift index d3ea39b3..35b26a65 100644 --- a/Sources/AnyLanguageModel/LanguageModelSession.swift +++ b/Sources/AnyLanguageModel/LanguageModelSession.swift @@ -914,6 +914,58 @@ private enum ResponseStreamError: Error { case noSnapshots } +// MARK: - Transcript Snapshot and Replacement + +extension LanguageModelSession { + /// An error that occurs when attempting to replace a transcript. + public enum TranscriptError: Error { + /// The session is currently generating a response. + case respondingInProgress + } + + /// Captures the current transcript for later restoration. + /// + /// It is safe to call this while the session is responding, but the + /// returned transcript may include a prompt entry without a corresponding + /// response. For deterministic snapshots, call this when ``isResponding`` + /// is `false`. + nonisolated public func snapshotTranscript() -> Transcript { + state.withLock { $0.transcript } + } + + /// Replaces the session's transcript and invalidates any associated + /// backend cache (for example, the MLX KV cache). + /// + /// Use this to temporarily swap a session's context for a different + /// task (such as running a tool agent), then restore the original + /// transcript afterward. The next call to ``respond(to:options:)`` + /// will rebuild any caches from the new transcript. + /// + /// - Parameter transcript: The transcript to install. + /// - Throws: ``TranscriptError/respondingInProgress`` if the session + /// is currently generating a response. + nonisolated public func replaceTranscript(_ transcript: Transcript) throws { + // Atomic check-and-mutate: verify not responding and swap transcript + // in a single lock acquisition to prevent TOCTOU races. + let didReplace: Bool = state.withLock { lockedState in + guard !lockedState.isResponding else { return false } + lockedState.transcript = transcript + return true + } + + guard didReplace else { + throw TranscriptError.respondingInProgress + } + + // Notify observation after releasing the lock + withMutation(keyPath: \.transcript) {} + + // Invalidate backend caches (e.g. MLX KV cache) outside the lock + // to avoid lock inversion if the model calls back into the session. + model.invalidateCache(for: self) + } +} + // MARK: - private struct State: Equatable, Sendable { diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index 26ccad88..3f76e964 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -468,6 +468,10 @@ import Foundation GPUMemoryManager.shared.evict() } + public func invalidateCache(for session: LanguageModelSession) { + removeSessionCache(for: session) + } + /// Get or load model context with caching private func loadContext(modelId: String, hub: HubApi?, directory: URL?) async throws -> ModelContext { let key = directory?.absoluteString ?? modelId From 4ad8b3684caf0b627972af2d22edbf2767d4c292 Mon Sep 17 00:00:00 2001 From: Michael Doise Date: Fri, 27 Feb 2026 13:42:26 -0600 Subject: [PATCH 7/7] Add DownloadableLanguageModel protocol and MLXModelDownloadManager Expose proactive download management for MLX models: - DownloadProgress struct, ModelDownloadState enum, DownloadableLanguageModel protocol - MLXModelDownloadManager (@Observable) with disk-state scanning, AsyncStream downloads, and delete support - MLXLanguageModel conforms to DownloadableLanguageModel - Wire progress reporting into loadContext() so lazy downloads also update state - Re-export Hub types for consumer access Co-Authored-By: Claude Opus 4.6 --- Package.resolved | 74 +----- .../DownloadableLanguageModel.swift | 93 +++++++ .../Models/MLXLanguageModel.swift | 235 +++++++++++++++++- 3 files changed, 328 insertions(+), 74 deletions(-) create mode 100644 Sources/AnyLanguageModel/DownloadableLanguageModel.swift diff --git a/Package.resolved b/Package.resolved index 1495dc88..48877163 100644 --- a/Package.resolved +++ b/Package.resolved @@ -1,5 +1,5 @@ { - "originHash" : "f7b86b800200fa069a2b288e06bafe53bc937a1851b6effeebba326a62be227e", + "originHash" : "a10eccc84124921a1b0c29a27934296127979574607ae6358a0605b0a15081ca", "pins" : [ { "identity" : "eventsource", @@ -19,24 +19,6 @@ "version" : "1.3.1" } }, - { - "identity" : "mlx-swift", - "kind" : "remoteSourceControl", - "location" : "https://github.com/ml-explore/mlx-swift", - "state" : { - "revision" : "6ba4827fb82c97d012eec9ab4b2de21f85c3b33d", - "version" : "0.30.6" - } - }, - { - "identity" : "mlx-swift-lm", - "kind" : "remoteSourceControl", - "location" : "https://github.com/ml-explore/mlx-swift-lm", - "state" : { - "revision" : "7e19e09027923d89ac47dd087d9627f610e5a91a", - "version" : "2.30.6" - } - }, { "identity" : "partialjsondecoder", "kind" : "remoteSourceControl", @@ -46,15 +28,6 @@ "version" : "1.0.0" } }, - { - "identity" : "swift-asn1", - "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-asn1.git", - "state" : { - "revision" : "810496cf121e525d660cd0ea89a758740476b85f", - "version" : "1.5.1" - } - }, { "identity" : "swift-collections", "kind" : "remoteSourceControl", @@ -64,33 +37,6 @@ "version" : "1.3.0" } }, - { - "identity" : "swift-crypto", - "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-crypto.git", - "state" : { - "revision" : "6f70fa9eab24c1fd982af18c281c4525d05e3095", - "version" : "4.2.0" - } - }, - { - "identity" : "swift-jinja", - "kind" : "remoteSourceControl", - "location" : "https://github.com/huggingface/swift-jinja.git", - "state" : { - "revision" : "f731f03bf746481d4fda07f817c3774390c4d5b9", - "version" : "2.3.2" - } - }, - { - "identity" : "swift-numerics", - "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-numerics", - "state" : { - "revision" : "0c0290ff6b24942dadb83a929ffaaa1481df04a2", - "version" : "1.1.1" - } - }, { "identity" : "swift-syntax", "kind" : "remoteSourceControl", @@ -99,24 +45,6 @@ "revision" : "0687f71944021d616d34d922343dcef086855920", "version" : "600.0.1" } - }, - { - "identity" : "swift-transformers", - "kind" : "remoteSourceControl", - "location" : "https://github.com/huggingface/swift-transformers", - "state" : { - "revision" : "3aecdf18e62303fb5a5543f8e87502b13474573f", - "version" : "1.1.7" - } - }, - { - "identity" : "yyjson", - "kind" : "remoteSourceControl", - "location" : "https://github.com/ibireme/yyjson.git", - "state" : { - "revision" : "8b4a38dc994a110abaec8a400615567bd996105f", - "version" : "0.12.0" - } } ], "version" : 3 diff --git a/Sources/AnyLanguageModel/DownloadableLanguageModel.swift b/Sources/AnyLanguageModel/DownloadableLanguageModel.swift new file mode 100644 index 00000000..450476b7 --- /dev/null +++ b/Sources/AnyLanguageModel/DownloadableLanguageModel.swift @@ -0,0 +1,93 @@ +#if MLX + import Foundation + + // MARK: - Download Progress + + /// Reports progress of a model download. + public struct DownloadProgress: Sendable, Equatable { + /// Fraction of the download completed, from 0.0 to 1.0. + public var fractionCompleted: Double + + /// Number of bytes downloaded so far, if known. + public var completedBytes: Int64? + + /// Total expected bytes, if known. + public var totalBytes: Int64? + + /// Current download speed in bytes per second, if known. + public var bytesPerSecond: Double? + + public init( + fractionCompleted: Double, + completedBytes: Int64? = nil, + totalBytes: Int64? = nil, + bytesPerSecond: Double? = nil + ) { + self.fractionCompleted = fractionCompleted + self.completedBytes = completedBytes + self.totalBytes = totalBytes + self.bytesPerSecond = bytesPerSecond + } + } + + // MARK: - Download State + + /// The on-disk state of a downloadable model. + public enum ModelDownloadState: Sendable, Equatable { + /// The model files are not present on disk. + case notDownloaded + /// The model is currently being downloaded. + case downloading(DownloadProgress) + /// The model is fully downloaded and ready to load. + case downloaded + + public static func == (lhs: ModelDownloadState, rhs: ModelDownloadState) -> Bool { + switch (lhs, rhs) { + case (.notDownloaded, .notDownloaded): + return true + case (.downloaded, .downloaded): + return true + case (.downloading(let a), .downloading(let b)): + return a == b + default: + return false + } + } + } + + // MARK: - DownloadableLanguageModel Protocol + + /// A language model whose weights can be downloaded, inspected, and deleted. + /// + /// Backends that run locally (e.g. MLX) conform to this protocol to expose + /// download management. API-only backends (OpenAI, Anthropic, etc.) do not + /// need to conform — consumers can check conformance with: + /// + /// ```swift + /// if let downloadable = model as? any DownloadableLanguageModel { ... } + /// ``` + public protocol DownloadableLanguageModel: LanguageModel { + /// Whether the model's files are fully present on disk. + var isDownloaded: Bool { get } + + /// The current download state of this model. + var downloadState: ModelDownloadState { get } + + /// Starts downloading the model and returns a stream of progress updates. + /// + /// If the model is already downloaded, the stream completes immediately. + /// Cancelling the consuming `Task` cancels the download. + func download() -> AsyncStream + + /// Removes the downloaded model files from disk. + /// + /// Also cancels any in-flight download and removes the model from the + /// in-memory cache. + func deleteDownload() async throws + + /// The total size of the downloaded model on disk, in bytes. + /// + /// Returns `nil` if the model is not downloaded or the size cannot be determined. + var downloadedSizeOnDisk: Int64? { get } + } +#endif diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index 3f76e964..6a052573 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -17,6 +17,7 @@ import Foundation import MLXVLM import Tokenizers import Hub + import Observation // MARK: - GPU Memory Configuration @@ -481,7 +482,18 @@ import Foundation return try await loadModel(directory: directory) } - return try await loadModel(hub: hub ?? HubApi(), id: modelId) + let effectiveHub = hub ?? HubApi() + let context = try await loadModel(hub: effectiveHub, id: modelId) { progress in + let dp = DownloadProgress( + fractionCompleted: progress.fractionCompleted, + completedBytes: progress.completedUnitCount, + totalBytes: progress.totalUnitCount > 0 ? progress.totalUnitCount : nil, + bytesPerSecond: progress.userInfo[.throughputKey] as? Double + ) + MLXModelDownloadManager.shared.updateState(for: modelId, to: .downloading(dp)) + } + MLXModelDownloadManager.shared.updateState(for: modelId, to: .downloaded) + return context } } @@ -1520,4 +1532,225 @@ import Foundation return sampledToken.item(Int.self) } } + + // MARK: - MLXModelDownloadManager + + /// Manages download state for MLX models. + /// + /// `@Observable` class (not actor) for direct SwiftUI binding. Thread safety + /// is provided by `Locked<>` for the state dictionary. + @Observable + public final class MLXModelDownloadManager: @unchecked Sendable { + public static let shared = MLXModelDownloadManager() + + /// Current download state per model ID. + public private(set) var states: [String: ModelDownloadState] = [:] + + /// In-flight download tasks, keyed by model ID. + private let inFlightDownloads = Locked<[String: Task]>([:]) + + /// Lock for thread-safe state mutations that trigger @Observable updates. + private let stateLock = NSLock() + + private init() {} + + // MARK: - State Queries + + /// Returns the current download state for a model. + /// + /// On first access for a given model, checks disk for existing files. + public func state(for modelId: String, hub: HubApi = HubApi()) -> ModelDownloadState { + stateLock.lock() + if let cached = states[modelId] { + stateLock.unlock() + return cached + } + stateLock.unlock() + + // Check disk + let repo = Hub.Repo(id: modelId) + let repoDir = hub.localRepoLocation(repo) + let configPath = repoDir.appendingPathComponent("config.json") + + let isOnDisk = FileManager.default.fileExists(atPath: configPath.path) + let state: ModelDownloadState = isOnDisk ? .downloaded : .notDownloaded + + stateLock.lock() + // Don't overwrite an in-progress download + if states[modelId] == nil { + withMutation(keyPath: \.states) { + states[modelId] = state + } + } + stateLock.unlock() + + return states[modelId] ?? state + } + + /// Updates the download state for a model. Called from download handlers. + public func updateState(for modelId: String, to newState: ModelDownloadState) { + stateLock.lock() + withMutation(keyPath: \.states) { + states[modelId] = newState + } + stateLock.unlock() + } + + /// Clears cached state for a model, forcing a fresh disk check on next access. + public func invalidateState(for modelId: String) { + stateLock.lock() + withMutation(keyPath: \.states) { + states[modelId] = nil + } + stateLock.unlock() + } + + // MARK: - Download + + /// Downloads a model and returns a stream of progress updates. + /// + /// If the model is already downloaded, the stream completes immediately. + /// Cancelling the consuming `Task` cancels the download (snapshot checks + /// `Task.isCancelled` between files). + public func download(modelId: String, hub: HubApi = HubApi()) -> AsyncStream { + let currentState = state(for: modelId, hub: hub) + if case .downloaded = currentState { + return AsyncStream { $0.finish() } + } + + return AsyncStream { continuation in + let task = Task { [weak self] in + guard let self else { + continuation.finish() + return + } + do { + let repo = Hub.Repo(id: modelId) + try await hub.snapshot( + from: repo, + matching: ["*.safetensors", "*.json", "tokenizer.model"] + ) { progress, speed in + let dp = DownloadProgress( + fractionCompleted: progress.fractionCompleted, + completedBytes: progress.completedUnitCount, + totalBytes: progress.totalUnitCount > 0 ? progress.totalUnitCount : nil, + bytesPerSecond: speed + ) + self.updateState(for: modelId, to: .downloading(dp)) + continuation.yield(dp) + } + self.updateState(for: modelId, to: .downloaded) + continuation.finish() + } catch { + self.updateState(for: modelId, to: .notDownloaded) + continuation.finish() + } + self.inFlightDownloads.withLock { $0[modelId] = nil } + } + + inFlightDownloads.withLock { $0[modelId] = task } + + continuation.onTermination = { [weak self] _ in + self?.inFlightDownloads.withLock { downloads in + downloads[modelId]?.cancel() + downloads[modelId] = nil + } + } + } + } + + // MARK: - Delete + + /// Removes downloaded model files from disk and resets state. + /// + /// Also cancels any in-flight download and evicts the model from + /// the in-memory model cache. + public func deleteDownload(modelId: String, hub: HubApi = HubApi()) async throws { + // Cancel in-flight download + inFlightDownloads.withLock { downloads in + downloads[modelId]?.cancel() + downloads[modelId] = nil + } + + // Remove from in-memory model cache + let cacheModel = MLXLanguageModel(modelId: modelId, hub: hub) + await cacheModel.removeFromCache() + + // Delete files on disk + let repo = Hub.Repo(id: modelId) + let repoDir = hub.localRepoLocation(repo) + if FileManager.default.fileExists(atPath: repoDir.path) { + try FileManager.default.removeItem(at: repoDir) + } + + updateState(for: modelId, to: .notDownloaded) + } + + // MARK: - Disk Size + + /// Returns the total size of the downloaded model on disk, in bytes. + public func downloadedSize(for modelId: String, hub: HubApi = HubApi()) -> Int64? { + let repo = Hub.Repo(id: modelId) + let repoDir = hub.localRepoLocation(repo) + guard FileManager.default.fileExists(atPath: repoDir.path) else { return nil } + return Self.directorySize(at: repoDir) + } + + static func directorySize(at url: URL) -> Int64 { + let fm = FileManager.default + guard let enumerator = fm.enumerator( + at: url, + includingPropertiesForKeys: [.fileSizeKey, .isDirectoryKey], + options: [.skipsHiddenFiles] + ) else { return 0 } + + var totalSize: Int64 = 0 + for case let fileURL as URL in enumerator { + guard let values = try? fileURL.resourceValues(forKeys: [.fileSizeKey, .isDirectoryKey]), + values.isDirectory != true, + let size = values.fileSize + else { continue } + totalSize += Int64(size) + } + return totalSize + } + } + + // MARK: - DownloadableLanguageModel Conformance + + extension MLXLanguageModel: DownloadableLanguageModel { + public var isDownloaded: Bool { + if directory != nil { return true } + return MLXModelDownloadManager.shared.state(for: modelId, hub: hub ?? HubApi()) == .downloaded + } + + public var downloadState: ModelDownloadState { + if directory != nil { return .downloaded } + return MLXModelDownloadManager.shared.state(for: modelId, hub: hub ?? HubApi()) + } + + public func download() -> AsyncStream { + if directory != nil { + return AsyncStream { $0.finish() } + } + return MLXModelDownloadManager.shared.download(modelId: modelId, hub: hub ?? HubApi()) + } + + public func deleteDownload() async throws { + if directory != nil { return } + try await MLXModelDownloadManager.shared.deleteDownload(modelId: modelId, hub: hub ?? HubApi()) + } + + public var downloadedSizeOnDisk: Int64? { + if directory != nil { return MLXModelDownloadManager.directorySize(at: directory!) } + return MLXModelDownloadManager.shared.downloadedSize(for: modelId, hub: hub ?? HubApi()) + } + } + + // MARK: - Hub Re-exports + + /// Re-export Hub types so consumers can configure custom HubApi instances + /// (download location, HF token) without importing Hub directly. + public typealias HubRepo = Hub.Repo + public typealias HubRepoType = Hub.RepoType #endif // MLX