tract_nnef/
framework.rs

1use tar::Builder;
2use tract_core::tract_data::itertools::Itertools;
3
4use crate::ast::quant::write_quant_format;
5use crate::ast::{Document, Identifier, ProtoModel, QuantFormat};
6use crate::{internal::*, nnef};
7use std::io::Read;
8#[cfg(target_family = "unix")]
9use std::os::unix::prelude::OsStrExt;
10use std::path::Path;
11use std::str::FromStr;
12
13pub fn stdlib() -> Vec<FragmentDef> {
14    crate::ast::parse::parse_fragments(include_str!("../stdlib.nnef")).unwrap()
15}
16
17pub struct Nnef {
18    pub stdlib: Vec<FragmentDef>,
19    pub registries: Vec<Registry>,
20    pub resource_loaders: Vec<Box<dyn ResourceLoader + 'static>>,
21    pub allow_extended_identifier_syntax: bool,
22}
23
24impl Default for Nnef {
25    fn default() -> Nnef {
26        Nnef {
27            stdlib: stdlib(),
28            registries: vec![crate::ops::tract_nnef()],
29            resource_loaders: vec![
30                GraphNnefLoader.into_boxed(),
31                DatLoader.into_boxed(),
32                GraphQuantLoader.into_boxed(),
33                TypedModelLoader::new(false).into_boxed(),
34            ],
35            allow_extended_identifier_syntax: false,
36        }
37    }
38}
39
40impl Nnef {
41    pub fn with_registry(mut self, registry: Registry) -> Nnef {
42        self.registries.push(registry);
43        self
44    }
45
46    pub fn with_resource_loader(mut self, loader: impl ResourceLoader + 'static) -> Nnef {
47        self.resource_loaders.push(Box::new(loader));
48        self
49    }
50
51    pub fn enable_tract_core(&mut self) {
52        self.registries.push(crate::ops::tract_core());
53    }
54
55    pub fn with_tract_core(mut self) -> Self {
56        self.registries.push(crate::ops::tract_core());
57        self
58    }
59
60    pub fn enable_tract_resource(&mut self) {
61        self.registries.push(crate::ops::tract_resource());
62    }
63
64    pub fn with_tract_resource(mut self) -> Self {
65        self.registries.push(crate::ops::tract_resource());
66        self
67    }
68
69    pub fn allow_extended_identifier_syntax(&mut self, allow_extended_identifier_syntax: bool) {
70        self.allow_extended_identifier_syntax = allow_extended_identifier_syntax;
71    }
72
73    #[allow(clippy::result_large_err)]
74    pub fn translate(
75        &self,
76        proto_model: &ProtoModel,
77        template: TypedModel,
78    ) -> Result<TypedModel, (TypedModel, TractError)> {
79        ModelBuilder::new(self, proto_model, template).into_typed_model()
80    }
81
82    pub fn write(&self, model: &TypedModel, w: impl std::io::Write) -> TractResult<()> {
83        self.write_to_tar(model, w)?;
84        Ok(())
85    }
86
87    pub fn write_to_tar<W: std::io::Write>(&self, model: &TypedModel, w: W) -> TractResult<W> {
88        let mut ar = tar::Builder::new(w);
89        let timestamp =
90            std::time::SystemTime::now().duration_since(std::time::SystemTime::UNIX_EPOCH).unwrap();
91        self._write_to_tar(model, &mut ar, false, timestamp)?;
92        ar.into_inner().context("Finalizing tar")
93    }
94
95    pub fn write_to_tar_with_config<W: std::io::Write>(
96        &self,
97        model: &TypedModel,
98        w: W,
99        compress_nested_models: bool,
100        deterministic: bool,
101    ) -> TractResult<W> {
102        let mut ar = tar::Builder::new(w);
103        let timestamp = if deterministic {
104            // 1 Jan 1980, MS-DOS epoch. Some tools have issues with 0 timestamps.
105            std::time::Duration::from_secs(315532800)
106        } else {
107            std::time::SystemTime::now().duration_since(std::time::SystemTime::UNIX_EPOCH).unwrap()
108        };
109
110        self._write_to_tar(model, &mut ar, compress_nested_models, timestamp)?;
111        ar.into_inner().context("Finalizing tar")
112    }
113
114    fn _write_to_tar<W: std::io::Write>(
115        &self,
116        model: &TypedModel,
117        ar: &mut Builder<W>,
118        compress_nested_models: bool,
119        timestamp: std::time::Duration,
120    ) -> TractResult<()> {
121        let proto_model =
122            crate::ser::to_proto_model(self, model).context("Translating model to proto_model")?;
123
124        let mut graph_data = vec![];
125        crate::ast::dump::Dumper::new(self, &mut graph_data)
126            .document(&proto_model.doc)
127            .context("Serializing graph.nnef")?;
128
129        let mut header = tar::Header::new_gnu();
130        header.set_path("graph.nnef").context("Setting graph.nnef path")?;
131        header.set_size(graph_data.len() as u64);
132        header.set_mode(0o644);
133        header.set_mtime(timestamp.as_secs());
134        header.set_cksum();
135        ar.append(&header, &mut &*graph_data).context("Appending graph.nnef")?;
136
137        if let Some(mut quantization) = proto_model.quantization {
138            let mut quant_data = vec![];
139
140            let mut keys = quantization.keys().cloned().collect::<Vec<_>>();
141            keys.sort();
142            for name in keys {
143                let format = quantization.remove(&name).unwrap();
144                write_quant_format(
145                    &mut quant_data,
146                    &name,
147                    format,
148                    self.allow_extended_identifier_syntax,
149                )
150                .context("Serializing graph.quant")?;
151            }
152
153            header.set_path("graph.quant").context("Setting graph.quant path")?;
154            header.set_size(quant_data.len() as u64);
155            header.set_mode(0o644);
156            header.set_mtime(timestamp.as_secs());
157            header.set_cksum();
158            ar.append(&header, &mut &*quant_data).context("Appending graph.quant")?;
159        }
160
161        let mut labels = proto_model.tensors.keys().collect::<Vec<_>>();
162        labels.sort();
163        for label in labels {
164            let t = proto_model.tensors.get(label).unwrap();
165            let mut label = label.0.to_string() + ".dat";
166            if label.starts_with('/') {
167                label.insert(0, '.');
168            }
169            let filename = std::path::Path::new(&label);
170            let mut data = vec![];
171            crate::tensors::write_tensor(&mut data, t)
172                .with_context(|| format!("Serializing tensor {filename:?}: {t:?}"))?;
173
174            let mut header = tar::Header::new_gnu();
175            header.set_size(data.len() as u64);
176            header.set_mode(0o644);
177            header.set_mtime(timestamp.as_secs());
178            header.set_cksum();
179
180            ar.append_data(&mut header, filename, &mut &*data)
181                .with_context(|| format!("Appending tensor {filename:?}"))?;
182        }
183
184        let mut labels = proto_model.resources.keys().collect::<Vec<_>>();
185        labels.sort();
186        for label in labels {
187            let resource = proto_model.resources.get(label).unwrap();
188            if let Some(typed_model_resource) = resource.downcast_ref::<TypedModelResource>() {
189                let mut submodel_data = vec![];
190                let mut filename = std::path::PathBuf::from_str(label)?;
191                let typed_model = &typed_model_resource.0;
192
193                if compress_nested_models {
194                    filename.set_extension("nnef.tgz");
195                    let encoder = flate2::write::GzEncoder::new(
196                        &mut submodel_data,
197                        flate2::Compression::default(),
198                    );
199                    self.write(typed_model, encoder)?;
200                } else {
201                    filename.set_extension("nnef.tar");
202                    self.write(typed_model, &mut submodel_data)?;
203                }
204
205                let mut header = tar::Header::new_gnu();
206                header.set_size(submodel_data.len() as u64);
207                header.set_mode(0o644);
208                header.set_mtime(timestamp.as_secs());
209                header.set_cksum();
210
211                ar.append_data(&mut header, filename, &mut &*submodel_data)
212                    .with_context(|| format!("Appending submodel {label:?}"))?;
213            }
214        }
215        Ok(())
216    }
217
218    pub fn write_to_dir(
219        &self,
220        model: &TypedModel,
221        path: impl AsRef<std::path::Path>,
222    ) -> TractResult<()> {
223        let path = path.as_ref();
224        if path.exists() {
225            bail!("{:?} already exists. Won't overwrite.", path);
226        }
227        let proto_model = crate::ser::to_proto_model(self, model)?;
228        std::fs::create_dir_all(path)?;
229        let mut graph_nnef = std::fs::File::create(path.join("graph.nnef"))?;
230        crate::ast::dump::Dumper::new(self, &mut graph_nnef).document(&proto_model.doc)?;
231
232        if let Some(quantization) = proto_model.quantization {
233            let mut graph_quant = std::fs::File::create(path.join("graph.quant"))?;
234            for (name, format) in quantization.into_iter().sorted_by_key(|(x, _)| x.clone()) {
235                write_quant_format(
236                    &mut graph_quant,
237                    &name,
238                    format,
239                    self.allow_extended_identifier_syntax,
240                )?;
241            }
242        }
243
244        for (label, t) in &proto_model.tensors {
245            let label = label.0.to_string() + ".dat";
246            let label = label.trim_start_matches('/');
247            let parent = path.join(label).parent().unwrap().to_owned();
248            std::fs::create_dir_all(&parent).with_context(|| format!("Creating dir {parent:?}"))?;
249            let filename = path.join(label).to_owned();
250            let mut file = std::fs::File::create(&filename)
251                .with_context(|| format!("Creating file {filename:?}"))?;
252
253            crate::tensors::write_tensor(&mut file, t)?;
254        }
255        Ok(())
256    }
257}
258
259impl tract_core::prelude::Framework<ProtoModel, TypedModel> for Nnef {
260    fn model_for_path(&self, p: impl AsRef<Path>) -> TractResult<TypedModel> {
261        let proto = self.proto_model_for_path(p)?;
262        self.model_for_proto_model(&proto)
263    }
264
265    fn proto_model_for_path(&self, path: impl AsRef<Path>) -> TractResult<ProtoModel> {
266        let path = path.as_ref();
267        if path.is_file() {
268            let mut f = std::fs::File::open(path)?;
269            return self.proto_model_for_read(&mut f);
270        }
271
272        let mut resources: HashMap<String, Arc<dyn Resource>> = Default::default();
273
274        // `walkdir::new` will first yield the given path at depth 0, but we don't want to load this
275        // entry here: only its descendants at depth >= 1.
276        for entry in walkdir::WalkDir::new(path).min_depth(1) {
277            let entry =
278                entry.map_err(|e| format_err!("Can not walk directory {:?}: {:?}", path, e))?;
279            // We don't want to load sub-directories themselves either.
280            if entry.path().is_dir() {
281                continue;
282            }
283            let subpath = entry
284                .path()
285                .components()
286                .skip(path.components().count())
287                .collect::<std::path::PathBuf>();
288            let mut stream = std::fs::File::open(entry.path())?;
289            read_stream(&subpath, &mut stream, &mut resources, self)?;
290        }
291        proto_model_from_resources(resources)
292    }
293
294    fn proto_model_for_read(&self, reader: &mut dyn std::io::Read) -> TractResult<ProtoModel> {
295        let mut resources: HashMap<String, Arc<dyn Resource>> = Default::default();
296
297        let mut buffer = vec![0u8; 2];
298        reader.read_exact(&mut buffer)?;
299        let header = std::io::Cursor::new(buffer.clone());
300        let stream = header.chain(reader);
301        let mut tar = if buffer == [0x1f, 0x8b] {
302            #[cfg(feature = "flate2")]
303            {
304                let f = flate2::read::GzDecoder::new(stream);
305                tar::Archive::new(Box::new(f) as Box<dyn Read>)
306            }
307            #[cfg(not(feature = "flate2"))]
308            bail!("Cannot read gzip file without flate2 enabled.");
309        } else {
310            tar::Archive::new(Box::new(stream) as Box<dyn Read>)
311        };
312        for entry in tar.entries()? {
313            let mut entry = entry?;
314            let mut path = entry.path()?.to_path_buf();
315            if path.starts_with("./") {
316                path = path.strip_prefix("./")?.to_path_buf();
317            }
318            read_stream(&path, &mut entry, &mut resources, self)?;
319        }
320        proto_model_from_resources(resources)
321    }
322
323    fn model_for_proto_model_with_model_template(
324        &self,
325        proto: &ProtoModel,
326        template: TypedModel,
327    ) -> TractResult<TypedModel> {
328        self.translate(proto, template).map_err(|e| e.1)
329    }
330}
331
332fn proto_model_from_resources(
333    resources: HashMap<String, Arc<dyn Resource>>,
334) -> TractResult<ProtoModel> {
335    // Iter resources IDs to detect submodels. Submodels are IDs with
336    // - two path compoents (ex: XXX/file)
337    // - a graph.nnef file as filename
338    let sub_models = resources
339        .keys()
340        .clone()
341        .filter_map(|id| {
342            let id_components = id.split('/').collect::<Vec<_>>();
343            if (id_components.last() == Some(&crate::resource::GRAPH_NNEF_FILENAME))
344                & (id_components.len() == 2)
345            {
346                id_components.first().map(|it| it.to_string())
347            } else {
348                None
349            }
350        })
351        .collect::<Vec<_>>();
352
353    // If there are submodels, we use the associated resources to create a TypedModel resource and add
354    // it as a new resource.
355    let mut new_resources = if sub_models.len() > 0 {
356        sub_models.into_iter().try_fold(resources, |r, it| -> TractResult<HashMap<_, _>> {
357            let (submodel_resources, mut resources): (HashMap<String, Arc<dyn Resource>>, _) =
358                r.into_iter().partition(|(k, _v)| k.starts_with(&it));
359            let submodel_resources = submodel_resources
360                .into_iter()
361                .map(|(k, v)| (k.split('/').next_back().unwrap().to_string(), v))
362                .collect::<HashMap<String, Arc<dyn Resource>>>();
363            let typed_model = nnef()
364                .model_for_proto_model(&proto_model_from_resources(submodel_resources).unwrap())?;
365            resources.insert(it, Arc::new(TypedModelResource(typed_model)));
366            Ok(resources)
367        })?
368    } else {
369        resources
370    };
371
372    // NNEF document extraction
373    let doc = new_resources
374        .remove(crate::resource::GRAPH_NNEF_FILENAME)
375        .with_context(|| {
376            anyhow!("Resource {} was not found in the model", crate::resource::GRAPH_NNEF_FILENAME)
377        })?
378        .downcast_arc::<Document>()
379        .map_err(|_| anyhow!("Error while downcasting NNEF document resource"))?;
380
381    let doc = Arc::try_unwrap(doc)
382            .map_err(|_| anyhow!("Error while extracting NNEF Document from shared reference. Only one reference to the document is expected"))?;
383
384    // Collect all resources that can be downcastable to Arc<Tensor>.
385    let tensors: HashMap<_, _> = new_resources
386        .iter()
387        .filter_map(|(key, resource)| {
388            Arc::clone(resource)
389                .downcast_arc::<Tensor>()
390                .ok()
391                .map(|r| (Identifier::from(&**key), r))
392        })
393        .collect();
394    // Iterate over tensors keys to remove them from the global resources hash map.
395    tensors.keys().for_each(|k| {
396        new_resources.remove(&*k.0);
397    });
398
399    // Quantization format resources extraction if present.
400    let quantization = if let Some(q_r) =
401        new_resources.remove(crate::resource::GRAPH_QUANT_FILENAME)
402    {
403        let Ok(q_r) = q_r.downcast_arc::<HashMap<String, QuantFormat>>() else {
404            bail!("Error while downcasting quantization format resource")
405        };
406        let Ok(q_r) = Arc::try_unwrap(q_r) else {
407            bail!("Error while extracting quantization format resource from shared reference. Only one reference to it is expected")
408        };
409        Some(q_r.into_iter().map(|(k, v)| (Identifier(k), v)).collect())
410    } else {
411        None
412    };
413
414    let proto = ProtoModel { doc, tensors, quantization, resources: new_resources };
415    proto.validate()?;
416    Ok(proto)
417}
418
419fn read_stream<R: std::io::Read>(
420    path: &Path,
421    reader: &mut R,
422    resources: &mut HashMap<String, Arc<dyn Resource>>,
423    framework: &Nnef,
424) -> TractResult<()> {
425    // ignore path with any component starting with "." (because OSX's tar is weird)
426    #[cfg(target_family = "unix")]
427    if path.components().any(|name| name.as_os_str().as_bytes().first() == Some(&b'.')) {
428        return Ok(());
429    }
430    let mut last_loader_name;
431    for loader in framework.resource_loaders.iter() {
432        last_loader_name = Some(loader.name());
433        let loaded = loader.try_load(path, reader, framework).with_context(|| {
434            anyhow!("Error while loading resource by {:?} at path {:?}", loader.name(), path)
435        })?;
436        if let Some((id, resource)) = loaded {
437            ensure!(
438                !resources.contains_key(&id),
439                "Loader {:?} succeeded to load {:?} which has been already loaded by {:?}",
440                loader.name(),
441                id,
442                last_loader_name
443            );
444            resources.insert(id, resource);
445            break;
446        }
447    }
448    Ok(())
449}