tract_nnef/
framework.rs

1use tar::Builder;
2use tract_core::tract_data::itertools::Itertools;
3use tract_core::transform::ModelTransform;
4
5use crate::ast::quant::write_quant_format;
6use crate::ast::{Identifier, ProtoModel, QuantFormat};
7use crate::resource::{GraphNnef, SafeTensorsLoader};
8use crate::{internal::*, nnef};
9use std::io::Read;
10#[cfg(target_family = "unix")]
11use std::os::unix::prelude::OsStrExt;
12use std::path::Path;
13use std::str::FromStr;
14
15pub fn stdlib() -> Vec<FragmentDef> {
16    crate::ast::parse::parse_fragments(include_str!("../stdlib.nnef")).unwrap()
17}
18
19pub struct Nnef {
20    pub stdlib: Vec<FragmentDef>,
21    pub registries: Vec<Registry>,
22    pub resource_loaders: Vec<Box<dyn ResourceLoader + 'static>>,
23    pub allow_extended_identifier_syntax: bool,
24}
25
26impl Default for Nnef {
27    fn default() -> Nnef {
28        Nnef {
29            stdlib: stdlib(),
30            registries: vec![crate::ops::tract_nnef()],
31            resource_loaders: vec![
32                GraphNnefLoader.into_boxed(),
33                DatLoader.into_boxed(),
34                GraphQuantLoader.into_boxed(),
35                TypedModelLoader::new(false).into_boxed(),
36                SafeTensorsLoader.into_boxed(),
37            ],
38            allow_extended_identifier_syntax: false,
39        }
40    }
41}
42
43impl Nnef {
44    pub fn with_registry(mut self, registry: Registry) -> Nnef {
45        self.registries.push(registry);
46        self
47    }
48
49    pub fn with_resource_loader(mut self, loader: impl ResourceLoader + 'static) -> Nnef {
50        self.resource_loaders.push(Box::new(loader));
51        self
52    }
53
54    pub fn enable_tract_core(&mut self) {
55        self.registries.push(crate::ops::tract_core());
56    }
57
58    pub fn with_tract_core(mut self) -> Self {
59        self.registries.push(crate::ops::tract_core());
60        self
61    }
62
63    pub fn enable_tract_resource(&mut self) {
64        self.registries.push(crate::ops::tract_resource());
65    }
66
67    pub fn with_tract_resource(mut self) -> Self {
68        self.registries.push(crate::ops::tract_resource());
69        self
70    }
71
72    pub fn get_transform(&self, spec: &str) -> TractResult<Option<Box<dyn ModelTransform>>> {
73        for reg in &self.registries {
74            if let Some(transform) = (reg.transforms)(spec)? {
75                return Ok(Some(transform));
76            }
77        }
78        Ok(None)
79    }
80
81    pub fn allow_extended_identifier_syntax(&mut self, allow_extended_identifier_syntax: bool) {
82        self.allow_extended_identifier_syntax = allow_extended_identifier_syntax;
83    }
84
85    #[allow(clippy::result_large_err)]
86    pub fn translate(
87        &self,
88        proto_model: &ProtoModel,
89        template: TypedModel,
90    ) -> Result<TypedModel, (TypedModel, TractError)> {
91        debug!("Build TypedModel from NNEF proto model (translate)");
92        ModelBuilder::new(self, proto_model, template).into_typed_model()
93    }
94
95    pub fn write(&self, model: &TypedModel, w: impl std::io::Write) -> TractResult<()> {
96        self.write_to_tar(model, w)?;
97        Ok(())
98    }
99
100    pub fn write_to_tar<W: std::io::Write>(&self, model: &TypedModel, w: W) -> TractResult<W> {
101        let mut ar = tar::Builder::new(w);
102        let timestamp =
103            std::time::SystemTime::now().duration_since(std::time::SystemTime::UNIX_EPOCH).unwrap();
104        self._write_to_tar(model, &mut ar, false, timestamp)?;
105        ar.into_inner().context("Finalizing tar")
106    }
107
108    pub fn write_to_tar_with_config<W: std::io::Write>(
109        &self,
110        model: &TypedModel,
111        w: W,
112        compress_nested_models: bool,
113        deterministic: bool,
114    ) -> TractResult<W> {
115        let mut ar = tar::Builder::new(w);
116        let timestamp = if deterministic {
117            // 1 Jan 1980, MS-DOS epoch. Some tools have issues with 0 timestamps.
118            std::time::Duration::from_secs(315532800)
119        } else {
120            std::time::SystemTime::now().duration_since(std::time::SystemTime::UNIX_EPOCH).unwrap()
121        };
122
123        self._write_to_tar(model, &mut ar, compress_nested_models, timestamp)?;
124        ar.into_inner().context("Finalizing tar")
125    }
126
127    fn _write_to_tar<W: std::io::Write>(
128        &self,
129        model: &TypedModel,
130        ar: &mut Builder<W>,
131        compress_nested_models: bool,
132        timestamp: std::time::Duration,
133    ) -> TractResult<()> {
134        let proto_model =
135            crate::ser::to_proto_model(self, model).context("Translating model to proto_model")?;
136
137        let mut graph_data = vec![];
138        crate::ast::dump::Dumper::new(self, &mut graph_data)
139            .document(&proto_model.doc)
140            .context("Serializing graph.nnef")?;
141
142        let mut header = tar::Header::new_gnu();
143        header.set_path("graph.nnef").context("Setting graph.nnef path")?;
144        header.set_size(graph_data.len() as u64);
145        header.set_mode(0o644);
146        header.set_mtime(timestamp.as_secs());
147        header.set_cksum();
148        ar.append(&header, &mut &*graph_data).context("Appending graph.nnef")?;
149
150        if let Some(mut quantization) = proto_model.quantization {
151            let mut quant_data = vec![];
152
153            let mut keys = quantization.keys().cloned().collect::<Vec<_>>();
154            keys.sort();
155            for name in keys {
156                let format = quantization.remove(&name).unwrap();
157                write_quant_format(
158                    &mut quant_data,
159                    &name,
160                    format,
161                    self.allow_extended_identifier_syntax,
162                )
163                .context("Serializing graph.quant")?;
164            }
165
166            header.set_path("graph.quant").context("Setting graph.quant path")?;
167            header.set_size(quant_data.len() as u64);
168            header.set_mode(0o644);
169            header.set_mtime(timestamp.as_secs());
170            header.set_cksum();
171            ar.append(&header, &mut &*quant_data).context("Appending graph.quant")?;
172        }
173
174        let mut labels = proto_model.tensors.keys().collect::<Vec<_>>();
175        labels.sort();
176        for label in labels {
177            let t = proto_model.tensors.get(label).unwrap();
178            let mut label = label.0.to_string() + ".dat";
179            if label.starts_with('/') {
180                label.insert(0, '.');
181            }
182            let filename = std::path::Path::new(&label);
183            let mut data = vec![];
184            crate::tensors::write_tensor(&mut data, t)
185                .with_context(|| format!("Serializing tensor {filename:?}: {t:?}"))?;
186
187            let mut header = tar::Header::new_gnu();
188            header.set_size(data.len() as u64);
189            header.set_mode(0o644);
190            header.set_mtime(timestamp.as_secs());
191            header.set_cksum();
192
193            ar.append_data(&mut header, filename, &mut &*data)
194                .with_context(|| format!("Appending tensor {filename:?}"))?;
195        }
196
197        let mut labels = proto_model.resources.keys().collect::<Vec<_>>();
198        labels.sort();
199        for label in labels {
200            let resource = proto_model.resources.get(label).unwrap();
201            if let Some(typed_model_resource) = resource.downcast_ref::<TypedModelResource>() {
202                let mut submodel_data = vec![];
203                let mut filename = std::path::PathBuf::from_str(label)?;
204                let typed_model = &typed_model_resource.0;
205
206                if compress_nested_models {
207                    filename.set_extension("nnef.tgz");
208                    let encoder = flate2::write::GzEncoder::new(
209                        &mut submodel_data,
210                        flate2::Compression::default(),
211                    );
212                    self.write(typed_model, encoder)?;
213                } else {
214                    filename.set_extension("nnef.tar");
215                    self.write(typed_model, &mut submodel_data)?;
216                }
217
218                let mut header = tar::Header::new_gnu();
219                header.set_size(submodel_data.len() as u64);
220                header.set_mode(0o644);
221                header.set_mtime(timestamp.as_secs());
222                header.set_cksum();
223
224                ar.append_data(&mut header, filename, &mut &*submodel_data)
225                    .with_context(|| format!("Appending submodel {label:?}"))?;
226            }
227        }
228        Ok(())
229    }
230
231    pub fn write_to_dir(
232        &self,
233        model: &TypedModel,
234        path: impl AsRef<std::path::Path>,
235    ) -> TractResult<()> {
236        let path = path.as_ref();
237        if path.exists() {
238            bail!("{:?} already exists. Won't overwrite.", path);
239        }
240        let proto_model = crate::ser::to_proto_model(self, model)?;
241        std::fs::create_dir_all(path)?;
242        let mut graph_nnef = std::fs::File::create(path.join("graph.nnef"))?;
243        crate::ast::dump::Dumper::new(self, &mut graph_nnef).document(&proto_model.doc)?;
244
245        if let Some(quantization) = proto_model.quantization {
246            let mut graph_quant = std::fs::File::create(path.join("graph.quant"))?;
247            for (name, format) in quantization.into_iter().sorted_by_key(|(x, _)| x.clone()) {
248                write_quant_format(
249                    &mut graph_quant,
250                    &name,
251                    format,
252                    self.allow_extended_identifier_syntax,
253                )?;
254            }
255        }
256
257        for (label, t) in &proto_model.tensors {
258            let label = label.0.to_string() + ".dat";
259            let label = label.trim_start_matches('/');
260            let parent = path.join(label).parent().unwrap().to_owned();
261            std::fs::create_dir_all(&parent).with_context(|| format!("Creating dir {parent:?}"))?;
262            let filename = path.join(label).to_owned();
263            let mut file = std::fs::File::create(&filename)
264                .with_context(|| format!("Creating file {filename:?}"))?;
265
266            crate::tensors::write_tensor(&mut file, t)?;
267        }
268        Ok(())
269    }
270}
271
272impl tract_core::prelude::Framework<ProtoModel, TypedModel> for Nnef {
273    fn model_for_path(&self, p: impl AsRef<Path>) -> TractResult<TypedModel> {
274        let proto = self.proto_model_for_path(p)?;
275        self.model_for_proto_model(&proto)
276    }
277
278    fn proto_model_for_path(&self, path: impl AsRef<Path>) -> TractResult<ProtoModel> {
279        let path = path.as_ref();
280        if path.is_file() {
281            let mut f = std::fs::File::open(path)?;
282            return self.proto_model_for_read(&mut f);
283        }
284
285        let mut resources: HashMap<String, Arc<dyn Resource>> = Default::default();
286
287        // `walkdir::new` will first yield the given path at depth 0, but we don't want to load this
288        // entry here: only its descendants at depth >= 1.
289        for entry in walkdir::WalkDir::new(path).min_depth(1) {
290            let entry =
291                entry.map_err(|e| format_err!("Can not walk directory {:?}: {:?}", path, e))?;
292            // We don't want to load sub-directories themselves either.
293            if entry.path().is_dir() {
294                continue;
295            }
296            let subpath = entry
297                .path()
298                .components()
299                .skip(path.components().count())
300                .collect::<std::path::PathBuf>();
301            let mut stream = std::fs::File::open(entry.path())?;
302            read_stream(&subpath, &mut stream, &mut resources, self)?;
303        }
304        proto_model_from_resources(resources)
305    }
306
307    fn proto_model_for_read(&self, reader: &mut dyn std::io::Read) -> TractResult<ProtoModel> {
308        let mut resources: HashMap<String, Arc<dyn Resource>> = Default::default();
309
310        let mut buffer = vec![0u8; 2];
311        reader.read_exact(&mut buffer)?;
312        let header = std::io::Cursor::new(buffer.clone());
313        let stream = header.chain(reader);
314        let mut tar = if buffer == [0x1f, 0x8b] {
315            #[cfg(feature = "flate2")]
316            {
317                let f = flate2::read::GzDecoder::new(stream);
318                tar::Archive::new(Box::new(f) as Box<dyn Read>)
319            }
320            #[cfg(not(feature = "flate2"))]
321            bail!("Cannot read gzip file without flate2 enabled.");
322        } else {
323            tar::Archive::new(Box::new(stream) as Box<dyn Read>)
324        };
325        for entry in tar.entries()? {
326            let mut entry = entry?;
327            let mut path = entry.path()?.to_path_buf();
328            if path.starts_with("./") {
329                path = path.strip_prefix("./")?.to_path_buf();
330            }
331            read_stream(&path, &mut entry, &mut resources, self)?;
332        }
333        proto_model_from_resources(resources)
334    }
335
336    fn model_for_proto_model_with_model_template(
337        &self,
338        proto: &ProtoModel,
339        template: TypedModel,
340    ) -> TractResult<TypedModel> {
341        self.translate(proto, template).map_err(|e| e.1)
342    }
343}
344
345fn proto_model_from_resources(
346    resources: HashMap<String, Arc<dyn Resource>>,
347) -> TractResult<ProtoModel> {
348    // Iter resources IDs to detect submodels. Submodels are IDs with
349    // - two path compoents (ex: XXX/file)
350    // - a graph.nnef file as filename
351    let sub_models = resources
352        .keys()
353        .clone()
354        .filter_map(|id| {
355            let id_components = id.split('/').collect::<Vec<_>>();
356            if (id_components.last() == Some(&crate::resource::GRAPH_NNEF_FILENAME))
357                & (id_components.len() == 2)
358            {
359                id_components.first().map(|it| it.to_string())
360            } else {
361                None
362            }
363        })
364        .collect::<Vec<_>>();
365
366    // If there are submodels, we use the associated resources to create a TypedModel resource and add
367    // it as a new resource.
368    let mut new_resources = if sub_models.len() > 0 {
369        sub_models.into_iter().try_fold(resources, |r, it| -> TractResult<HashMap<_, _>> {
370            let (submodel_resources, mut resources): (HashMap<String, Arc<dyn Resource>>, _) =
371                r.into_iter().partition(|(k, _v)| k.starts_with(&it));
372            let submodel_resources = submodel_resources
373                .into_iter()
374                .map(|(k, v)| (k.split('/').next_back().unwrap().to_string(), v))
375                .collect::<HashMap<String, Arc<dyn Resource>>>();
376            let typed_model = nnef()
377                .model_for_proto_model(&proto_model_from_resources(submodel_resources).unwrap())?;
378            resources.insert(it, Arc::new(TypedModelResource(typed_model)));
379            Ok(resources)
380        })?
381    } else {
382        resources
383    };
384
385    // NNEF document extraction
386    let doc = new_resources
387        .remove(crate::resource::GRAPH_NNEF_FILENAME)
388        .with_context(|| {
389            anyhow!("Resource {} was not found in the model", crate::resource::GRAPH_NNEF_FILENAME)
390        })?
391        .downcast_arc::<GraphNnef>()
392        .map_err(|_| anyhow!("Error while downcasting NNEF document resource"))?;
393
394    let doc = Arc::try_unwrap(doc)
395            .map_err(|_| anyhow!("Error while extracting NNEF Document from shared reference. Only one reference to the document is expected"))?;
396
397    // Collect all resources that can be downcastable to Arc<Tensor>.
398    let mut tensors: HashMap<_, _> = new_resources
399        .iter()
400        .filter_map(|(key, resource)| {
401            Arc::clone(resource)
402                .downcast_arc::<Tensor>()
403                .ok()
404                .map(|r| (Identifier::from(&**key), r))
405        })
406        .collect();
407
408    for r in new_resources.values() {
409        if let Some(safe_tensors) = r.downcast_ref::<Vec<(String, Arc<Tensor>)>>() {
410            for (name, t) in safe_tensors.iter() {
411                tensors.insert(Identifier::from(&**name), Arc::clone(t));
412            }
413        }
414    }
415    new_resources.retain(|_, r| !r.is::<Tensor>() && !r.is::<Vec<(String, Arc<Tensor>)>>());
416
417    // Quantization format resources extraction if present.
418    let quantization = if let Some(q_r) =
419        new_resources.remove(crate::resource::GRAPH_QUANT_FILENAME)
420    {
421        let Ok(q_r) = q_r.downcast_arc::<HashMap<String, QuantFormat>>() else {
422            bail!("Error while downcasting quantization format resource")
423        };
424        let Ok(q_r) = Arc::try_unwrap(q_r) else {
425            bail!("Error while extracting quantization format resource from shared reference. Only one reference to it is expected")
426        };
427        Some(q_r.into_iter().map(|(k, v)| (Identifier(k), v)).collect())
428    } else {
429        None
430    };
431
432    let doc = crate::liquid::process_file(&doc.0, &new_resources)?;
433    let doc = crate::ast::parse::parse_document(&doc)?;
434
435    let proto = ProtoModel { doc, tensors, quantization, resources: new_resources };
436    proto.validate()?;
437    Ok(proto)
438}
439
440fn read_stream<R: std::io::Read>(
441    path: &Path,
442    reader: &mut R,
443    resources: &mut HashMap<String, Arc<dyn Resource>>,
444    framework: &Nnef,
445) -> TractResult<()> {
446    // ignore path with any component starting with "." (because OSX's tar is weird)
447    #[cfg(target_family = "unix")]
448    if path.components().any(|name| name.as_os_str().as_bytes().first() == Some(&b'.')) {
449        return Ok(());
450    }
451    let mut last_loader_name;
452    for loader in framework.resource_loaders.iter() {
453        last_loader_name = Some(loader.name());
454        let loaded = loader.try_load(path, reader, framework).with_context(|| {
455            anyhow!("Error while loading resource by {:?} at path {:?}", loader.name(), path)
456        })?;
457        if let Some((id, resource)) = loaded {
458            ensure!(
459                !resources.contains_key(&id),
460                "Loader {:?} succeeded to load {:?} which has been already loaded by {:?}",
461                loader.name(),
462                id,
463                last_loader_name
464            );
465            resources.insert(id, resource);
466            break;
467        }
468    }
469    Ok(())
470}