sapient_models/
weights.rs1use std::collections::HashMap;
4use std::path::PathBuf;
5
6use anyhow::{bail, Context, Result};
7use sapient_core::Tensor;
8use sapient_io::SafetensorsLoader;
9
10pub fn load_hf_weights(paths: &[PathBuf]) -> Result<HashMap<String, Tensor>> {
12 let mut merged = HashMap::new();
13 for path in paths {
14 let shard = SafetensorsLoader::load(path)
15 .with_context(|| format!("failed to load weights from {}", path.display()))?;
16 for (k, v) in shard {
17 if merged.insert(k.clone(), v).is_some() {
18 bail!("duplicate weight key '{k}' in shard {}", path.display());
19 }
20 }
21 }
22 Ok(merged)
23}
24
25pub fn detect_weight_prefix(weights: &HashMap<String, Tensor>) -> String {
27 const CANDIDATES: &[&str] = &[
28 "model.text_model.",
29 "model.language_model.",
30 "transformer.",
31 "model.",
32 "gpt_neox.",
33 ];
34
35 for prefix in CANDIDATES {
36 let embed_key = format!("{prefix}embed_tokens.weight");
37 if weights.contains_key(&embed_key) {
38 return prefix.to_string();
39 }
40 }
41
42 if weights.contains_key("embed_tokens.weight") {
43 return String::new();
44 }
45
46 weights
48 .keys()
49 .find(|k| k.ends_with("embed_tokens.weight"))
50 .map(|k| {
51 k.strip_suffix("embed_tokens.weight")
52 .unwrap_or("")
53 .to_string()
54 })
55 .unwrap_or_else(|| "model.".to_string())
56}
57
58pub fn resolve_weight<'a>(
60 weights: &'a HashMap<String, Tensor>,
61 prefix: &str,
62 suffix: &str,
63) -> Result<&'a Tensor> {
64 let key = format!("{prefix}{suffix}.weight");
65 weights
66 .get(&key)
67 .or_else(|| weights.get(suffix))
68 .with_context(|| format!("missing weight '{key}'"))
69}
70
71pub fn resolve_bias<'a>(
74 weights: &'a HashMap<String, Tensor>,
75 prefix: &str,
76 suffix: &str,
77) -> Option<&'a Tensor> {
78 let key = format!("{prefix}{suffix}.bias");
79 weights
80 .get(&key)
81 .or_else(|| weights.get(&format!("{suffix}.bias")))
82}
83
84pub fn resolve_lm_head<'a>(
86 weights: &'a HashMap<String, Tensor>,
87 prefix: &str,
88 tie_word_embeddings: bool,
89 embed_key: &str,
90) -> Result<&'a Tensor> {
91 if tie_word_embeddings {
92 return weights
93 .get(embed_key)
94 .with_context(|| format!("missing tied embedding weight '{embed_key}'"));
95 }
96
97 weights
98 .get("lm_head.weight")
99 .or_else(|| weights.get(&format!("{prefix}lm_head.weight")))
100 .with_context(|| "missing lm_head.weight")
101}
102
103pub fn tie_word_embeddings_from_config(raw: &serde_json::Value) -> bool {
104 raw.get("tie_word_embeddings")
105 .and_then(|v| v.as_bool())
106 .or_else(|| {
107 raw.get("text_config")
108 .and_then(|tc| tc.get("tie_word_embeddings"))
109 .and_then(|v| v.as_bool())
110 })
111 .unwrap_or(false)
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117
118 #[test]
119 fn detect_text_model_prefix() {
120 let mut w = HashMap::new();
121 w.insert(
122 "model.text_model.embed_tokens.weight".into(),
123 Tensor::zeros(vec![1, 1], sapient_core::DType::F32).unwrap(),
124 );
125 assert_eq!(detect_weight_prefix(&w), "model.text_model.");
126 }
127}