diff --git a/Package.resolved b/Package.resolved index 837d776..4887716 100644 --- a/Package.resolved +++ b/Package.resolved @@ -1,13 +1,13 @@ { - "originHash" : "f7b86b800200fa069a2b288e06bafe53bc937a1851b6effeebba326a62be227e", + "originHash" : "a10eccc84124921a1b0c29a27934296127979574607ae6358a0605b0a15081ca", "pins" : [ { "identity" : "eventsource", "kind" : "remoteSourceControl", - "location" : "https://github.com/mattt/EventSource.git", + "location" : "https://github.com/mattt/EventSource", "state" : { - "revision" : "ca2a9d90cbe49e09b92f4b6ebd922c03ebea51d0", - "version" : "1.3.0" + "revision" : "bd64824505da71a1a403adb221f6e25413c0bc7f", + "version" : "1.4.0" } }, { @@ -19,33 +19,6 @@ "version" : "1.3.1" } }, - { - "identity" : "llama.swift", - "kind" : "remoteSourceControl", - "location" : "https://github.com/mattt/llama.swift", - "state" : { - "revision" : "4d57cff84ba85914baa39850157e7c27684db9c8", - "version" : "2.7966.0" - } - }, - { - "identity" : "mlx-swift", - "kind" : "remoteSourceControl", - "location" : "https://github.com/ml-explore/mlx-swift", - "state" : { - "revision" : "072b684acaae80b6a463abab3a103732f33774bf", - "version" : "0.29.1" - } - }, - { - "identity" : "mlx-swift-lm", - "kind" : "remoteSourceControl", - "location" : "https://github.com/ml-explore/mlx-swift-lm", - "state" : { - "revision" : "5064b8c5d8ed3b0bbb71385c4124f0fc102e74a2", - "version" : "2.29.3" - } - }, { "identity" : "partialjsondecoder", "kind" : "remoteSourceControl", @@ -64,24 +37,6 @@ "version" : "1.3.0" } }, - { - "identity" : "swift-jinja", - "kind" : "remoteSourceControl", - "location" : "https://github.com/huggingface/swift-jinja.git", - "state" : { - "revision" : "d81197f35f41445bc10e94600795e68c6f5e94b0", - "version" : "2.3.1" - } - }, - { - "identity" : "swift-numerics", - "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-numerics", - "state" : { - "revision" : "0c0290ff6b24942dadb83a929ffaaa1481df04a2", - "version" : "1.1.1" - } - }, { "identity" : "swift-syntax", "kind" : "remoteSourceControl", @@ -90,15 +45,6 @@ "revision" : "0687f71944021d616d34d922343dcef086855920", "version" : "600.0.1" } - }, - { - "identity" : "swift-transformers", - "kind" : "remoteSourceControl", - "location" : "https://github.com/huggingface/swift-transformers", - "state" : { - "revision" : "573e5c9036c2f136b3a8a071da8e8907322403d0", - "version" : "1.1.6" - } } ], "version" : 3 diff --git a/Package.swift b/Package.swift index 3916bf0..d3a6893 100644 --- a/Package.swift +++ b/Package.swift @@ -33,8 +33,9 @@ let package = Package( .package(url: "https://github.com/mattt/JSONSchema", from: "1.3.0"), .package(url: "https://github.com/mattt/llama.swift", .upToNextMajor(from: "2.7484.0")), .package(url: "https://github.com/mattt/PartialJSONDecoder", from: "1.0.0"), - // mlx-swift-lm must be >= 2.25.5 for ToolSpec/tool calls and UserInput(chat:processing:tools:). - .package(url: "https://github.com/ml-explore/mlx-swift-lm", from: "2.25.5"), + // mlx-swift-lm >= 2.30.3 for fast SDPA, Gemma3n per-layer intermediate_size, + // cache race fix, Memory API, and chat rehydration. >= 2.25.5 for ToolSpec/tool calls. + .package(url: "https://github.com/ml-explore/mlx-swift-lm", from: "2.30.3"), .package(url: "https://github.com/swiftlang/swift-syntax", from: "600.0.0"), ], targets: [ diff --git a/Sources/AnyLanguageModel/DownloadableLanguageModel.swift b/Sources/AnyLanguageModel/DownloadableLanguageModel.swift new file mode 100644 index 0000000..450476b --- /dev/null +++ b/Sources/AnyLanguageModel/DownloadableLanguageModel.swift @@ -0,0 +1,93 @@ +#if MLX + import Foundation + + // MARK: - Download Progress + + /// Reports progress of a model download. + public struct DownloadProgress: Sendable, Equatable { + /// Fraction of the download completed, from 0.0 to 1.0. + public var fractionCompleted: Double + + /// Number of bytes downloaded so far, if known. + public var completedBytes: Int64? + + /// Total expected bytes, if known. + public var totalBytes: Int64? + + /// Current download speed in bytes per second, if known. + public var bytesPerSecond: Double? + + public init( + fractionCompleted: Double, + completedBytes: Int64? = nil, + totalBytes: Int64? = nil, + bytesPerSecond: Double? = nil + ) { + self.fractionCompleted = fractionCompleted + self.completedBytes = completedBytes + self.totalBytes = totalBytes + self.bytesPerSecond = bytesPerSecond + } + } + + // MARK: - Download State + + /// The on-disk state of a downloadable model. + public enum ModelDownloadState: Sendable, Equatable { + /// The model files are not present on disk. + case notDownloaded + /// The model is currently being downloaded. + case downloading(DownloadProgress) + /// The model is fully downloaded and ready to load. + case downloaded + + public static func == (lhs: ModelDownloadState, rhs: ModelDownloadState) -> Bool { + switch (lhs, rhs) { + case (.notDownloaded, .notDownloaded): + return true + case (.downloaded, .downloaded): + return true + case (.downloading(let a), .downloading(let b)): + return a == b + default: + return false + } + } + } + + // MARK: - DownloadableLanguageModel Protocol + + /// A language model whose weights can be downloaded, inspected, and deleted. + /// + /// Backends that run locally (e.g. MLX) conform to this protocol to expose + /// download management. API-only backends (OpenAI, Anthropic, etc.) do not + /// need to conform — consumers can check conformance with: + /// + /// ```swift + /// if let downloadable = model as? any DownloadableLanguageModel { ... } + /// ``` + public protocol DownloadableLanguageModel: LanguageModel { + /// Whether the model's files are fully present on disk. + var isDownloaded: Bool { get } + + /// The current download state of this model. + var downloadState: ModelDownloadState { get } + + /// Starts downloading the model and returns a stream of progress updates. + /// + /// If the model is already downloaded, the stream completes immediately. + /// Cancelling the consuming `Task` cancels the download. + func download() -> AsyncStream + + /// Removes the downloaded model files from disk. + /// + /// Also cancels any in-flight download and removes the model from the + /// in-memory cache. + func deleteDownload() async throws + + /// The total size of the downloaded model on disk, in bytes. + /// + /// Returns `nil` if the model is not downloaded or the size cannot be determined. + var downloadedSizeOnDisk: Int64? { get } + } +#endif diff --git a/Sources/AnyLanguageModel/GenerationOptions.swift b/Sources/AnyLanguageModel/GenerationOptions.swift index 5d0f633..17da525 100644 --- a/Sources/AnyLanguageModel/GenerationOptions.swift +++ b/Sources/AnyLanguageModel/GenerationOptions.swift @@ -121,6 +121,25 @@ public struct GenerationOptions: Sendable, Equatable, Codable { /// an error will be thrown. public var maximumResponseTokens: Int? + /// Maximum number of tokens to retain in the KV cache. + /// + /// When set, uses a rotating cache that evicts oldest tokens beyond this limit. + /// When `nil` (default), the cache grows unbounded. + /// + /// Recommended values: 2048–4096 for iPhone, `nil` for Mac. + public var maxKVSize: Int? + + /// Bit width for KV cache quantization (for example, 4 or 8). + /// + /// Reduces cache memory usage at slight quality cost. + /// When `nil` (default), the cache uses full precision. + public var kvBits: Int? + + /// Group size for KV cache quantization. + /// + /// Only meaningful when ``kvBits`` is set. Default is 64. + public var kvGroupSize: Int + /// Storage for model-specific custom options. private var customOptionsStorage: CustomOptionsStorage = .init() @@ -157,14 +176,23 @@ public struct GenerationOptions: Sendable, Equatable, Codable { /// responses. Must be between `0` and `1`, inclusive. /// - maximumResponseTokens: The maximum number of tokens the model is allowed /// to produce before being artificially halted. Must be positive. + /// - maxKVSize: Maximum tokens in the KV cache. When set, enables a rotating cache. + /// - kvBits: Bit width for KV cache quantization. + /// - kvGroupSize: Group size for KV cache quantization. Default is 64. public init( sampling: SamplingMode? = nil, temperature: Double? = nil, - maximumResponseTokens: Int? = nil + maximumResponseTokens: Int? = nil, + maxKVSize: Int? = nil, + kvBits: Int? = nil, + kvGroupSize: Int = 64 ) { self.sampling = sampling self.temperature = temperature self.maximumResponseTokens = maximumResponseTokens + self.maxKVSize = maxKVSize + self.kvBits = kvBits + self.kvGroupSize = kvGroupSize } } diff --git a/Sources/AnyLanguageModel/LanguageModel.swift b/Sources/AnyLanguageModel/LanguageModel.swift index 635de68..9e96e96 100644 --- a/Sources/AnyLanguageModel/LanguageModel.swift +++ b/Sources/AnyLanguageModel/LanguageModel.swift @@ -14,7 +14,8 @@ public protocol LanguageModel: Sendable { func prewarm( for session: LanguageModelSession, - promptPrefix: Prompt? + promptPrefix: Prompt?, + tools: [any Tool]? ) func respond( @@ -39,6 +40,13 @@ public protocol LanguageModel: Sendable { issues: [LanguageModelFeedback.Issue], desiredOutput: Transcript.Entry? ) -> Data + + /// Invalidates any cached state associated with the given session. + /// + /// Called when a session's transcript is replaced. Models that maintain + /// per-session caches (such as MLX KV caches) should evict the entry + /// for the given session. The default implementation does nothing. + func invalidateCache(for session: LanguageModelSession) } // MARK: - Default Implementation @@ -54,7 +62,8 @@ extension LanguageModel { public func prewarm( for session: LanguageModelSession, - promptPrefix: Prompt? = nil + promptPrefix: Prompt? = nil, + tools: [any Tool]? = nil ) { return } @@ -67,6 +76,8 @@ extension LanguageModel { ) -> Data { return Data() } + + public func invalidateCache(for session: LanguageModelSession) {} } extension LanguageModel where UnavailableReason == Never { diff --git a/Sources/AnyLanguageModel/LanguageModelSession.swift b/Sources/AnyLanguageModel/LanguageModelSession.swift index ba38550..35b26a6 100644 --- a/Sources/AnyLanguageModel/LanguageModelSession.swift +++ b/Sources/AnyLanguageModel/LanguageModelSession.swift @@ -98,7 +98,7 @@ public final class LanguageModelSession: @unchecked Sendable { } public func prewarm(promptPrefix: Prompt? = nil) { - model.prewarm(for: self, promptPrefix: promptPrefix) + model.prewarm(for: self, promptPrefix: promptPrefix, tools: tools.isEmpty ? nil : tools) } nonisolated private func beginResponding() { @@ -914,6 +914,58 @@ private enum ResponseStreamError: Error { case noSnapshots } +// MARK: - Transcript Snapshot and Replacement + +extension LanguageModelSession { + /// An error that occurs when attempting to replace a transcript. + public enum TranscriptError: Error { + /// The session is currently generating a response. + case respondingInProgress + } + + /// Captures the current transcript for later restoration. + /// + /// It is safe to call this while the session is responding, but the + /// returned transcript may include a prompt entry without a corresponding + /// response. For deterministic snapshots, call this when ``isResponding`` + /// is `false`. + nonisolated public func snapshotTranscript() -> Transcript { + state.withLock { $0.transcript } + } + + /// Replaces the session's transcript and invalidates any associated + /// backend cache (for example, the MLX KV cache). + /// + /// Use this to temporarily swap a session's context for a different + /// task (such as running a tool agent), then restore the original + /// transcript afterward. The next call to ``respond(to:options:)`` + /// will rebuild any caches from the new transcript. + /// + /// - Parameter transcript: The transcript to install. + /// - Throws: ``TranscriptError/respondingInProgress`` if the session + /// is currently generating a response. + nonisolated public func replaceTranscript(_ transcript: Transcript) throws { + // Atomic check-and-mutate: verify not responding and swap transcript + // in a single lock acquisition to prevent TOCTOU races. + let didReplace: Bool = state.withLock { lockedState in + guard !lockedState.isResponding else { return false } + lockedState.transcript = transcript + return true + } + + guard didReplace else { + throw TranscriptError.respondingInProgress + } + + // Notify observation after releasing the lock + withMutation(keyPath: \.transcript) {} + + // Invalidate backend caches (e.g. MLX KV cache) outside the lock + // to avoid lock inversion if the model calls back into the session. + model.invalidateCache(for: self) + } +} + // MARK: - private struct State: Equatable, Sendable { diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index f4be593..6a05257 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -17,6 +17,144 @@ import Foundation import MLXVLM import Tokenizers import Hub + import Observation + + // MARK: - GPU Memory Configuration + + /// Controls Metal buffer pool behavior during and between MLX inference. + /// + /// MLX maintains a recycled buffer pool to avoid repeated Metal allocations. + /// This configuration sets the pool size during active inference (`activeCacheLimit`) + /// and between generations (`idleCacheLimit`). + /// + /// ```swift + /// // Automatic (scaled by device RAM): + /// let model = MLXLanguageModel(modelId: "...", gpuMemory: .automatic) + /// + /// // Custom: + /// let model = MLXLanguageModel(modelId: "...", gpuMemory: .init( + /// activeCacheLimit: 256_000_000, + /// idleCacheLimit: 50_000_000 + /// )) + /// ``` + public struct GPUMemoryConfiguration: Sendable, Equatable { + /// Maximum Metal buffer cache size in bytes during active inference. + public var activeCacheLimit: Int + + /// Maximum Metal buffer cache size in bytes when no inference is running. + public var idleCacheLimit: Int + + /// Whether to call `Memory.clearCache()` when a model is evicted. + public var clearCacheOnEviction: Bool + + public init( + activeCacheLimit: Int, + idleCacheLimit: Int, + clearCacheOnEviction: Bool = true + ) { + self.activeCacheLimit = activeCacheLimit + self.idleCacheLimit = idleCacheLimit + self.clearCacheOnEviction = clearCacheOnEviction + } + + /// Scaled by device RAM. Idle: 50 MB. Clear cache on eviction. + /// + /// Active limits: <4 GB → 128 MB, <6 GB → 256 MB, <8 GB → 512 MB, 8+ GB → 768 MB. + public static var automatic: GPUMemoryConfiguration { + let ramBytes = ProcessInfo.processInfo.physicalMemory + let ramGB = ramBytes / (1024 * 1024 * 1024) + + let active: Int + switch ramGB { + case ..<4: + active = 128_000_000 + case ..<6: + active = 256_000_000 + case ..<8: + active = 512_000_000 + default: + active = 768_000_000 + } + + return GPUMemoryConfiguration( + activeCacheLimit: active, + idleCacheLimit: 50_000_000, + clearCacheOnEviction: true + ) + } + + /// No management — MLX defaults, unbounded cache. + public static var unconstrained: GPUMemoryConfiguration { + GPUMemoryConfiguration( + activeCacheLimit: Int.max, + idleCacheLimit: Int.max, + clearCacheOnEviction: false + ) + } + } + + // MARK: - GPU Memory Manager + + /// Reference-counted active/idle toggling for the global Metal buffer cache. + /// + /// Multiple sessions can generate concurrently. The cache stays at `activeCacheLimit` + /// as long as ANY session is generating, and drops to `idleCacheLimit` only when ALL + /// sessions complete. + private final class GPUMemoryManager: @unchecked Sendable { + static let shared = GPUMemoryManager() + + private let lock = NSLock() + private var activeCount = 0 + private var config: GPUMemoryConfiguration = .automatic + private var hasCustomConfig = false + + private init() { + Memory.cacheLimit = config.idleCacheLimit + } + + /// Applies a GPU memory configuration. First custom configuration wins — + /// subsequent calls with a different configuration are ignored to prevent + /// multiple MLXLanguageModel instances from silently overwriting each other. + func configure(_ configuration: GPUMemoryConfiguration) { + lock.withLock { + if hasCustomConfig && config != configuration { + return + } + hasCustomConfig = true + config = configuration + if activeCount == 0 { + Memory.cacheLimit = configuration.idleCacheLimit + } + } + } + + func markActive() { + lock.withLock { + if activeCount == 0 { + Memory.cacheLimit = config.activeCacheLimit + } + activeCount += 1 + } + } + + func markIdle() { + lock.withLock { + activeCount = max(0, activeCount - 1) + if activeCount == 0 { + Memory.cacheLimit = config.idleCacheLimit + } + } + } + + func evict() { + lock.withLock { + Memory.cacheLimit = config.idleCacheLimit + if config.clearCacheOnEviction { + Memory.clearCache() + } + } + } + } /// Wrapper to store model availability state in NSCache. private final class CachedModelState: NSObject, @unchecked Sendable { @@ -172,6 +310,86 @@ import Foundation /// Shared cache across MLXLanguageModel instances. private nonisolated(unsafe) let modelCache = ModelContextCache(countLimit: 3) + // MARK: - Session KV Cache Store + + /// Stores a KV cache and its prefill token count for a session. + private final class SessionCacheEntry: NSObject, @unchecked Sendable { + var kvCache: [MLXLMCommon.KVCache] + var prefillTokenCount: Int + var prefillTokenHash: Int + + init(kvCache: [MLXLMCommon.KVCache], prefillTokenCount: Int, prefillTokenHash: Int) { + self.kvCache = kvCache + self.prefillTokenCount = prefillTokenCount + self.prefillTokenHash = prefillTokenHash + } + } + + /// Maps LanguageModelSession (weak key) → SessionCacheEntry. + /// When a session is deallocated, its cache entry is automatically released. + private nonisolated(unsafe) let sessionKVCache = NSMapTable.weakToStrongObjects() + private let sessionKVCacheLock = NSLock() + + private func getSessionCache(_ session: LanguageModelSession) -> SessionCacheEntry? { + sessionKVCacheLock.lock() + defer { sessionKVCacheLock.unlock() } + return sessionKVCache.object(forKey: session) + } + + private func setSessionCache(_ entry: SessionCacheEntry, for session: LanguageModelSession) { + sessionKVCacheLock.lock() + defer { sessionKVCacheLock.unlock() } + sessionKVCache.setObject(entry, forKey: session) + } + + private func removeSessionCache(for session: LanguageModelSession) { + sessionKVCacheLock.lock() + defer { sessionKVCacheLock.unlock() } + sessionKVCache.removeObject(forKey: session) + } + + /// Hashes up to the first `count` tokens of an MLXArray for cache identity checks. + private func hashTokenPrefix(_ tokens: MLXArray, count: Int = 64) -> Int { + let tokenCount = tokens.dim(0) + let n = min(count, tokenCount) + guard n > 0 else { return 0 } + let prefix = tokens[0.. (cache: [MLXLMCommon.KVCache], input: MLXLMCommon.LMInput) { + let existingEntry = getSessionCache(session) + let fullTokenCount = lmInput.text.tokens.dim(0) + let currentHash = hashTokenPrefix(lmInput.text.tokens) + + if let existingEntry, + existingEntry.prefillTokenCount > 0, + fullTokenCount > existingEntry.prefillTokenCount, + existingEntry.prefillTokenHash == currentHash, + lmInput.image == nil + { + let cachedCount = existingEntry.prefillTokenCount + let newTokens = lmInput.text.tokens[cachedCount...] + let partialText = MLXLMCommon.LMInput.Text(tokens: newTokens) + return (cache: existingEntry.kvCache, input: MLXLMCommon.LMInput(text: partialText)) + } + + if existingEntry != nil { + removeSessionCache(for: session) + } + let freshCache = context.model.newCache(parameters: generateParameters) + return (cache: freshCache, input: lmInput) + } + // MARK: - MLXLanguageModel /// A language model that runs locally using MLX. @@ -200,16 +418,22 @@ import Foundation /// The local directory containing the model files. public let directory: URL? + /// GPU memory management configuration for Metal buffer pools. + public let gpuMemory: GPUMemoryConfiguration + /// Creates an MLX language model. /// /// - Parameters: /// - modelId: The model identifier (for example, "mlx-community/Llama-3.2-3B-Instruct-4bit"). /// - hub: An optional Hub API instance for downloading models. If not provided, the default Hub API is used. /// - directory: An optional local directory URL containing the model files. If provided, the model is loaded from this directory instead of downloading. - public init(modelId: String, hub: HubApi? = nil, directory: URL? = nil) { + /// - gpuMemory: GPU memory configuration. Defaults to `.automatic` which scales by device RAM. + public init(modelId: String, hub: HubApi? = nil, directory: URL? = nil, gpuMemory: GPUMemoryConfiguration = .automatic) { self.modelId = modelId self.hub = hub self.directory = directory + self.gpuMemory = gpuMemory + GPUMemoryManager.shared.configure(gpuMemory) } /// The current availability of this model in memory. @@ -233,11 +457,20 @@ import Foundation public func removeFromCache() async { let key = directory?.absoluteString ?? modelId await modelCache.removeAndCancel(for: key) + GPUMemoryManager.shared.evict() } /// Removes all MLX models from the shared cache and cancels in-flight loads. public static func removeAllFromCache() async { await modelCache.removeAllAndCancel() + sessionKVCacheLock.withLock { + sessionKVCache.removeAllObjects() + } + GPUMemoryManager.shared.evict() + } + + public func invalidateCache(for session: LanguageModelSession) { + removeSessionCache(for: session) } /// Get or load model context with caching @@ -249,7 +482,18 @@ import Foundation return try await loadModel(directory: directory) } - return try await loadModel(hub: hub ?? HubApi(), id: modelId) + let effectiveHub = hub ?? HubApi() + let context = try await loadModel(hub: effectiveHub, id: modelId) { progress in + let dp = DownloadProgress( + fractionCompleted: progress.fractionCompleted, + completedBytes: progress.completedUnitCount, + totalBytes: progress.totalUnitCount > 0 ? progress.totalUnitCount : nil, + bytesPerSecond: progress.userInfo[.throughputKey] as? Double + ) + MLXModelDownloadManager.shared.updateState(for: modelId, to: .downloading(dp)) + } + MLXModelDownloadManager.shared.updateState(for: modelId, to: .downloaded) + return context } } @@ -263,6 +507,9 @@ import Foundation // Get cached or load fresh ModelContext let context = try await loadContext(modelId: modelId, hub: hub, directory: directory) + GPUMemoryManager.shared.markActive() + defer { GPUMemoryManager.shared.markIdle() } + if type != String.self { let jsonString = try await generateStructuredJSON( context: context, @@ -298,8 +545,24 @@ import Foundation var allTextChunks: [String] = [] var allEntries: [Transcript.Entry] = [] + // Track the KV cache across the tool-calling loop. + // On the first iteration we try to reuse the session's cached KV state; + // on subsequent iterations (tool results added) we must rebuild. + var isFirstIteration = true + + // Guard against infinite tool-call loops (e.g. model keeps retrying the + // same tool call). After this many iterations, break and return whatever + // text has been accumulated. + let maxToolIterations = 5 + var toolIteration = 0 + var previousToolCallSignature: String? + // Loop until no more tool calls while true { + toolIteration += 1 + if toolIteration > maxToolIterations { + break + } // Build user input with current chat history and tools let userInput = MLXLMCommon.UserInput( chat: chat, @@ -308,9 +571,31 @@ import Foundation ) let lmInput = try await context.processor.prepare(input: userInput) + // Determine cache and input for generation + let cache: [MLXLMCommon.KVCache] + let inputForGeneration: MLXLMCommon.LMInput + + if isFirstIteration { + let resolved = resolveCache( + for: session, + lmInput: lmInput, + generateParameters: generateParameters, + context: context + ) + cache = resolved.cache + inputForGeneration = resolved.input + } else { + // Tool-calling iterations: fresh cache (chat has been mutated) + cache = context.model.newCache(parameters: generateParameters) + inputForGeneration = lmInput + } + + isFirstIteration = false + // Generate let stream = try MLXLMCommon.generate( - input: lmInput, + input: inputForGeneration, + cache: cache, parameters: generateParameters, context: context ) @@ -329,8 +614,16 @@ import Foundation } } + // Update session cache with current offset after generation + let currentOffset = cache.first?.offset ?? 0 + let cacheEntry = SessionCacheEntry( + kvCache: cache, + prefillTokenCount: currentOffset, + prefillTokenHash: hashTokenPrefix(lmInput.text.tokens) + ) + setSessionCache(cacheEntry, for: session) + let assistantText = chunks.joined() - allTextChunks.append(assistantText) // Add assistant response to chat history if !assistantText.isEmpty { @@ -339,6 +632,29 @@ import Foundation // If there are tool calls, execute them and continue if !collectedToolCalls.isEmpty { + // Detect repeated tool calls — if the model generates the exact + // same tool call(s) as the previous iteration, it's stuck in a + // loop. Break and return whatever text we have so far. + let signature = collectedToolCalls + .map { "\($0.function.name):\($0.function.arguments)" } + .joined(separator: "|") + if signature == previousToolCallSignature { + allTextChunks.append(assistantText) + break + } + previousToolCallSignature = signature + + // Record the assistant text generated before the tool call + // as a transcript entry so convertTranscriptToMLXChat() can + // reproduce the exact same chat sequence on future turns + // (keeping the KV cache valid). + if !assistantText.isEmpty { + allEntries.append(.response(Transcript.Response( + assetIDs: [], + segments: [.text(.init(content: assistantText))] + ))) + } + let resolution = try await resolveToolCalls(collectedToolCalls, session: session) switch resolution { case .stop(let calls): @@ -354,13 +670,20 @@ import Foundation if !invocations.isEmpty { allEntries.append(.toolCalls(Transcript.ToolCalls(invocations.map(\.call)))) - // Execute each tool and add results to chat + // Execute each tool and fold results into the + // preceding assistant message to maintain strict + // user/assistant alternation for templates like Gemma 3. + var toolResults: [String] = [] for invocation in invocations { allEntries.append(.toolOutput(invocation.output)) - - // Convert tool output to JSON string for MLX - let toolResultJSON = toolOutputToJSON(invocation.output) - chat.append(.tool(toolResultJSON)) + toolResults.append(toolOutputToJSON(invocation.output)) + } + let combinedResults = toolResults.joined(separator: "\n") + if let lastIdx = chat.indices.last, chat[lastIdx].role == .assistant { + let existing = chat[lastIdx].content + chat[lastIdx] = .assistant(existing + "\n\n[Tool result]: " + combinedResults) + } else { + chat.append(.assistant("[Tool result]: " + combinedResults)) } // Continue loop to generate with tool results @@ -369,11 +692,13 @@ import Foundation } } - // No more tool calls, exit loop + // No more tool calls — this is the final response text + allTextChunks.append(assistantText) break } let text = allTextChunks.joined() + return LanguageModelSession.Response( content: text as! Content, rawContent: GeneratedContent(text), @@ -398,6 +723,9 @@ import Foundation let stream: AsyncThrowingStream.Snapshot, any Error> = .init { continuation in + let didMarkIdle = Locked(false) + GPUMemoryManager.shared.markActive() + let task = Task { @Sendable in do { // Get cached or load fresh ModelContext @@ -414,8 +742,19 @@ import Foundation ) let lmInput = try await context.processor.prepare(input: userInput) + // Resolve KV cache for this session + let resolved = resolveCache( + for: session, + lmInput: lmInput, + generateParameters: generateParameters, + context: context + ) + let cache = resolved.cache + let inputForGeneration = resolved.input + let mlxStream = try MLXLMCommon.generate( - input: lmInput, + input: inputForGeneration, + cache: cache, parameters: generateParameters, context: context ) @@ -436,29 +775,88 @@ import Foundation } } + // Update the session cache with current offset + let currentOffset = cache.first?.offset ?? 0 + let entry = SessionCacheEntry( + kvCache: cache, + prefillTokenCount: currentOffset, + prefillTokenHash: hashTokenPrefix(lmInput.text.tokens) + ) + setSessionCache(entry, for: session) + + didMarkIdle.withLock { done in + if !done { GPUMemoryManager.shared.markIdle(); done = true } + } continuation.finish() } catch { + didMarkIdle.withLock { done in + if !done { GPUMemoryManager.shared.markIdle(); done = true } + } continuation.finish(throwing: error) } } - continuation.onTermination = { _ in task.cancel() } + continuation.onTermination = { _ in + didMarkIdle.withLock { done in + if !done { GPUMemoryManager.shared.markIdle(); done = true } + } + task.cancel() + } } return LanguageModelSession.ResponseStream(stream: stream) } - /// Prewarms the model + /// Prewarms the model by loading it and optionally prefilling the system prompt into a KV cache. + /// + /// - Parameters: + /// - session: The session whose instructions will be prefilled. + /// - promptPrefix: An optional prompt prefix (reserved for future use). + /// - tools: Tools that will be used with this session. Pass the same tools here + /// so the prefilled cache includes tool definitions in its tokenization, + /// avoiding a cache miss on the first real request. public func prewarm( for session: LanguageModelSession, - promptPrefix: Prompt? + promptPrefix: Prompt?, + tools: [any Tool]? = nil ) { let modelId = self.modelId let hub = self.hub let directory = self.directory Task { + GPUMemoryManager.shared.markActive() + defer { GPUMemoryManager.shared.markIdle() } + do { - _ = try await loadContext(modelId: modelId, hub: hub, directory: directory) + let context = try await loadContext(modelId: modelId, hub: hub, directory: directory) + + // Prefill the system prompt into a KV cache so the first turn is faster + if let instructions = session.instructions?.description, !instructions.isEmpty { + let params = MLXLMCommon.GenerateParameters() + let newCache = context.model.newCache(parameters: params) + let chat: [MLXLMCommon.Chat.Message] = [.init(role: .system, content: instructions)] + + // Convert tools to MLX ToolSpec format so the prefill tokenization + // matches what respond() will produce, ensuring cache hits. + let toolSpecs: [ToolSpec]? = tools.flatMap { toolList in + toolList.isEmpty ? nil : toolList.map { convertToolToMLXSpec($0) } + } + + let userInput = MLXLMCommon.UserInput( + chat: chat, + processing: .init(resize: .init(width: 512, height: 512)), + tools: toolSpecs + ) + let lmInput = try await context.processor.prepare(input: userInput) + _ = try context.model.prepare(lmInput, cache: newCache, windowSize: nil) + + let entry = SessionCacheEntry( + kvCache: newCache, + prefillTokenCount: newCache.first?.offset ?? 0, + prefillTokenHash: hashTokenPrefix(lmInput.text.tokens) + ) + setSessionCache(entry, for: session) + } } catch { // Ignore errors during prewarm } @@ -471,9 +869,9 @@ import Foundation private func toGenerateParameters(_ options: GenerationOptions) -> MLXLMCommon.GenerateParameters { MLXLMCommon.GenerateParameters( maxTokens: options.maximumResponseTokens, - maxKVSize: nil, - kvBits: nil, - kvGroupSize: 64, + maxKVSize: options.maxKVSize, + kvBits: options.kvBits, + kvGroupSize: options.kvGroupSize, quantizedKVStart: 0, temperature: Float(options.temperature ?? 0.6), topP: 1.0, @@ -519,7 +917,10 @@ import Foundation chat.append(.init(role: .system, content: instructions)) } - // Convert each transcript entry + // Convert each transcript entry. + // Tool call/output entries are folded into adjacent assistant messages + // to maintain strict user/assistant alternation required by some templates + // (e.g. Gemma 3 which has no tool role support in its Jinja template). for entry in session.transcript { switch entry { case .instructions(let instr): @@ -533,12 +934,23 @@ import Foundation chat.append(.assistant(content)) case .toolCalls: - // Tool calls are handled inline during generation loop + // Skip — tool calls were already executed; the output is captured below break case .toolOutput(let toolOutput): - let content = toolOutput.segments.map { extractText(from: $0) }.joined(separator: "\n") - chat.append(.tool(content)) + // Fold tool output into the preceding assistant message to avoid + // injecting a .tool() role that breaks strict-alternation templates. + // Use toolOutputToJSON() — the same function used in the live tool + // loop — so the replayed chat produces identical tokens, keeping + // the KV cache valid for session continuation. + let content = toolOutputToJSON(toolOutput) + if let lastIdx = chat.indices.last, chat[lastIdx].role == .assistant { + let existing = chat[lastIdx].content + chat[lastIdx] = .assistant(existing + "\n\n[Tool result]: " + content) + } else { + // No preceding assistant message — wrap as assistant + chat.append(.assistant("[Tool result]: " + content)) + } } } @@ -654,7 +1066,7 @@ import Foundation } private func convertToSendableJSONValue(_ value: Any) throws -> any Sendable { - if value is NSNull { return MLXLMCommon.JSONValue.null } + if value is NSNull { return NSNull() } if let stringValue = value as? String { return stringValue } if let boolValue = value as? Bool { return boolValue } if let intValue = value as? Int { return intValue } @@ -663,7 +1075,7 @@ import Foundation return numberValue.doubleValue } if let arrayValue = value as? [Any] { - return try arrayValue.map { try convertToSendableJSONValue($0) } + return try arrayValue.map { try convertToSendableJSONValue($0) } as [any Sendable] } if let dictionaryValue = value as? [String: Any] { return try convertToSendableJSONObject(dictionaryValue) @@ -1120,4 +1532,225 @@ import Foundation return sampledToken.item(Int.self) } } + + // MARK: - MLXModelDownloadManager + + /// Manages download state for MLX models. + /// + /// `@Observable` class (not actor) for direct SwiftUI binding. Thread safety + /// is provided by `Locked<>` for the state dictionary. + @Observable + public final class MLXModelDownloadManager: @unchecked Sendable { + public static let shared = MLXModelDownloadManager() + + /// Current download state per model ID. + public private(set) var states: [String: ModelDownloadState] = [:] + + /// In-flight download tasks, keyed by model ID. + private let inFlightDownloads = Locked<[String: Task]>([:]) + + /// Lock for thread-safe state mutations that trigger @Observable updates. + private let stateLock = NSLock() + + private init() {} + + // MARK: - State Queries + + /// Returns the current download state for a model. + /// + /// On first access for a given model, checks disk for existing files. + public func state(for modelId: String, hub: HubApi = HubApi()) -> ModelDownloadState { + stateLock.lock() + if let cached = states[modelId] { + stateLock.unlock() + return cached + } + stateLock.unlock() + + // Check disk + let repo = Hub.Repo(id: modelId) + let repoDir = hub.localRepoLocation(repo) + let configPath = repoDir.appendingPathComponent("config.json") + + let isOnDisk = FileManager.default.fileExists(atPath: configPath.path) + let state: ModelDownloadState = isOnDisk ? .downloaded : .notDownloaded + + stateLock.lock() + // Don't overwrite an in-progress download + if states[modelId] == nil { + withMutation(keyPath: \.states) { + states[modelId] = state + } + } + stateLock.unlock() + + return states[modelId] ?? state + } + + /// Updates the download state for a model. Called from download handlers. + public func updateState(for modelId: String, to newState: ModelDownloadState) { + stateLock.lock() + withMutation(keyPath: \.states) { + states[modelId] = newState + } + stateLock.unlock() + } + + /// Clears cached state for a model, forcing a fresh disk check on next access. + public func invalidateState(for modelId: String) { + stateLock.lock() + withMutation(keyPath: \.states) { + states[modelId] = nil + } + stateLock.unlock() + } + + // MARK: - Download + + /// Downloads a model and returns a stream of progress updates. + /// + /// If the model is already downloaded, the stream completes immediately. + /// Cancelling the consuming `Task` cancels the download (snapshot checks + /// `Task.isCancelled` between files). + public func download(modelId: String, hub: HubApi = HubApi()) -> AsyncStream { + let currentState = state(for: modelId, hub: hub) + if case .downloaded = currentState { + return AsyncStream { $0.finish() } + } + + return AsyncStream { continuation in + let task = Task { [weak self] in + guard let self else { + continuation.finish() + return + } + do { + let repo = Hub.Repo(id: modelId) + try await hub.snapshot( + from: repo, + matching: ["*.safetensors", "*.json", "tokenizer.model"] + ) { progress, speed in + let dp = DownloadProgress( + fractionCompleted: progress.fractionCompleted, + completedBytes: progress.completedUnitCount, + totalBytes: progress.totalUnitCount > 0 ? progress.totalUnitCount : nil, + bytesPerSecond: speed + ) + self.updateState(for: modelId, to: .downloading(dp)) + continuation.yield(dp) + } + self.updateState(for: modelId, to: .downloaded) + continuation.finish() + } catch { + self.updateState(for: modelId, to: .notDownloaded) + continuation.finish() + } + self.inFlightDownloads.withLock { $0[modelId] = nil } + } + + inFlightDownloads.withLock { $0[modelId] = task } + + continuation.onTermination = { [weak self] _ in + self?.inFlightDownloads.withLock { downloads in + downloads[modelId]?.cancel() + downloads[modelId] = nil + } + } + } + } + + // MARK: - Delete + + /// Removes downloaded model files from disk and resets state. + /// + /// Also cancels any in-flight download and evicts the model from + /// the in-memory model cache. + public func deleteDownload(modelId: String, hub: HubApi = HubApi()) async throws { + // Cancel in-flight download + inFlightDownloads.withLock { downloads in + downloads[modelId]?.cancel() + downloads[modelId] = nil + } + + // Remove from in-memory model cache + let cacheModel = MLXLanguageModel(modelId: modelId, hub: hub) + await cacheModel.removeFromCache() + + // Delete files on disk + let repo = Hub.Repo(id: modelId) + let repoDir = hub.localRepoLocation(repo) + if FileManager.default.fileExists(atPath: repoDir.path) { + try FileManager.default.removeItem(at: repoDir) + } + + updateState(for: modelId, to: .notDownloaded) + } + + // MARK: - Disk Size + + /// Returns the total size of the downloaded model on disk, in bytes. + public func downloadedSize(for modelId: String, hub: HubApi = HubApi()) -> Int64? { + let repo = Hub.Repo(id: modelId) + let repoDir = hub.localRepoLocation(repo) + guard FileManager.default.fileExists(atPath: repoDir.path) else { return nil } + return Self.directorySize(at: repoDir) + } + + static func directorySize(at url: URL) -> Int64 { + let fm = FileManager.default + guard let enumerator = fm.enumerator( + at: url, + includingPropertiesForKeys: [.fileSizeKey, .isDirectoryKey], + options: [.skipsHiddenFiles] + ) else { return 0 } + + var totalSize: Int64 = 0 + for case let fileURL as URL in enumerator { + guard let values = try? fileURL.resourceValues(forKeys: [.fileSizeKey, .isDirectoryKey]), + values.isDirectory != true, + let size = values.fileSize + else { continue } + totalSize += Int64(size) + } + return totalSize + } + } + + // MARK: - DownloadableLanguageModel Conformance + + extension MLXLanguageModel: DownloadableLanguageModel { + public var isDownloaded: Bool { + if directory != nil { return true } + return MLXModelDownloadManager.shared.state(for: modelId, hub: hub ?? HubApi()) == .downloaded + } + + public var downloadState: ModelDownloadState { + if directory != nil { return .downloaded } + return MLXModelDownloadManager.shared.state(for: modelId, hub: hub ?? HubApi()) + } + + public func download() -> AsyncStream { + if directory != nil { + return AsyncStream { $0.finish() } + } + return MLXModelDownloadManager.shared.download(modelId: modelId, hub: hub ?? HubApi()) + } + + public func deleteDownload() async throws { + if directory != nil { return } + try await MLXModelDownloadManager.shared.deleteDownload(modelId: modelId, hub: hub ?? HubApi()) + } + + public var downloadedSizeOnDisk: Int64? { + if directory != nil { return MLXModelDownloadManager.directorySize(at: directory!) } + return MLXModelDownloadManager.shared.downloadedSize(for: modelId, hub: hub ?? HubApi()) + } + } + + // MARK: - Hub Re-exports + + /// Re-export Hub types so consumers can configure custom HubApi instances + /// (download location, HF token) without importing Hub directly. + public typealias HubRepo = Hub.Repo + public typealias HubRepoType = Hub.RepoType #endif // MLX