Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 4 additions & 58 deletions Package.resolved

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ let package = Package(
.package(url: "https://github.com/mattt/JSONSchema", from: "1.3.0"),
.package(url: "https://github.com/mattt/llama.swift", .upToNextMajor(from: "2.7484.0")),
.package(url: "https://github.com/mattt/PartialJSONDecoder", from: "1.0.0"),
// mlx-swift-lm must be >= 2.25.5 for ToolSpec/tool calls and UserInput(chat:processing:tools:).
.package(url: "https://github.com/ml-explore/mlx-swift-lm", from: "2.25.5"),
// mlx-swift-lm >= 2.30.3 for fast SDPA, Gemma3n per-layer intermediate_size,
// cache race fix, Memory API, and chat rehydration. >= 2.25.5 for ToolSpec/tool calls.
.package(url: "https://github.com/ml-explore/mlx-swift-lm", from: "2.30.3"),
.package(url: "https://github.com/swiftlang/swift-syntax", from: "600.0.0"),
],
targets: [
Expand Down
93 changes: 93 additions & 0 deletions Sources/AnyLanguageModel/DownloadableLanguageModel.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#if MLX
import Foundation

// MARK: - Download Progress

/// Reports progress of a model download.
public struct DownloadProgress: Sendable, Equatable {
/// Fraction of the download completed, from 0.0 to 1.0.
public var fractionCompleted: Double

/// Number of bytes downloaded so far, if known.
public var completedBytes: Int64?

/// Total expected bytes, if known.
public var totalBytes: Int64?

/// Current download speed in bytes per second, if known.
public var bytesPerSecond: Double?

public init(
fractionCompleted: Double,
completedBytes: Int64? = nil,
totalBytes: Int64? = nil,
bytesPerSecond: Double? = nil
) {
self.fractionCompleted = fractionCompleted
self.completedBytes = completedBytes
self.totalBytes = totalBytes
self.bytesPerSecond = bytesPerSecond
}
}

// MARK: - Download State

/// The on-disk state of a downloadable model.
public enum ModelDownloadState: Sendable, Equatable {
/// The model files are not present on disk.
case notDownloaded
/// The model is currently being downloaded.
case downloading(DownloadProgress)
/// The model is fully downloaded and ready to load.
case downloaded

public static func == (lhs: ModelDownloadState, rhs: ModelDownloadState) -> Bool {
switch (lhs, rhs) {
case (.notDownloaded, .notDownloaded):
return true
case (.downloaded, .downloaded):
return true
case (.downloading(let a), .downloading(let b)):
return a == b
default:
return false
}
}
}

// MARK: - DownloadableLanguageModel Protocol

/// A language model whose weights can be downloaded, inspected, and deleted.
///
/// Backends that run locally (e.g. MLX) conform to this protocol to expose
/// download management. API-only backends (OpenAI, Anthropic, etc.) do not
/// need to conform — consumers can check conformance with:
///
/// ```swift
/// if let downloadable = model as? any DownloadableLanguageModel { ... }
/// ```
public protocol DownloadableLanguageModel: LanguageModel {
/// Whether the model's files are fully present on disk.
var isDownloaded: Bool { get }

/// The current download state of this model.
var downloadState: ModelDownloadState { get }

/// Starts downloading the model and returns a stream of progress updates.
///
/// If the model is already downloaded, the stream completes immediately.
/// Cancelling the consuming `Task` cancels the download.
func download() -> AsyncStream<DownloadProgress>

/// Removes the downloaded model files from disk.
///
/// Also cancels any in-flight download and removes the model from the
/// in-memory cache.
func deleteDownload() async throws

/// The total size of the downloaded model on disk, in bytes.
///
/// Returns `nil` if the model is not downloaded or the size cannot be determined.
var downloadedSizeOnDisk: Int64? { get }
}
#endif
30 changes: 29 additions & 1 deletion Sources/AnyLanguageModel/GenerationOptions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,25 @@ public struct GenerationOptions: Sendable, Equatable, Codable {
/// an error will be thrown.
public var maximumResponseTokens: Int?

/// Maximum number of tokens to retain in the KV cache.
///
/// When set, uses a rotating cache that evicts oldest tokens beyond this limit.
/// When `nil` (default), the cache grows unbounded.
///
/// Recommended values: 2048–4096 for iPhone, `nil` for Mac.
public var maxKVSize: Int?

/// Bit width for KV cache quantization (for example, 4 or 8).
///
/// Reduces cache memory usage at slight quality cost.
/// When `nil` (default), the cache uses full precision.
public var kvBits: Int?

/// Group size for KV cache quantization.
///
/// Only meaningful when ``kvBits`` is set. Default is 64.
public var kvGroupSize: Int

/// Storage for model-specific custom options.
private var customOptionsStorage: CustomOptionsStorage = .init()

Expand Down Expand Up @@ -157,14 +176,23 @@ public struct GenerationOptions: Sendable, Equatable, Codable {
/// responses. Must be between `0` and `1`, inclusive.
/// - maximumResponseTokens: The maximum number of tokens the model is allowed
/// to produce before being artificially halted. Must be positive.
/// - maxKVSize: Maximum tokens in the KV cache. When set, enables a rotating cache.
/// - kvBits: Bit width for KV cache quantization.
/// - kvGroupSize: Group size for KV cache quantization. Default is 64.
public init(
sampling: SamplingMode? = nil,
temperature: Double? = nil,
maximumResponseTokens: Int? = nil
maximumResponseTokens: Int? = nil,
maxKVSize: Int? = nil,
kvBits: Int? = nil,
kvGroupSize: Int = 64
) {
self.sampling = sampling
self.temperature = temperature
self.maximumResponseTokens = maximumResponseTokens
self.maxKVSize = maxKVSize
self.kvBits = kvBits
self.kvGroupSize = kvGroupSize
}
}

Expand Down
15 changes: 13 additions & 2 deletions Sources/AnyLanguageModel/LanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ public protocol LanguageModel: Sendable {

func prewarm(
for session: LanguageModelSession,
promptPrefix: Prompt?
promptPrefix: Prompt?,
tools: [any Tool]?
)

func respond<Content>(
Expand All @@ -39,6 +40,13 @@ public protocol LanguageModel: Sendable {
issues: [LanguageModelFeedback.Issue],
desiredOutput: Transcript.Entry?
) -> Data

/// Invalidates any cached state associated with the given session.
///
/// Called when a session's transcript is replaced. Models that maintain
/// per-session caches (such as MLX KV caches) should evict the entry
/// for the given session. The default implementation does nothing.
func invalidateCache(for session: LanguageModelSession)
}

// MARK: - Default Implementation
Expand All @@ -54,7 +62,8 @@ extension LanguageModel {

public func prewarm(
for session: LanguageModelSession,
promptPrefix: Prompt? = nil
promptPrefix: Prompt? = nil,
tools: [any Tool]? = nil
) {
return
}
Expand All @@ -67,6 +76,8 @@ extension LanguageModel {
) -> Data {
return Data()
}

public func invalidateCache(for session: LanguageModelSession) {}
}

extension LanguageModel where UnavailableReason == Never {
Expand Down
54 changes: 53 additions & 1 deletion Sources/AnyLanguageModel/LanguageModelSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ public final class LanguageModelSession: @unchecked Sendable {
}

public func prewarm(promptPrefix: Prompt? = nil) {
model.prewarm(for: self, promptPrefix: promptPrefix)
model.prewarm(for: self, promptPrefix: promptPrefix, tools: tools.isEmpty ? nil : tools)
}

nonisolated private func beginResponding() {
Expand Down Expand Up @@ -914,6 +914,58 @@ private enum ResponseStreamError: Error {
case noSnapshots
}

// MARK: - Transcript Snapshot and Replacement

extension LanguageModelSession {
/// An error that occurs when attempting to replace a transcript.
public enum TranscriptError: Error {
/// The session is currently generating a response.
case respondingInProgress
}

/// Captures the current transcript for later restoration.
///
/// It is safe to call this while the session is responding, but the
/// returned transcript may include a prompt entry without a corresponding
/// response. For deterministic snapshots, call this when ``isResponding``
/// is `false`.
nonisolated public func snapshotTranscript() -> Transcript {
state.withLock { $0.transcript }
}

/// Replaces the session's transcript and invalidates any associated
/// backend cache (for example, the MLX KV cache).
///
/// Use this to temporarily swap a session's context for a different
/// task (such as running a tool agent), then restore the original
/// transcript afterward. The next call to ``respond(to:options:)``
/// will rebuild any caches from the new transcript.
///
/// - Parameter transcript: The transcript to install.
/// - Throws: ``TranscriptError/respondingInProgress`` if the session
/// is currently generating a response.
nonisolated public func replaceTranscript(_ transcript: Transcript) throws {
// Atomic check-and-mutate: verify not responding and swap transcript
// in a single lock acquisition to prevent TOCTOU races.
let didReplace: Bool = state.withLock { lockedState in
guard !lockedState.isResponding else { return false }
lockedState.transcript = transcript
return true
}

guard didReplace else {
throw TranscriptError.respondingInProgress
}

// Notify observation after releasing the lock
withMutation(keyPath: \.transcript) {}

// Invalidate backend caches (e.g. MLX KV cache) outside the lock
// to avoid lock inversion if the model calls back into the session.
model.invalidateCache(for: self)
}
}

// MARK: -

private struct State: Equatable, Sendable {
Expand Down
Loading