1use std::path::Path;
2
3use crate::ast::QuantFormat;
4use crate::internal::*;
5use safetensors::SafeTensors;
6use tract_core::downcast_rs::{impl_downcast, DowncastSync};
7use tract_core::tract_data::itertools::Itertools;
8
9pub const GRAPH_NNEF_FILENAME: &str = "graph.nnef";
10pub const GRAPH_QUANT_FILENAME: &str = "graph.quant";
11
12pub fn resource_path_to_id(path: impl AsRef<Path>) -> TractResult<String> {
13 let mut path = path.as_ref().to_path_buf();
14 path.set_extension("");
15 path.to_str()
16 .ok_or_else(|| format_err!("Badly encoded filename for path: {:?}", path))
17 .map(|s| s.to_string())
18}
19
20pub trait Resource: DowncastSync + std::fmt::Debug + Send + Sync {
21 fn get(&self, _key: &str) -> TractResult<Value> {
23 bail!("No key access supported by this resource");
24 }
25
26 fn to_liquid_value(&self) -> Option<liquid::model::Value> {
27 None
28 }
29}
30
31impl_downcast!(sync Resource);
32
33pub trait ResourceLoader: Send + Sync {
34 fn name(&self) -> StaticName;
36 fn try_load(
39 &self,
40 path: &Path,
41 reader: &mut dyn std::io::Read,
42 framework: &Nnef,
43 ) -> TractResult<Option<(String, Arc<dyn Resource>)>>;
44
45 fn into_boxed(self) -> Box<dyn ResourceLoader>
46 where
47 Self: Sized + 'static,
48 {
49 Box::new(self)
50 }
51}
52
53#[derive(Debug)]
54pub struct GraphNnef(pub String);
55impl Resource for GraphNnef {}
56
57#[derive(Clone, Debug, Copy, PartialEq, Eq, Hash)]
58pub struct GraphNnefLoader;
59
60impl ResourceLoader for GraphNnefLoader {
61 fn name(&self) -> StaticName {
62 "GraphNnefLoader".into()
63 }
64
65 fn try_load(
66 &self,
67 path: &Path,
68 reader: &mut dyn std::io::Read,
69 _framework: &Nnef,
70 ) -> TractResult<Option<(String, Arc<dyn Resource>)>> {
71 if path.ends_with(GRAPH_NNEF_FILENAME) {
72 let mut text = String::new();
73 reader.read_to_string(&mut text)?;
74 Ok(Some((path.to_string_lossy().to_string(), Arc::new(GraphNnef(text)))))
75 } else {
76 Ok(None)
77 }
78 }
79}
80
81impl Resource for Tensor {}
82
83#[derive(Clone, Debug, Copy, PartialEq, Eq, Hash)]
84pub struct DatLoader;
85
86impl ResourceLoader for DatLoader {
87 fn name(&self) -> StaticName {
88 "DatLoader".into()
89 }
90
91 fn try_load(
92 &self,
93 path: &Path,
94 reader: &mut dyn std::io::Read,
95 _framework: &Nnef,
96 ) -> TractResult<Option<(String, Arc<dyn Resource>)>> {
97 if path.extension().map(|e| e == "dat").unwrap_or(false) {
98 let tensor = crate::tensors::read_tensor(reader)
99 .with_context(|| format!("Error while reading tensor {path:?}"))?;
100 Ok(Some((resource_path_to_id(path)?, Arc::new(tensor))))
101 } else {
102 Ok(None)
103 }
104 }
105}
106
107impl Resource for HashMap<String, QuantFormat> {}
108
109#[derive(Clone, Debug, Copy, PartialEq, Eq, Hash)]
110pub struct GraphQuantLoader;
111
112impl ResourceLoader for GraphQuantLoader {
113 fn name(&self) -> StaticName {
114 "GraphQuantLoader".into()
115 }
116
117 fn try_load(
118 &self,
119 path: &Path,
120 reader: &mut dyn std::io::Read,
121 _framework: &Nnef,
122 ) -> TractResult<Option<(String, Arc<dyn Resource>)>> {
123 if path.ends_with(GRAPH_QUANT_FILENAME) {
124 let mut text = String::new();
125 reader.read_to_string(&mut text)?;
126 let quant = crate::ast::quant::parse_quantization(&text)?;
127 let quant: HashMap<String, QuantFormat> =
128 quant.into_iter().map(|(k, v)| (k.0, v)).collect();
129 Ok(Some((path.to_str().unwrap().to_string(), Arc::new(quant))))
130 } else {
131 Ok(None)
132 }
133 }
134}
135
136pub struct TypedModelLoader {
137 pub optimized_model: bool,
138}
139
140impl TypedModelLoader {
141 pub fn new(optimized_model: bool) -> Self {
142 Self { optimized_model }
143 }
144}
145
146impl ResourceLoader for TypedModelLoader {
147 fn name(&self) -> StaticName {
148 "TypedModelLoader".into()
149 }
150
151 fn try_load(
152 &self,
153 path: &Path,
154 reader: &mut dyn std::io::Read,
155 framework: &Nnef,
156 ) -> TractResult<Option<(String, Arc<dyn Resource>)>> {
157 const NNEF_TGZ: &str = ".nnef.tgz";
158 const NNEF_TAR: &str = ".nnef.tar";
159 let path_str = path.to_str().unwrap_or("");
160 if path_str.ends_with(NNEF_TGZ) || path_str.ends_with(NNEF_TAR) {
161 let model = if self.optimized_model {
162 framework.model_for_read(reader)?.into_optimized()?
163 } else {
164 framework.model_for_read(reader)?
165 };
166
167 let label = if path_str.ends_with(NNEF_TGZ) {
168 path.to_str()
169 .ok_or_else(|| anyhow!("invalid model resource path"))?
170 .trim_end_matches(NNEF_TGZ)
171 } else {
172 path.to_str()
173 .ok_or_else(|| anyhow!("invalid model resource path"))?
174 .trim_end_matches(NNEF_TAR)
175 };
176 Ok(Some((resource_path_to_id(label)?, Arc::new(TypedModelResource(model)))))
177 } else {
178 Ok(None)
179 }
180 }
181}
182
183#[derive(Debug, Clone)]
184pub struct TypedModelResource(pub TypedModel);
185
186impl Resource for TypedModelResource {}
187
188pub struct SafeTensorsLoader;
189
190impl ResourceLoader for SafeTensorsLoader {
191 fn name(&self) -> StaticName {
192 "SafeTensorsLoader".into()
193 }
194
195 fn try_load(
196 &self,
197 path: &Path,
198 reader: &mut dyn std::io::Read,
199 _framework: &Nnef,
200 ) -> TractResult<Option<(String, Arc<dyn Resource>)>> {
201 if path.extension().is_some_and(|e| e == "safetensors") {
202 let mut buffer = vec![];
203 reader.read_to_end(&mut buffer)?;
204 let tensors: Vec<(String, Arc<Tensor>)> = SafeTensors::deserialize(&buffer)?
205 .tensors()
206 .into_iter()
207 .map(|(name, t)| {
208 let dt = match t.dtype() {
209 safetensors::Dtype::F32 => DatumType::F32,
210 safetensors::Dtype::F16 => DatumType::F16,
211 _ => panic!(),
212 };
213 let tensor = unsafe { Tensor::from_raw_dt(dt, t.shape(), t.data()).unwrap() };
214 (name, tensor.into_arc_tensor())
215 })
216 .collect_vec();
217 return Ok(Some((path.to_string_lossy().to_string(), Arc::new(tensors))));
218 }
219 Ok(None)
220 }
221}
222
223impl Resource for Vec<(String, Arc<Tensor>)> {}