tract_nnef/
resource.rs

1use std::path::Path;
2
3use crate::ast::{Document, QuantFormat};
4use crate::internal::*;
5use tract_core::downcast_rs::{impl_downcast, DowncastSync};
6
7pub const GRAPH_NNEF_FILENAME: &str = "graph.nnef";
8pub const GRAPH_QUANT_FILENAME: &str = "graph.quant";
9
10pub fn resource_path_to_id(path: impl AsRef<Path>) -> TractResult<String> {
11    let mut path = path.as_ref().to_path_buf();
12    path.set_extension("");
13    path.to_str()
14        .ok_or_else(|| format_err!("Badly encoded filename for path: {:?}", path))
15        .map(|s| s.to_string())
16}
17
18pub trait Resource: DowncastSync + std::fmt::Debug + Send + Sync {
19    /// Get value for a given key.
20    fn get(&self, _key: &str) -> TractResult<Value> {
21        bail!("No key access supported by this resource");
22    }
23}
24
25impl_downcast!(sync Resource);
26
27pub trait ResourceLoader: Send + Sync {
28    /// Name of the resource loader.
29    fn name(&self) -> Cow<str>;
30    /// Try to load a resource give a path and its corresponding reader.
31    /// None is returned if the path is not accepted by this loader.
32    fn try_load(
33        &self,
34        path: &Path,
35        reader: &mut dyn std::io::Read,
36        framework: &Nnef,
37    ) -> TractResult<Option<(String, Arc<dyn Resource>)>>;
38
39    fn into_boxed(self) -> Box<dyn ResourceLoader>
40    where
41        Self: Sized + 'static,
42    {
43        Box::new(self)
44    }
45}
46
47impl Resource for Document {}
48
49#[derive(Clone, Debug, Copy, PartialEq, Eq, Hash)]
50pub struct GraphNnefLoader;
51
52impl ResourceLoader for GraphNnefLoader {
53    fn name(&self) -> Cow<str> {
54        "GraphNnefLoader".into()
55    }
56
57    fn try_load(
58        &self,
59        path: &Path,
60        reader: &mut dyn std::io::Read,
61        _framework: &Nnef,
62    ) -> TractResult<Option<(String, Arc<dyn Resource>)>> {
63        if path.ends_with(GRAPH_NNEF_FILENAME) {
64            let mut text = String::new();
65            reader.read_to_string(&mut text)?;
66            let document = crate::ast::parse::parse_document(&text)?;
67            Ok(Some((path.to_str().unwrap().to_string(), Arc::new(document))))
68        } else {
69            Ok(None)
70        }
71    }
72}
73
74impl Resource for Tensor {}
75
76#[derive(Clone, Debug, Copy, PartialEq, Eq, Hash)]
77pub struct DatLoader;
78
79impl ResourceLoader for DatLoader {
80    fn name(&self) -> Cow<str> {
81        "DatLoader".into()
82    }
83
84    fn try_load(
85        &self,
86        path: &Path,
87        reader: &mut dyn std::io::Read,
88        _framework: &Nnef,
89    ) -> TractResult<Option<(String, Arc<dyn Resource>)>> {
90        if path.extension().map(|e| e == "dat").unwrap_or(false) {
91            let tensor = crate::tensors::read_tensor(reader)
92                .with_context(|| format!("Error while reading tensor {path:?}"))?;
93            Ok(Some((resource_path_to_id(path)?, Arc::new(tensor))))
94        } else {
95            Ok(None)
96        }
97    }
98}
99
100impl Resource for HashMap<String, QuantFormat> {}
101
102#[derive(Clone, Debug, Copy, PartialEq, Eq, Hash)]
103pub struct GraphQuantLoader;
104
105impl ResourceLoader for GraphQuantLoader {
106    fn name(&self) -> Cow<str> {
107        "GraphQuantLoader".into()
108    }
109
110    fn try_load(
111        &self,
112        path: &Path,
113        reader: &mut dyn std::io::Read,
114        _framework: &Nnef,
115    ) -> TractResult<Option<(String, Arc<dyn Resource>)>> {
116        if path.ends_with(GRAPH_QUANT_FILENAME) {
117            let mut text = String::new();
118            reader.read_to_string(&mut text)?;
119            let quant = crate::ast::quant::parse_quantization(&text)?;
120            let quant: HashMap<String, QuantFormat> =
121                quant.into_iter().map(|(k, v)| (k.0, v)).collect();
122            Ok(Some((path.to_str().unwrap().to_string(), Arc::new(quant))))
123        } else {
124            Ok(None)
125        }
126    }
127}
128
129pub struct TypedModelLoader {
130    pub optimized_model: bool,
131}
132
133impl TypedModelLoader {
134    pub fn new(optimized_model: bool) -> Self {
135        Self { optimized_model }
136    }
137}
138
139impl ResourceLoader for TypedModelLoader {
140    fn name(&self) -> Cow<str> {
141        "TypedModelLoader".into()
142    }
143
144    fn try_load(
145        &self,
146        path: &Path,
147        reader: &mut dyn std::io::Read,
148        framework: &Nnef,
149    ) -> TractResult<Option<(String, Arc<dyn Resource>)>> {
150        const NNEF_TGZ: &str = ".nnef.tgz";
151        const NNEF_TAR: &str = ".nnef.tar";
152        let path_str = path.to_str().unwrap_or("");
153        if path_str.ends_with(NNEF_TGZ) || path_str.ends_with(NNEF_TAR) {
154            let model = if self.optimized_model {
155                framework.model_for_read(reader)?.into_optimized()?
156            } else {
157                framework.model_for_read(reader)?
158            };
159
160            let label = if path_str.ends_with(NNEF_TGZ) {
161                path
162                    .to_str()
163                    .ok_or_else(|| anyhow!("invalid model resource path"))?
164                    .trim_end_matches(NNEF_TGZ)
165            } else {
166                path
167                    .to_str()
168                    .ok_or_else(|| anyhow!("invalid model resource path"))?
169                    .trim_end_matches(NNEF_TAR)
170            };
171            Ok(Some((resource_path_to_id(label)?, Arc::new(TypedModelResource(model)))))
172        } else {
173            Ok(None)
174        }
175    }
176}
177
178#[derive(Debug, Clone)]
179pub struct TypedModelResource(pub TypedModel);
180
181impl Resource for TypedModelResource {}