1use tar::Builder;
2use tract_core::tract_data::itertools::Itertools;
3
4use crate::ast::quant::write_quant_format;
5use crate::ast::{Document, Identifier, ProtoModel, QuantFormat};
6use crate::{internal::*, nnef};
7use std::io::Read;
8#[cfg(target_family = "unix")]
9use std::os::unix::prelude::OsStrExt;
10use std::path::Path;
11use std::str::FromStr;
12
13pub fn stdlib() -> Vec<FragmentDef> {
14 crate::ast::parse::parse_fragments(include_str!("../stdlib.nnef")).unwrap()
15}
16
17pub struct Nnef {
18 pub stdlib: Vec<FragmentDef>,
19 pub registries: Vec<Registry>,
20 pub resource_loaders: Vec<Box<dyn ResourceLoader + 'static>>,
21 pub allow_extended_identifier_syntax: bool,
22}
23
24impl Default for Nnef {
25 fn default() -> Nnef {
26 Nnef {
27 stdlib: stdlib(),
28 registries: vec![crate::ops::tract_nnef()],
29 resource_loaders: vec![
30 GraphNnefLoader.into_boxed(),
31 DatLoader.into_boxed(),
32 GraphQuantLoader.into_boxed(),
33 TypedModelLoader::new(false).into_boxed(),
34 ],
35 allow_extended_identifier_syntax: false,
36 }
37 }
38}
39
40impl Nnef {
41 pub fn with_registry(mut self, registry: Registry) -> Nnef {
42 self.registries.push(registry);
43 self
44 }
45
46 pub fn with_resource_loader(mut self, loader: impl ResourceLoader + 'static) -> Nnef {
47 self.resource_loaders.push(Box::new(loader));
48 self
49 }
50
51 pub fn enable_tract_core(&mut self) {
52 self.registries.push(crate::ops::tract_core());
53 }
54
55 pub fn with_tract_core(mut self) -> Self {
56 self.registries.push(crate::ops::tract_core());
57 self
58 }
59
60 pub fn enable_tract_resource(&mut self) {
61 self.registries.push(crate::ops::tract_resource());
62 }
63
64 pub fn with_tract_resource(mut self) -> Self {
65 self.registries.push(crate::ops::tract_resource());
66 self
67 }
68
69 pub fn allow_extended_identifier_syntax(&mut self, allow_extended_identifier_syntax: bool) {
70 self.allow_extended_identifier_syntax = allow_extended_identifier_syntax;
71 }
72
73 #[allow(clippy::result_large_err)]
74 pub fn translate(
75 &self,
76 proto_model: &ProtoModel,
77 template: TypedModel,
78 ) -> Result<TypedModel, (TypedModel, TractError)> {
79 ModelBuilder::new(self, proto_model, template).into_typed_model()
80 }
81
82 pub fn write(&self, model: &TypedModel, w: impl std::io::Write) -> TractResult<()> {
83 self.write_to_tar(model, w)?;
84 Ok(())
85 }
86
87 pub fn write_to_tar<W: std::io::Write>(&self, model: &TypedModel, w: W) -> TractResult<W> {
88 let mut ar = tar::Builder::new(w);
89 let timestamp =
90 std::time::SystemTime::now().duration_since(std::time::SystemTime::UNIX_EPOCH).unwrap();
91 self._write_to_tar(model, &mut ar, false, timestamp)?;
92 ar.into_inner().context("Finalizing tar")
93 }
94
95 pub fn write_to_tar_with_config<W: std::io::Write>(
96 &self,
97 model: &TypedModel,
98 w: W,
99 compress_nested_models: bool,
100 deterministic: bool,
101 ) -> TractResult<W> {
102 let mut ar = tar::Builder::new(w);
103 let timestamp = if deterministic {
104 std::time::Duration::from_secs(315532800)
106 } else {
107 std::time::SystemTime::now().duration_since(std::time::SystemTime::UNIX_EPOCH).unwrap()
108 };
109
110 self._write_to_tar(model, &mut ar, compress_nested_models, timestamp)?;
111 ar.into_inner().context("Finalizing tar")
112 }
113
114 fn _write_to_tar<W: std::io::Write>(
115 &self,
116 model: &TypedModel,
117 ar: &mut Builder<W>,
118 compress_nested_models: bool,
119 timestamp: std::time::Duration,
120 ) -> TractResult<()> {
121 let proto_model =
122 crate::ser::to_proto_model(self, model).context("Translating model to proto_model")?;
123
124 let mut graph_data = vec![];
125 crate::ast::dump::Dumper::new(self, &mut graph_data)
126 .document(&proto_model.doc)
127 .context("Serializing graph.nnef")?;
128
129 let mut header = tar::Header::new_gnu();
130 header.set_path("graph.nnef").context("Setting graph.nnef path")?;
131 header.set_size(graph_data.len() as u64);
132 header.set_mode(0o644);
133 header.set_mtime(timestamp.as_secs());
134 header.set_cksum();
135 ar.append(&header, &mut &*graph_data).context("Appending graph.nnef")?;
136
137 if let Some(mut quantization) = proto_model.quantization {
138 let mut quant_data = vec![];
139
140 let mut keys = quantization.keys().cloned().collect::<Vec<_>>();
141 keys.sort();
142 for name in keys {
143 let format = quantization.remove(&name).unwrap();
144 write_quant_format(
145 &mut quant_data,
146 &name,
147 format,
148 self.allow_extended_identifier_syntax,
149 )
150 .context("Serializing graph.quant")?;
151 }
152
153 header.set_path("graph.quant").context("Setting graph.quant path")?;
154 header.set_size(quant_data.len() as u64);
155 header.set_mode(0o644);
156 header.set_mtime(timestamp.as_secs());
157 header.set_cksum();
158 ar.append(&header, &mut &*quant_data).context("Appending graph.quant")?;
159 }
160
161 let mut labels = proto_model.tensors.keys().collect::<Vec<_>>();
162 labels.sort();
163 for label in labels {
164 let t = proto_model.tensors.get(label).unwrap();
165 let mut label = label.0.to_string() + ".dat";
166 if label.starts_with('/') {
167 label.insert(0, '.');
168 }
169 let filename = std::path::Path::new(&label);
170 let mut data = vec![];
171 crate::tensors::write_tensor(&mut data, t)
172 .with_context(|| format!("Serializing tensor {filename:?}: {t:?}"))?;
173
174 let mut header = tar::Header::new_gnu();
175 header.set_size(data.len() as u64);
176 header.set_mode(0o644);
177 header.set_mtime(timestamp.as_secs());
178 header.set_cksum();
179
180 ar.append_data(&mut header, filename, &mut &*data)
181 .with_context(|| format!("Appending tensor {filename:?}"))?;
182 }
183
184 let mut labels = proto_model.resources.keys().collect::<Vec<_>>();
185 labels.sort();
186 for label in labels {
187 let resource = proto_model.resources.get(label).unwrap();
188 if let Some(typed_model_resource) = resource.downcast_ref::<TypedModelResource>() {
189 let mut submodel_data = vec![];
190 let mut filename = std::path::PathBuf::from_str(label)?;
191 let typed_model = &typed_model_resource.0;
192
193 if compress_nested_models {
194 filename.set_extension("nnef.tgz");
195 let encoder = flate2::write::GzEncoder::new(
196 &mut submodel_data,
197 flate2::Compression::default(),
198 );
199 self.write(typed_model, encoder)?;
200 } else {
201 filename.set_extension("nnef.tar");
202 self.write(typed_model, &mut submodel_data)?;
203 }
204
205 let mut header = tar::Header::new_gnu();
206 header.set_size(submodel_data.len() as u64);
207 header.set_mode(0o644);
208 header.set_mtime(timestamp.as_secs());
209 header.set_cksum();
210
211 ar.append_data(&mut header, filename, &mut &*submodel_data)
212 .with_context(|| format!("Appending submodel {label:?}"))?;
213 }
214 }
215 Ok(())
216 }
217
218 pub fn write_to_dir(
219 &self,
220 model: &TypedModel,
221 path: impl AsRef<std::path::Path>,
222 ) -> TractResult<()> {
223 let path = path.as_ref();
224 if path.exists() {
225 bail!("{:?} already exists. Won't overwrite.", path);
226 }
227 let proto_model = crate::ser::to_proto_model(self, model)?;
228 std::fs::create_dir_all(path)?;
229 let mut graph_nnef = std::fs::File::create(path.join("graph.nnef"))?;
230 crate::ast::dump::Dumper::new(self, &mut graph_nnef).document(&proto_model.doc)?;
231
232 if let Some(quantization) = proto_model.quantization {
233 let mut graph_quant = std::fs::File::create(path.join("graph.quant"))?;
234 for (name, format) in quantization.into_iter().sorted_by_key(|(x, _)| x.clone()) {
235 write_quant_format(
236 &mut graph_quant,
237 &name,
238 format,
239 self.allow_extended_identifier_syntax,
240 )?;
241 }
242 }
243
244 for (label, t) in &proto_model.tensors {
245 let label = label.0.to_string() + ".dat";
246 let label = label.trim_start_matches('/');
247 let parent = path.join(label).parent().unwrap().to_owned();
248 std::fs::create_dir_all(&parent).with_context(|| format!("Creating dir {parent:?}"))?;
249 let filename = path.join(label).to_owned();
250 let mut file = std::fs::File::create(&filename)
251 .with_context(|| format!("Creating file {filename:?}"))?;
252
253 crate::tensors::write_tensor(&mut file, t)?;
254 }
255 Ok(())
256 }
257}
258
259impl tract_core::prelude::Framework<ProtoModel, TypedModel> for Nnef {
260 fn model_for_path(&self, p: impl AsRef<Path>) -> TractResult<TypedModel> {
261 let proto = self.proto_model_for_path(p)?;
262 self.model_for_proto_model(&proto)
263 }
264
265 fn proto_model_for_path(&self, path: impl AsRef<Path>) -> TractResult<ProtoModel> {
266 let path = path.as_ref();
267 if path.is_file() {
268 let mut f = std::fs::File::open(path)?;
269 return self.proto_model_for_read(&mut f);
270 }
271
272 let mut resources: HashMap<String, Arc<dyn Resource>> = Default::default();
273
274 for entry in walkdir::WalkDir::new(path).min_depth(1) {
277 let entry =
278 entry.map_err(|e| format_err!("Can not walk directory {:?}: {:?}", path, e))?;
279 if entry.path().is_dir() {
281 continue;
282 }
283 let subpath = entry
284 .path()
285 .components()
286 .skip(path.components().count())
287 .collect::<std::path::PathBuf>();
288 let mut stream = std::fs::File::open(entry.path())?;
289 read_stream(&subpath, &mut stream, &mut resources, self)?;
290 }
291 proto_model_from_resources(resources)
292 }
293
294 fn proto_model_for_read(&self, reader: &mut dyn std::io::Read) -> TractResult<ProtoModel> {
295 let mut resources: HashMap<String, Arc<dyn Resource>> = Default::default();
296
297 let mut buffer = vec![0u8; 2];
298 reader.read_exact(&mut buffer)?;
299 let header = std::io::Cursor::new(buffer.clone());
300 let stream = header.chain(reader);
301 let mut tar = if buffer == [0x1f, 0x8b] {
302 #[cfg(feature = "flate2")]
303 {
304 let f = flate2::read::GzDecoder::new(stream);
305 tar::Archive::new(Box::new(f) as Box<dyn Read>)
306 }
307 #[cfg(not(feature = "flate2"))]
308 bail!("Cannot read gzip file without flate2 enabled.");
309 } else {
310 tar::Archive::new(Box::new(stream) as Box<dyn Read>)
311 };
312 for entry in tar.entries()? {
313 let mut entry = entry?;
314 let mut path = entry.path()?.to_path_buf();
315 if path.starts_with("./") {
316 path = path.strip_prefix("./")?.to_path_buf();
317 }
318 read_stream(&path, &mut entry, &mut resources, self)?;
319 }
320 proto_model_from_resources(resources)
321 }
322
323 fn model_for_proto_model_with_model_template(
324 &self,
325 proto: &ProtoModel,
326 template: TypedModel,
327 ) -> TractResult<TypedModel> {
328 self.translate(proto, template).map_err(|e| e.1)
329 }
330}
331
332fn proto_model_from_resources(
333 resources: HashMap<String, Arc<dyn Resource>>,
334) -> TractResult<ProtoModel> {
335 let sub_models = resources
339 .keys()
340 .clone()
341 .filter_map(|id| {
342 let id_components = id.split('/').collect::<Vec<_>>();
343 if (id_components.last() == Some(&crate::resource::GRAPH_NNEF_FILENAME))
344 & (id_components.len() == 2)
345 {
346 id_components.first().map(|it| it.to_string())
347 } else {
348 None
349 }
350 })
351 .collect::<Vec<_>>();
352
353 let mut new_resources = if sub_models.len() > 0 {
356 sub_models.into_iter().try_fold(resources, |r, it| -> TractResult<HashMap<_, _>> {
357 let (submodel_resources, mut resources): (HashMap<String, Arc<dyn Resource>>, _) =
358 r.into_iter().partition(|(k, _v)| k.starts_with(&it));
359 let submodel_resources = submodel_resources
360 .into_iter()
361 .map(|(k, v)| (k.split('/').next_back().unwrap().to_string(), v))
362 .collect::<HashMap<String, Arc<dyn Resource>>>();
363 let typed_model = nnef()
364 .model_for_proto_model(&proto_model_from_resources(submodel_resources).unwrap())?;
365 resources.insert(it, Arc::new(TypedModelResource(typed_model)));
366 Ok(resources)
367 })?
368 } else {
369 resources
370 };
371
372 let doc = new_resources
374 .remove(crate::resource::GRAPH_NNEF_FILENAME)
375 .with_context(|| {
376 anyhow!("Resource {} was not found in the model", crate::resource::GRAPH_NNEF_FILENAME)
377 })?
378 .downcast_arc::<Document>()
379 .map_err(|_| anyhow!("Error while downcasting NNEF document resource"))?;
380
381 let doc = Arc::try_unwrap(doc)
382 .map_err(|_| anyhow!("Error while extracting NNEF Document from shared reference. Only one reference to the document is expected"))?;
383
384 let tensors: HashMap<_, _> = new_resources
386 .iter()
387 .filter_map(|(key, resource)| {
388 Arc::clone(resource)
389 .downcast_arc::<Tensor>()
390 .ok()
391 .map(|r| (Identifier::from(&**key), r))
392 })
393 .collect();
394 tensors.keys().for_each(|k| {
396 new_resources.remove(&*k.0);
397 });
398
399 let quantization = if let Some(q_r) =
401 new_resources.remove(crate::resource::GRAPH_QUANT_FILENAME)
402 {
403 let Ok(q_r) = q_r.downcast_arc::<HashMap<String, QuantFormat>>() else {
404 bail!("Error while downcasting quantization format resource")
405 };
406 let Ok(q_r) = Arc::try_unwrap(q_r) else {
407 bail!("Error while extracting quantization format resource from shared reference. Only one reference to it is expected")
408 };
409 Some(q_r.into_iter().map(|(k, v)| (Identifier(k), v)).collect())
410 } else {
411 None
412 };
413
414 let proto = ProtoModel { doc, tensors, quantization, resources: new_resources };
415 proto.validate()?;
416 Ok(proto)
417}
418
419fn read_stream<R: std::io::Read>(
420 path: &Path,
421 reader: &mut R,
422 resources: &mut HashMap<String, Arc<dyn Resource>>,
423 framework: &Nnef,
424) -> TractResult<()> {
425 #[cfg(target_family = "unix")]
427 if path.components().any(|name| name.as_os_str().as_bytes().first() == Some(&b'.')) {
428 return Ok(());
429 }
430 let mut last_loader_name;
431 for loader in framework.resource_loaders.iter() {
432 last_loader_name = Some(loader.name());
433 let loaded = loader.try_load(path, reader, framework).with_context(|| {
434 anyhow!("Error while loading resource by {:?} at path {:?}", loader.name(), path)
435 })?;
436 if let Some((id, resource)) = loaded {
437 ensure!(
438 !resources.contains_key(&id),
439 "Loader {:?} succeeded to load {:?} which has been already loaded by {:?}",
440 loader.name(),
441 id,
442 last_loader_name
443 );
444 resources.insert(id, resource);
445 break;
446 }
447 }
448 Ok(())
449}