1use anyhow::{Error as E, Result};
2use candle_nn::VarBuilder;
3use candle_transformers::models::bert::{BertModel, Config, HiddenAct, DTYPE};
4use hf_hub::{api::sync::Api, Repo, RepoType};
5use tokenizers::{Tokenizer, TruncationParams};
6
7pub struct Args {
8 pub cpu: bool,
10 pub tracing: bool,
12 pub model_id: Option<String>,
14 pub revision: Option<String>,
15 pub use_pth: bool,
17 pub normalize_embeddings: bool,
19 pub approximate_gelu: bool,
21}
22
23impl Args {
24 pub fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> {
25 let device = crate::util::device(self.cpu)?;
26 let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
27 let default_revision = "refs/pr/21".to_string();
28 let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
29 (Some(model_id), Some(revision)) => (model_id, revision),
30 (Some(model_id), None) => (model_id, "main".to_string()),
31 (None, Some(revision)) => (default_model, revision),
32 (None, None) => (default_model, default_revision),
33 };
34
35 let repo = Repo::with_revision(model_id, RepoType::Model, revision);
36 let (config_filename, tokenizer_filename, weights_filename) = {
37 let api = Api::new()?;
38 let api = api.repo(repo);
39 let config = api.get("config.json")?;
40 let tokenizer = api.get("tokenizer.json")?;
41 let weights = if self.use_pth {
42 api.get("pytorch_model.bin")?
43 } else {
44 api.get("model.safetensors")?
45 };
46 (config, tokenizer, weights)
47 };
48 let config = std::fs::read_to_string(config_filename)?;
49 let mut config: Config = serde_json::from_str(&config)?;
50 let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
51 let _ = tokenizer.with_truncation(Some(TruncationParams {
52 max_length: 512,
53 ..Default::default()
54 }));
55 let vb = if self.use_pth {
56 VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
57 } else {
58 unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
59 };
60 if self.approximate_gelu {
61 config.hidden_act = HiddenAct::GeluApproximate;
62 }
63 let model = BertModel::load(vb, &config)?;
64 Ok((model, tokenizer))
65 }
66}