1use tar::Builder;
2use tract_core::tract_data::itertools::Itertools;
3use tract_core::transform::ModelTransform;
4
5use crate::ast::quant::write_quant_format;
6use crate::ast::{Identifier, ProtoModel, QuantFormat};
7use crate::resource::{GraphNnef, SafeTensorsLoader};
8use crate::{internal::*, nnef};
9use std::io::Read;
10#[cfg(target_family = "unix")]
11use std::os::unix::prelude::OsStrExt;
12use std::path::Path;
13use std::str::FromStr;
14
15pub fn stdlib() -> Vec<FragmentDef> {
16 crate::ast::parse::parse_fragments(include_str!("../stdlib.nnef")).unwrap()
17}
18
19pub struct Nnef {
20 pub stdlib: Vec<FragmentDef>,
21 pub registries: Vec<Registry>,
22 pub resource_loaders: Vec<Box<dyn ResourceLoader + 'static>>,
23 pub allow_extended_identifier_syntax: bool,
24}
25
26impl Default for Nnef {
27 fn default() -> Nnef {
28 Nnef {
29 stdlib: stdlib(),
30 registries: vec![crate::ops::tract_nnef()],
31 resource_loaders: vec![
32 GraphNnefLoader.into_boxed(),
33 DatLoader.into_boxed(),
34 GraphQuantLoader.into_boxed(),
35 TypedModelLoader::new(false).into_boxed(),
36 SafeTensorsLoader.into_boxed(),
37 ],
38 allow_extended_identifier_syntax: false,
39 }
40 }
41}
42
43impl Nnef {
44 pub fn with_registry(mut self, registry: Registry) -> Nnef {
45 self.registries.push(registry);
46 self
47 }
48
49 pub fn with_resource_loader(mut self, loader: impl ResourceLoader + 'static) -> Nnef {
50 self.resource_loaders.push(Box::new(loader));
51 self
52 }
53
54 pub fn enable_tract_core(&mut self) {
55 self.registries.push(crate::ops::tract_core());
56 }
57
58 pub fn with_tract_core(mut self) -> Self {
59 self.registries.push(crate::ops::tract_core());
60 self
61 }
62
63 pub fn enable_tract_resource(&mut self) {
64 self.registries.push(crate::ops::tract_resource());
65 }
66
67 pub fn with_tract_resource(mut self) -> Self {
68 self.registries.push(crate::ops::tract_resource());
69 self
70 }
71
72 pub fn get_transform(&self, spec: &str) -> TractResult<Option<Box<dyn ModelTransform>>> {
73 for reg in &self.registries {
74 if let Some(transform) = (reg.transforms)(spec)? {
75 return Ok(Some(transform));
76 }
77 }
78 Ok(None)
79 }
80
81 pub fn allow_extended_identifier_syntax(&mut self, allow_extended_identifier_syntax: bool) {
82 self.allow_extended_identifier_syntax = allow_extended_identifier_syntax;
83 }
84
85 #[allow(clippy::result_large_err)]
86 pub fn translate(
87 &self,
88 proto_model: &ProtoModel,
89 template: TypedModel,
90 ) -> Result<TypedModel, (TypedModel, TractError)> {
91 debug!("Build TypedModel from NNEF proto model (translate)");
92 ModelBuilder::new(self, proto_model, template).into_typed_model()
93 }
94
95 pub fn write(&self, model: &TypedModel, w: impl std::io::Write) -> TractResult<()> {
96 self.write_to_tar(model, w)?;
97 Ok(())
98 }
99
100 pub fn write_to_tar<W: std::io::Write>(&self, model: &TypedModel, w: W) -> TractResult<W> {
101 let mut ar = tar::Builder::new(w);
102 let timestamp =
103 std::time::SystemTime::now().duration_since(std::time::SystemTime::UNIX_EPOCH).unwrap();
104 self._write_to_tar(model, &mut ar, false, timestamp)?;
105 ar.into_inner().context("Finalizing tar")
106 }
107
108 pub fn write_to_tar_with_config<W: std::io::Write>(
109 &self,
110 model: &TypedModel,
111 w: W,
112 compress_nested_models: bool,
113 deterministic: bool,
114 ) -> TractResult<W> {
115 let mut ar = tar::Builder::new(w);
116 let timestamp = if deterministic {
117 std::time::Duration::from_secs(315532800)
119 } else {
120 std::time::SystemTime::now().duration_since(std::time::SystemTime::UNIX_EPOCH).unwrap()
121 };
122
123 self._write_to_tar(model, &mut ar, compress_nested_models, timestamp)?;
124 ar.into_inner().context("Finalizing tar")
125 }
126
127 fn _write_to_tar<W: std::io::Write>(
128 &self,
129 model: &TypedModel,
130 ar: &mut Builder<W>,
131 compress_nested_models: bool,
132 timestamp: std::time::Duration,
133 ) -> TractResult<()> {
134 let proto_model =
135 crate::ser::to_proto_model(self, model).context("Translating model to proto_model")?;
136
137 let mut graph_data = vec![];
138 crate::ast::dump::Dumper::new(self, &mut graph_data)
139 .document(&proto_model.doc)
140 .context("Serializing graph.nnef")?;
141
142 let mut header = tar::Header::new_gnu();
143 header.set_path("graph.nnef").context("Setting graph.nnef path")?;
144 header.set_size(graph_data.len() as u64);
145 header.set_mode(0o644);
146 header.set_mtime(timestamp.as_secs());
147 header.set_cksum();
148 ar.append(&header, &mut &*graph_data).context("Appending graph.nnef")?;
149
150 if let Some(mut quantization) = proto_model.quantization {
151 let mut quant_data = vec![];
152
153 let mut keys = quantization.keys().cloned().collect::<Vec<_>>();
154 keys.sort();
155 for name in keys {
156 let format = quantization.remove(&name).unwrap();
157 write_quant_format(
158 &mut quant_data,
159 &name,
160 format,
161 self.allow_extended_identifier_syntax,
162 )
163 .context("Serializing graph.quant")?;
164 }
165
166 header.set_path("graph.quant").context("Setting graph.quant path")?;
167 header.set_size(quant_data.len() as u64);
168 header.set_mode(0o644);
169 header.set_mtime(timestamp.as_secs());
170 header.set_cksum();
171 ar.append(&header, &mut &*quant_data).context("Appending graph.quant")?;
172 }
173
174 let mut labels = proto_model.tensors.keys().collect::<Vec<_>>();
175 labels.sort();
176 for label in labels {
177 let t = proto_model.tensors.get(label).unwrap();
178 let mut label = label.0.to_string() + ".dat";
179 if label.starts_with('/') {
180 label.insert(0, '.');
181 }
182 let filename = std::path::Path::new(&label);
183 let mut data = vec![];
184 crate::tensors::write_tensor(&mut data, t)
185 .with_context(|| format!("Serializing tensor {filename:?}: {t:?}"))?;
186
187 let mut header = tar::Header::new_gnu();
188 header.set_size(data.len() as u64);
189 header.set_mode(0o644);
190 header.set_mtime(timestamp.as_secs());
191 header.set_cksum();
192
193 ar.append_data(&mut header, filename, &mut &*data)
194 .with_context(|| format!("Appending tensor {filename:?}"))?;
195 }
196
197 let mut labels = proto_model.resources.keys().collect::<Vec<_>>();
198 labels.sort();
199 for label in labels {
200 let resource = proto_model.resources.get(label).unwrap();
201 if let Some(typed_model_resource) = resource.downcast_ref::<TypedModelResource>() {
202 let mut submodel_data = vec![];
203 let mut filename = std::path::PathBuf::from_str(label)?;
204 let typed_model = &typed_model_resource.0;
205
206 if compress_nested_models {
207 filename.set_extension("nnef.tgz");
208 let encoder = flate2::write::GzEncoder::new(
209 &mut submodel_data,
210 flate2::Compression::default(),
211 );
212 self.write(typed_model, encoder)?;
213 } else {
214 filename.set_extension("nnef.tar");
215 self.write(typed_model, &mut submodel_data)?;
216 }
217
218 let mut header = tar::Header::new_gnu();
219 header.set_size(submodel_data.len() as u64);
220 header.set_mode(0o644);
221 header.set_mtime(timestamp.as_secs());
222 header.set_cksum();
223
224 ar.append_data(&mut header, filename, &mut &*submodel_data)
225 .with_context(|| format!("Appending submodel {label:?}"))?;
226 }
227 }
228 Ok(())
229 }
230
231 pub fn write_to_dir(
232 &self,
233 model: &TypedModel,
234 path: impl AsRef<std::path::Path>,
235 ) -> TractResult<()> {
236 let path = path.as_ref();
237 if path.exists() {
238 bail!("{:?} already exists. Won't overwrite.", path);
239 }
240 let proto_model = crate::ser::to_proto_model(self, model)?;
241 std::fs::create_dir_all(path)?;
242 let mut graph_nnef = std::fs::File::create(path.join("graph.nnef"))?;
243 crate::ast::dump::Dumper::new(self, &mut graph_nnef).document(&proto_model.doc)?;
244
245 if let Some(quantization) = proto_model.quantization {
246 let mut graph_quant = std::fs::File::create(path.join("graph.quant"))?;
247 for (name, format) in quantization.into_iter().sorted_by_key(|(x, _)| x.clone()) {
248 write_quant_format(
249 &mut graph_quant,
250 &name,
251 format,
252 self.allow_extended_identifier_syntax,
253 )?;
254 }
255 }
256
257 for (label, t) in &proto_model.tensors {
258 let label = label.0.to_string() + ".dat";
259 let label = label.trim_start_matches('/');
260 let parent = path.join(label).parent().unwrap().to_owned();
261 std::fs::create_dir_all(&parent).with_context(|| format!("Creating dir {parent:?}"))?;
262 let filename = path.join(label).to_owned();
263 let mut file = std::fs::File::create(&filename)
264 .with_context(|| format!("Creating file {filename:?}"))?;
265
266 crate::tensors::write_tensor(&mut file, t)?;
267 }
268 Ok(())
269 }
270}
271
272impl tract_core::prelude::Framework<ProtoModel, TypedModel> for Nnef {
273 fn model_for_path(&self, p: impl AsRef<Path>) -> TractResult<TypedModel> {
274 let proto = self.proto_model_for_path(p)?;
275 self.model_for_proto_model(&proto)
276 }
277
278 fn proto_model_for_path(&self, path: impl AsRef<Path>) -> TractResult<ProtoModel> {
279 let path = path.as_ref();
280 if path.is_file() {
281 let mut f = std::fs::File::open(path)?;
282 return self.proto_model_for_read(&mut f);
283 }
284
285 let mut resources: HashMap<String, Arc<dyn Resource>> = Default::default();
286
287 for entry in walkdir::WalkDir::new(path).min_depth(1) {
290 let entry =
291 entry.map_err(|e| format_err!("Can not walk directory {:?}: {:?}", path, e))?;
292 if entry.path().is_dir() {
294 continue;
295 }
296 let subpath = entry
297 .path()
298 .components()
299 .skip(path.components().count())
300 .collect::<std::path::PathBuf>();
301 let mut stream = std::fs::File::open(entry.path())?;
302 read_stream(&subpath, &mut stream, &mut resources, self)?;
303 }
304 proto_model_from_resources(resources)
305 }
306
307 fn proto_model_for_read(&self, reader: &mut dyn std::io::Read) -> TractResult<ProtoModel> {
308 let mut resources: HashMap<String, Arc<dyn Resource>> = Default::default();
309
310 let mut buffer = vec![0u8; 2];
311 reader.read_exact(&mut buffer)?;
312 let header = std::io::Cursor::new(buffer.clone());
313 let stream = header.chain(reader);
314 let mut tar = if buffer == [0x1f, 0x8b] {
315 #[cfg(feature = "flate2")]
316 {
317 let f = flate2::read::GzDecoder::new(stream);
318 tar::Archive::new(Box::new(f) as Box<dyn Read>)
319 }
320 #[cfg(not(feature = "flate2"))]
321 bail!("Cannot read gzip file without flate2 enabled.");
322 } else {
323 tar::Archive::new(Box::new(stream) as Box<dyn Read>)
324 };
325 for entry in tar.entries()? {
326 let mut entry = entry?;
327 let mut path = entry.path()?.to_path_buf();
328 if path.starts_with("./") {
329 path = path.strip_prefix("./")?.to_path_buf();
330 }
331 read_stream(&path, &mut entry, &mut resources, self)?;
332 }
333 proto_model_from_resources(resources)
334 }
335
336 fn model_for_proto_model_with_model_template(
337 &self,
338 proto: &ProtoModel,
339 template: TypedModel,
340 ) -> TractResult<TypedModel> {
341 self.translate(proto, template).map_err(|e| e.1)
342 }
343}
344
345fn proto_model_from_resources(
346 resources: HashMap<String, Arc<dyn Resource>>,
347) -> TractResult<ProtoModel> {
348 let sub_models = resources
352 .keys()
353 .clone()
354 .filter_map(|id| {
355 let id_components = id.split('/').collect::<Vec<_>>();
356 if (id_components.last() == Some(&crate::resource::GRAPH_NNEF_FILENAME))
357 & (id_components.len() == 2)
358 {
359 id_components.first().map(|it| it.to_string())
360 } else {
361 None
362 }
363 })
364 .collect::<Vec<_>>();
365
366 let mut new_resources = if sub_models.len() > 0 {
369 sub_models.into_iter().try_fold(resources, |r, it| -> TractResult<HashMap<_, _>> {
370 let (submodel_resources, mut resources): (HashMap<String, Arc<dyn Resource>>, _) =
371 r.into_iter().partition(|(k, _v)| k.starts_with(&it));
372 let submodel_resources = submodel_resources
373 .into_iter()
374 .map(|(k, v)| (k.split('/').next_back().unwrap().to_string(), v))
375 .collect::<HashMap<String, Arc<dyn Resource>>>();
376 let typed_model = nnef()
377 .model_for_proto_model(&proto_model_from_resources(submodel_resources).unwrap())?;
378 resources.insert(it, Arc::new(TypedModelResource(typed_model)));
379 Ok(resources)
380 })?
381 } else {
382 resources
383 };
384
385 let doc = new_resources
387 .remove(crate::resource::GRAPH_NNEF_FILENAME)
388 .with_context(|| {
389 anyhow!("Resource {} was not found in the model", crate::resource::GRAPH_NNEF_FILENAME)
390 })?
391 .downcast_arc::<GraphNnef>()
392 .map_err(|_| anyhow!("Error while downcasting NNEF document resource"))?;
393
394 let doc = Arc::try_unwrap(doc)
395 .map_err(|_| anyhow!("Error while extracting NNEF Document from shared reference. Only one reference to the document is expected"))?;
396
397 let mut tensors: HashMap<_, _> = new_resources
399 .iter()
400 .filter_map(|(key, resource)| {
401 Arc::clone(resource)
402 .downcast_arc::<Tensor>()
403 .ok()
404 .map(|r| (Identifier::from(&**key), r))
405 })
406 .collect();
407
408 for r in new_resources.values() {
409 if let Some(safe_tensors) = r.downcast_ref::<Vec<(String, Arc<Tensor>)>>() {
410 for (name, t) in safe_tensors.iter() {
411 tensors.insert(Identifier::from(&**name), Arc::clone(t));
412 }
413 }
414 }
415 new_resources.retain(|_, r| !r.is::<Tensor>() && !r.is::<Vec<(String, Arc<Tensor>)>>());
416
417 let quantization = if let Some(q_r) =
419 new_resources.remove(crate::resource::GRAPH_QUANT_FILENAME)
420 {
421 let Ok(q_r) = q_r.downcast_arc::<HashMap<String, QuantFormat>>() else {
422 bail!("Error while downcasting quantization format resource")
423 };
424 let Ok(q_r) = Arc::try_unwrap(q_r) else {
425 bail!("Error while extracting quantization format resource from shared reference. Only one reference to it is expected")
426 };
427 Some(q_r.into_iter().map(|(k, v)| (Identifier(k), v)).collect())
428 } else {
429 None
430 };
431
432 let doc = crate::liquid::process_file(&doc.0, &new_resources)?;
433 let doc = crate::ast::parse::parse_document(&doc)?;
434
435 let proto = ProtoModel { doc, tensors, quantization, resources: new_resources };
436 proto.validate()?;
437 Ok(proto)
438}
439
440fn read_stream<R: std::io::Read>(
441 path: &Path,
442 reader: &mut R,
443 resources: &mut HashMap<String, Arc<dyn Resource>>,
444 framework: &Nnef,
445) -> TractResult<()> {
446 #[cfg(target_family = "unix")]
448 if path.components().any(|name| name.as_os_str().as_bytes().first() == Some(&b'.')) {
449 return Ok(());
450 }
451 let mut last_loader_name;
452 for loader in framework.resource_loaders.iter() {
453 last_loader_name = Some(loader.name());
454 let loaded = loader.try_load(path, reader, framework).with_context(|| {
455 anyhow!("Error while loading resource by {:?} at path {:?}", loader.name(), path)
456 })?;
457 if let Some((id, resource)) = loaded {
458 ensure!(
459 !resources.contains_key(&id),
460 "Loader {:?} succeeded to load {:?} which has been already loaded by {:?}",
461 loader.name(),
462 id,
463 last_loader_name
464 );
465 resources.insert(id, resource);
466 break;
467 }
468 }
469 Ok(())
470}