Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 82 additions & 45 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,60 +20,38 @@ pub struct StaticModel {
unk_token_id: Option<usize>,
}

#[derive(Debug, Clone)]
struct ModelFiles {
tokenizer: std::path::PathBuf,
model: std::path::PathBuf,
config: std::path::PathBuf,
}

impl StaticModel {
/// Load a Model2Vec model from a local folder or the HuggingFace Hub.
/// Load a Model2Vec model directly from in-memory bytes.
///
/// # Arguments
/// * `repo_or_path` - HuggingFace repo ID or local path to the model folder.
/// * `token` - Optional HuggingFace token for authenticated downloads.
/// * `normalize` - Optional flag to normalize embeddings (default from config.json).
/// * `subfolder` - Optional subfolder within the repo or path to look for model files.
pub fn from_pretrained<P: AsRef<Path>>(
repo_or_path: P,
token: Option<&str>,
/// This path is useful for runtimes that fetch model assets as bytes
/// rather than reading them from a local filesystem.
pub fn from_bytes<T, M, C>(
tokenizer_bytes: T,
model_bytes: M,
config_bytes: C,
normalize: Option<bool>,
subfolder: Option<&str>,
) -> Result<Self> {
// If provided, set HF token for authenticated downloads
if let Some(tok) = token {
env::set_var("HF_HUB_TOKEN", tok);
}

// Locate tokenizer.json, model.safetensors, config.json
let (tok_path, mdl_path, cfg_path) = {
let base = repo_or_path.as_ref();
if base.exists() {
let folder = subfolder.map(|s| base.join(s)).unwrap_or_else(|| base.to_path_buf());
let t = folder.join("tokenizer.json");
let m = folder.join("model.safetensors");
let c = folder.join("config.json");
if !t.exists() || !m.exists() || !c.exists() {
return Err(anyhow!("local path {folder:?} missing tokenizer / model / config"));
}
(t, m, c)
} else {
let api = Api::new().context("hf-hub API init failed")?;
let repo = api.model(repo_or_path.as_ref().to_string_lossy().into_owned());
let prefix = subfolder.map(|s| format!("{}/", s)).unwrap_or_default();
let t = repo.get(&format!("{prefix}tokenizer.json"))?;
let m = repo.get(&format!("{prefix}model.safetensors"))?;
let c = repo.get(&format!("{prefix}config.json"))?;
(t, m, c)
}
};

// Load the tokenizer
let tokenizer = Tokenizer::from_file(&tok_path).map_err(|e| anyhow!("failed to load tokenizer: {e}"))?;
) -> Result<Self>
where
T: AsRef<[u8]>,
M: AsRef<[u8]>,
C: AsRef<[u8]>,
{
let tokenizer = Tokenizer::from_bytes(tokenizer_bytes).map_err(|e| anyhow!("failed to load tokenizer: {e}"))?;

// Read normalize default from config.json
let cfg_file = std::fs::File::open(&cfg_path).context("failed to read config.json")?;
let cfg: Value = serde_json::from_reader(&cfg_file).context("failed to parse config.json")?;
let cfg: Value = serde_json::from_slice(config_bytes.as_ref()).context("failed to parse config.json")?;
let cfg_norm = cfg.get("normalize").and_then(Value::as_bool).unwrap_or(true);
let normalize = normalize.unwrap_or(cfg_norm);

// Load the safetensors
let model_bytes = fs::read(&mdl_path).context("failed to read model.safetensors")?;
let safet = SafeTensors::deserialize(&model_bytes).context("failed to parse safetensors")?;
let safet = SafeTensors::deserialize(model_bytes.as_ref()).context("failed to parse safetensors")?;
let tensor = safet
.tensor("embeddings")
.or_else(|_| safet.tensor("0"))
Expand Down Expand Up @@ -137,6 +115,26 @@ impl StaticModel {
Self::from_owned(tokenizer, floats, rows, cols, normalize, weights, token_mapping)
}

/// Load a Model2Vec model from a local folder or the HuggingFace Hub.
///
/// # Arguments
/// * `repo_or_path` - HuggingFace repo ID or local path to the model folder.
/// * `token` - Optional HuggingFace token for authenticated downloads.
/// * `normalize` - Optional flag to normalize embeddings (default from config.json).
/// * `subfolder` - Optional subfolder within the repo or path to look for model files.
pub fn from_pretrained<P: AsRef<Path>>(
repo_or_path: P,
token: Option<&str>,
normalize: Option<bool>,
subfolder: Option<&str>,
) -> Result<Self> {
let files = resolve_model_files(repo_or_path, token, subfolder)?;
let tokenizer_bytes = fs::read(&files.tokenizer).context("failed to read tokenizer.json")?;
let model_bytes = fs::read(&files.model).context("failed to read model.safetensors")?;
let config_bytes = fs::read(&files.config).context("failed to read config.json")?;
Self::from_bytes(tokenizer_bytes, model_bytes, config_bytes, normalize)
}

/// Construct from owned data.
///
/// # Arguments
Expand Down Expand Up @@ -375,3 +373,42 @@ impl StaticModel {
sum
}
}

fn resolve_model_files<P: AsRef<Path>>(
repo_or_path: P,
token: Option<&str>,
subfolder: Option<&str>,
) -> Result<ModelFiles> {
if let Some(tok) = token {
env::set_var("HF_HUB_TOKEN", tok);
}

let (tokenizer, model, config) = {
let base = repo_or_path.as_ref();
if base.exists() {
let folder = subfolder.map(|s| base.join(s)).unwrap_or_else(|| base.to_path_buf());
let tokenizer = folder.join("tokenizer.json");
let model = folder.join("model.safetensors");
let config = folder.join("config.json");
if !tokenizer.exists() || !model.exists() || !config.exists() {
return Err(anyhow!("local path {folder:?} missing tokenizer / model / config"));
}
(tokenizer, model, config)
} else {
let api = Api::new().context("hf-hub API init failed")?;
let repo = api.model(repo_or_path.as_ref().to_string_lossy().into_owned());
let prefix = subfolder.map(|s| format!("{s}/")).unwrap_or_default();
(
repo.get(&format!("{prefix}tokenizer.json"))?,
repo.get(&format!("{prefix}model.safetensors"))?,
repo.get(&format!("{prefix}config.json"))?,
)
}
};

Ok(ModelFiles {
tokenizer,
model,
config,
})
}
27 changes: 26 additions & 1 deletion tests/test_model.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod common;
use common::load_test_model;
use model2vec_rs::model::StaticModel;
use std::fs;
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A module-level use std::fs; was added, but test_from_borrowed already imports std::fs inside the function. This duplication is easy to miss and can trigger lint noise; consider removing the inner import and using the module-level one consistently (or vice versa).

Copilot uses AI. Check for mistakes.

/// Test that encoding an empty input slice yields an empty output
#[test]
Expand Down Expand Up @@ -75,7 +76,6 @@ fn test_normalization_flag_override() {
#[test]
fn test_from_borrowed() {
use safetensors::SafeTensors;
use std::fs;
use tokenizers::Tokenizer;

let path = "tests/fixtures/test-model-float32";
Expand All @@ -97,3 +97,28 @@ fn test_from_borrowed() {
let emb = model.encode_single("hello");
assert!(!emb.is_empty());
}

#[test]
fn test_from_bytes_matches_from_pretrained_for_local_model() {
let path = "tests/fixtures/test-model-float32";
let from_path = StaticModel::from_pretrained(path, None, None, None).unwrap();
let from_bytes = StaticModel::from_bytes(
fs::read(format!("{path}/tokenizer.json")).unwrap(),
fs::read(format!("{path}/model.safetensors")).unwrap(),
fs::read(format!("{path}/config.json")).unwrap(),
None,
)
.unwrap();

let query = "hello world";
let path_embedding = from_path.encode_single(query);
let bytes_embedding = from_bytes.encode_single(query);

assert_eq!(path_embedding.len(), bytes_embedding.len());
for (left, right) in path_embedding.iter().zip(bytes_embedding.iter()) {
assert!(
(left - right).abs() < 1e-6,
"expected byte-loaded model to match path-loaded model"
);
}
}
Loading