diff --git a/src/model.rs b/src/model.rs index a4e22e9..d798f20 100644 --- a/src/model.rs +++ b/src/model.rs @@ -20,60 +20,38 @@ pub struct StaticModel { unk_token_id: Option, } +#[derive(Debug, Clone)] +struct ModelFiles { + tokenizer: std::path::PathBuf, + model: std::path::PathBuf, + config: std::path::PathBuf, +} + impl StaticModel { - /// Load a Model2Vec model from a local folder or the HuggingFace Hub. + /// Load a Model2Vec model directly from in-memory bytes. /// - /// # Arguments - /// * `repo_or_path` - HuggingFace repo ID or local path to the model folder. - /// * `token` - Optional HuggingFace token for authenticated downloads. - /// * `normalize` - Optional flag to normalize embeddings (default from config.json). - /// * `subfolder` - Optional subfolder within the repo or path to look for model files. - pub fn from_pretrained>( - repo_or_path: P, - token: Option<&str>, + /// This path is useful for runtimes that fetch model assets as bytes + /// rather than reading them from a local filesystem. + pub fn from_bytes( + tokenizer_bytes: T, + model_bytes: M, + config_bytes: C, normalize: Option, - subfolder: Option<&str>, - ) -> Result { - // If provided, set HF token for authenticated downloads - if let Some(tok) = token { - env::set_var("HF_HUB_TOKEN", tok); - } - - // Locate tokenizer.json, model.safetensors, config.json - let (tok_path, mdl_path, cfg_path) = { - let base = repo_or_path.as_ref(); - if base.exists() { - let folder = subfolder.map(|s| base.join(s)).unwrap_or_else(|| base.to_path_buf()); - let t = folder.join("tokenizer.json"); - let m = folder.join("model.safetensors"); - let c = folder.join("config.json"); - if !t.exists() || !m.exists() || !c.exists() { - return Err(anyhow!("local path {folder:?} missing tokenizer / model / config")); - } - (t, m, c) - } else { - let api = Api::new().context("hf-hub API init failed")?; - let repo = api.model(repo_or_path.as_ref().to_string_lossy().into_owned()); - let prefix = subfolder.map(|s| format!("{}/", s)).unwrap_or_default(); - let t = repo.get(&format!("{prefix}tokenizer.json"))?; - let m = repo.get(&format!("{prefix}model.safetensors"))?; - let c = repo.get(&format!("{prefix}config.json"))?; - (t, m, c) - } - }; - - // Load the tokenizer - let tokenizer = Tokenizer::from_file(&tok_path).map_err(|e| anyhow!("failed to load tokenizer: {e}"))?; + ) -> Result + where + T: AsRef<[u8]>, + M: AsRef<[u8]>, + C: AsRef<[u8]>, + { + let tokenizer = Tokenizer::from_bytes(tokenizer_bytes).map_err(|e| anyhow!("failed to load tokenizer: {e}"))?; // Read normalize default from config.json - let cfg_file = std::fs::File::open(&cfg_path).context("failed to read config.json")?; - let cfg: Value = serde_json::from_reader(&cfg_file).context("failed to parse config.json")?; + let cfg: Value = serde_json::from_slice(config_bytes.as_ref()).context("failed to parse config.json")?; let cfg_norm = cfg.get("normalize").and_then(Value::as_bool).unwrap_or(true); let normalize = normalize.unwrap_or(cfg_norm); // Load the safetensors - let model_bytes = fs::read(&mdl_path).context("failed to read model.safetensors")?; - let safet = SafeTensors::deserialize(&model_bytes).context("failed to parse safetensors")?; + let safet = SafeTensors::deserialize(model_bytes.as_ref()).context("failed to parse safetensors")?; let tensor = safet .tensor("embeddings") .or_else(|_| safet.tensor("0")) @@ -137,6 +115,26 @@ impl StaticModel { Self::from_owned(tokenizer, floats, rows, cols, normalize, weights, token_mapping) } + /// Load a Model2Vec model from a local folder or the HuggingFace Hub. + /// + /// # Arguments + /// * `repo_or_path` - HuggingFace repo ID or local path to the model folder. + /// * `token` - Optional HuggingFace token for authenticated downloads. + /// * `normalize` - Optional flag to normalize embeddings (default from config.json). + /// * `subfolder` - Optional subfolder within the repo or path to look for model files. + pub fn from_pretrained>( + repo_or_path: P, + token: Option<&str>, + normalize: Option, + subfolder: Option<&str>, + ) -> Result { + let files = resolve_model_files(repo_or_path, token, subfolder)?; + let tokenizer_bytes = fs::read(&files.tokenizer).context("failed to read tokenizer.json")?; + let model_bytes = fs::read(&files.model).context("failed to read model.safetensors")?; + let config_bytes = fs::read(&files.config).context("failed to read config.json")?; + Self::from_bytes(tokenizer_bytes, model_bytes, config_bytes, normalize) + } + /// Construct from owned data. /// /// # Arguments @@ -375,3 +373,42 @@ impl StaticModel { sum } } + +fn resolve_model_files>( + repo_or_path: P, + token: Option<&str>, + subfolder: Option<&str>, +) -> Result { + if let Some(tok) = token { + env::set_var("HF_HUB_TOKEN", tok); + } + + let (tokenizer, model, config) = { + let base = repo_or_path.as_ref(); + if base.exists() { + let folder = subfolder.map(|s| base.join(s)).unwrap_or_else(|| base.to_path_buf()); + let tokenizer = folder.join("tokenizer.json"); + let model = folder.join("model.safetensors"); + let config = folder.join("config.json"); + if !tokenizer.exists() || !model.exists() || !config.exists() { + return Err(anyhow!("local path {folder:?} missing tokenizer / model / config")); + } + (tokenizer, model, config) + } else { + let api = Api::new().context("hf-hub API init failed")?; + let repo = api.model(repo_or_path.as_ref().to_string_lossy().into_owned()); + let prefix = subfolder.map(|s| format!("{s}/")).unwrap_or_default(); + ( + repo.get(&format!("{prefix}tokenizer.json"))?, + repo.get(&format!("{prefix}model.safetensors"))?, + repo.get(&format!("{prefix}config.json"))?, + ) + } + }; + + Ok(ModelFiles { + tokenizer, + model, + config, + }) +} diff --git a/tests/test_model.rs b/tests/test_model.rs index 03581dd..33f8b8f 100644 --- a/tests/test_model.rs +++ b/tests/test_model.rs @@ -1,6 +1,7 @@ mod common; use common::load_test_model; use model2vec_rs::model::StaticModel; +use std::fs; /// Test that encoding an empty input slice yields an empty output #[test] @@ -75,7 +76,6 @@ fn test_normalization_flag_override() { #[test] fn test_from_borrowed() { use safetensors::SafeTensors; - use std::fs; use tokenizers::Tokenizer; let path = "tests/fixtures/test-model-float32"; @@ -97,3 +97,28 @@ fn test_from_borrowed() { let emb = model.encode_single("hello"); assert!(!emb.is_empty()); } + +#[test] +fn test_from_bytes_matches_from_pretrained_for_local_model() { + let path = "tests/fixtures/test-model-float32"; + let from_path = StaticModel::from_pretrained(path, None, None, None).unwrap(); + let from_bytes = StaticModel::from_bytes( + fs::read(format!("{path}/tokenizer.json")).unwrap(), + fs::read(format!("{path}/model.safetensors")).unwrap(), + fs::read(format!("{path}/config.json")).unwrap(), + None, + ) + .unwrap(); + + let query = "hello world"; + let path_embedding = from_path.encode_single(query); + let bytes_embedding = from_bytes.encode_single(query); + + assert_eq!(path_embedding.len(), bytes_embedding.len()); + for (left, right) in path_embedding.iter().zip(bytes_embedding.iter()) { + assert!( + (left - right).abs() < 1e-6, + "expected byte-loaded model to match path-loaded model" + ); + } +}