1use std::path::Path;
2
3use crate::ast::{Document, QuantFormat};
4use crate::internal::*;
5use tract_core::downcast_rs::{impl_downcast, DowncastSync};
6
7pub const GRAPH_NNEF_FILENAME: &str = "graph.nnef";
8pub const GRAPH_QUANT_FILENAME: &str = "graph.quant";
9
10pub fn resource_path_to_id(path: impl AsRef<Path>) -> TractResult<String> {
11 let mut path = path.as_ref().to_path_buf();
12 path.set_extension("");
13 path.to_str()
14 .ok_or_else(|| format_err!("Badly encoded filename for path: {:?}", path))
15 .map(|s| s.to_string())
16}
17
18pub trait Resource: DowncastSync + std::fmt::Debug + Send + Sync {
19 fn get(&self, _key: &str) -> TractResult<Value> {
21 bail!("No key access supported by this resource");
22 }
23}
24
25impl_downcast!(sync Resource);
26
27pub trait ResourceLoader: Send + Sync {
28 fn name(&self) -> Cow<str>;
30 fn try_load(
33 &self,
34 path: &Path,
35 reader: &mut dyn std::io::Read,
36 framework: &Nnef,
37 ) -> TractResult<Option<(String, Arc<dyn Resource>)>>;
38
39 fn into_boxed(self) -> Box<dyn ResourceLoader>
40 where
41 Self: Sized + 'static,
42 {
43 Box::new(self)
44 }
45}
46
47impl Resource for Document {}
48
49#[derive(Clone, Debug, Copy, PartialEq, Eq, Hash)]
50pub struct GraphNnefLoader;
51
52impl ResourceLoader for GraphNnefLoader {
53 fn name(&self) -> Cow<str> {
54 "GraphNnefLoader".into()
55 }
56
57 fn try_load(
58 &self,
59 path: &Path,
60 reader: &mut dyn std::io::Read,
61 _framework: &Nnef,
62 ) -> TractResult<Option<(String, Arc<dyn Resource>)>> {
63 if path.ends_with(GRAPH_NNEF_FILENAME) {
64 let mut text = String::new();
65 reader.read_to_string(&mut text)?;
66 let document = crate::ast::parse::parse_document(&text)?;
67 Ok(Some((path.to_str().unwrap().to_string(), Arc::new(document))))
68 } else {
69 Ok(None)
70 }
71 }
72}
73
74impl Resource for Tensor {}
75
76#[derive(Clone, Debug, Copy, PartialEq, Eq, Hash)]
77pub struct DatLoader;
78
79impl ResourceLoader for DatLoader {
80 fn name(&self) -> Cow<str> {
81 "DatLoader".into()
82 }
83
84 fn try_load(
85 &self,
86 path: &Path,
87 reader: &mut dyn std::io::Read,
88 _framework: &Nnef,
89 ) -> TractResult<Option<(String, Arc<dyn Resource>)>> {
90 if path.extension().map(|e| e == "dat").unwrap_or(false) {
91 let tensor = crate::tensors::read_tensor(reader)
92 .with_context(|| format!("Error while reading tensor {path:?}"))?;
93 Ok(Some((resource_path_to_id(path)?, Arc::new(tensor))))
94 } else {
95 Ok(None)
96 }
97 }
98}
99
100impl Resource for HashMap<String, QuantFormat> {}
101
102#[derive(Clone, Debug, Copy, PartialEq, Eq, Hash)]
103pub struct GraphQuantLoader;
104
105impl ResourceLoader for GraphQuantLoader {
106 fn name(&self) -> Cow<str> {
107 "GraphQuantLoader".into()
108 }
109
110 fn try_load(
111 &self,
112 path: &Path,
113 reader: &mut dyn std::io::Read,
114 _framework: &Nnef,
115 ) -> TractResult<Option<(String, Arc<dyn Resource>)>> {
116 if path.ends_with(GRAPH_QUANT_FILENAME) {
117 let mut text = String::new();
118 reader.read_to_string(&mut text)?;
119 let quant = crate::ast::quant::parse_quantization(&text)?;
120 let quant: HashMap<String, QuantFormat> =
121 quant.into_iter().map(|(k, v)| (k.0, v)).collect();
122 Ok(Some((path.to_str().unwrap().to_string(), Arc::new(quant))))
123 } else {
124 Ok(None)
125 }
126 }
127}
128
129pub struct TypedModelLoader {
130 pub optimized_model: bool,
131}
132
133impl TypedModelLoader {
134 pub fn new(optimized_model: bool) -> Self {
135 Self { optimized_model }
136 }
137}
138
139impl ResourceLoader for TypedModelLoader {
140 fn name(&self) -> Cow<str> {
141 "TypedModelLoader".into()
142 }
143
144 fn try_load(
145 &self,
146 path: &Path,
147 reader: &mut dyn std::io::Read,
148 framework: &Nnef,
149 ) -> TractResult<Option<(String, Arc<dyn Resource>)>> {
150 const NNEF_TGZ: &str = ".nnef.tgz";
151 const NNEF_TAR: &str = ".nnef.tar";
152 let path_str = path.to_str().unwrap_or("");
153 if path_str.ends_with(NNEF_TGZ) || path_str.ends_with(NNEF_TAR) {
154 let model = if self.optimized_model {
155 framework.model_for_read(reader)?.into_optimized()?
156 } else {
157 framework.model_for_read(reader)?
158 };
159
160 let label = if path_str.ends_with(NNEF_TGZ) {
161 path
162 .to_str()
163 .ok_or_else(|| anyhow!("invalid model resource path"))?
164 .trim_end_matches(NNEF_TGZ)
165 } else {
166 path
167 .to_str()
168 .ok_or_else(|| anyhow!("invalid model resource path"))?
169 .trim_end_matches(NNEF_TAR)
170 };
171 Ok(Some((resource_path_to_id(label)?, Arc::new(TypedModelResource(model)))))
172 } else {
173 Ok(None)
174 }
175 }
176}
177
178#[derive(Debug, Clone)]
179pub struct TypedModelResource(pub TypedModel);
180
181impl Resource for TypedModelResource {}