1use std::collections::BTreeMap;
2
3use tract_hir::internal::*;
4
5#[derive(Clone, Debug)]
6pub struct KaldiProtoModel {
7 pub config_lines: ConfigLines,
8 pub components: HashMap<String, Component>,
9 pub adjust_final_offset: isize,
10}
11
12#[derive(Clone, Debug)]
13pub struct ConfigLines {
14 pub input_name: String,
15 pub input_dim: usize,
16 pub nodes: Vec<(String, NodeLine)>,
17 pub outputs: Vec<OutputLine>,
18}
19
20#[derive(Clone, Debug)]
21pub enum NodeLine {
22 Component(ComponentNode),
23 DimRange(DimRangeNode),
24}
25
26#[derive(Clone, Debug)]
27pub struct OutputLine {
28 pub output_alias: String,
29 pub descriptor: GeneralDescriptor,
30}
31
32#[derive(Clone, Debug, PartialEq)]
33pub enum GeneralDescriptor {
34 Append(Vec<GeneralDescriptor>),
35 IfDefined(Box<GeneralDescriptor>),
36 Name(String),
37 Offset(Box<GeneralDescriptor>, isize),
38}
39
40impl GeneralDescriptor {
41 pub fn inputs(&self) -> TVec<&str> {
42 match self {
43 GeneralDescriptor::Append(ref gds) => gds.iter().fold(tvec!(), |mut acc, gd| {
44 gd.inputs().iter().for_each(|i| {
45 if !acc.contains(i) {
46 acc.push(i)
47 }
48 });
49 acc
50 }),
51 GeneralDescriptor::IfDefined(ref gd) => gd.inputs(),
52 GeneralDescriptor::Name(ref s) => tvec!(&**s),
53 GeneralDescriptor::Offset(ref gd, _) => gd.inputs(),
54 }
55 }
56
57 pub fn as_conv_shape_dilation(&self) -> Option<(usize, usize)> {
58 if let GeneralDescriptor::Name(_) = self {
59 return Some((1, 1));
60 }
61 if let GeneralDescriptor::Append(ref appendees) = self {
62 let mut offsets = vec![];
63 for app in appendees {
64 match app {
65 GeneralDescriptor::Name(_) => offsets.push(0),
66 GeneralDescriptor::Offset(_, offset) => offsets.push(*offset),
67 _ => return None,
68 }
69 }
70 let dilation = offsets[1] - offsets[0];
71 if offsets.windows(2).all(|pair| pair[1] - pair[0] == dilation) {
72 return Some((offsets.len(), dilation as usize));
73 }
74 }
75 None
76 }
77
78 fn wire(
79 &self,
80 inlet: InletId,
81 name: &str,
82 model: &mut InferenceModel,
83 deferred: &mut BTreeMap<InletId, String>,
84 adjust_final_offset: Option<isize>,
85 ) -> TractResult<()> {
86 use GeneralDescriptor::*;
87 match self {
88 Name(n) => {
89 deferred.insert(inlet, n.to_string());
90 return Ok(());
91 }
92 Append(appendees) => {
93 let name = format!("{name}.Append");
94 let id = model.add_node(
95 &*name,
96 expand(tract_hir::ops::array::Concat::new(1)),
97 tvec!(InferenceFact::default()),
98 )?;
99 model.add_edge(OutletId::new(id, 0), inlet)?;
100 for (ix, appendee) in appendees.iter().enumerate() {
101 let name = format!("{name}-{ix}");
102 appendee.wire(
103 InletId::new(id, ix),
104 &name,
105 model,
106 deferred,
107 adjust_final_offset,
108 )?;
109 }
110 return Ok(());
111 }
112 IfDefined(ref o) => {
113 if let Offset(n, o) = &**o {
114 if let Name(n) = &**n {
115 let name = format!("{name}.memory");
116 model.add_node(
117 &*name,
118 crate::ops::memory::Memory::new(n.to_string(), *o),
119 tvec!(InferenceFact::default()),
120 )?;
121 deferred.insert(inlet, name);
122 return Ok(());
123 }
124 }
125 }
126 Offset(ref n, o) if *o > 0 => {
127 let name = format!("{name}-Delay");
128 let crop = *o + adjust_final_offset.unwrap_or(0);
129 if crop < 0 {
130 bail!("Invalid offset adjustment (network as {}, adjustment is {}", o, crop)
131 }
132 let id = model.add_node(
133 &*name,
134 expand(tract_hir::ops::array::Crop::new(0, crop as usize, 0)),
135 tvec!(InferenceFact::default()),
136 )?;
137 model.add_edge(OutletId::new(id, 0), inlet)?;
138 n.wire(InletId::new(id, 0), &name, model, deferred, adjust_final_offset)?;
139 return Ok(());
140 }
141 _ => (),
142 }
143 bail!("Unhandled input descriptor: {:?}", self)
144 }
145}
146
147#[derive(Clone, Debug)]
148pub struct DimRangeNode {
149 pub input: GeneralDescriptor,
150 pub offset: usize,
151 pub dim: usize,
152}
153
154#[derive(Clone, Debug)]
155pub struct ComponentNode {
156 pub input: GeneralDescriptor,
157 pub component: String,
158}
159
160#[derive(Clone, Debug, Default)]
161pub struct Component {
162 pub klass: String,
163 pub attributes: HashMap<String, Arc<Tensor>>,
164}
165
166pub struct ParsingContext<'a> {
167 pub proto_model: &'a KaldiProtoModel,
168}
169
170type OpBuilder = fn(&ParsingContext, node: &str) -> TractResult<Box<dyn InferenceOp>>;
171
172#[derive(Clone, Default)]
173pub struct KaldiOpRegister(pub HashMap<String, OpBuilder>);
174
175impl KaldiOpRegister {
176 pub fn insert(&mut self, s: &'static str, builder: OpBuilder) {
177 self.0.insert(s.into(), builder);
178 }
179}
180
181#[derive(Clone, Default)]
182pub struct Kaldi {
183 pub op_register: KaldiOpRegister,
184}
185
186impl Framework<KaldiProtoModel, InferenceModel> for Kaldi {
187 fn proto_model_for_read(&self, r: &mut dyn std::io::Read) -> TractResult<KaldiProtoModel> {
188 use crate::parser;
189 let mut v = vec![];
190 r.read_to_end(&mut v)?;
191 parser::nnet3(&v)
192 }
193
194 fn model_for_proto_model_with_symbols(
195 &self,
196 proto_model: &KaldiProtoModel,
197 symbols: &SymbolTable,
198 ) -> TractResult<InferenceModel> {
199 let ctx = ParsingContext { proto_model };
200 let mut model =
201 InferenceModel { symbol_table: symbols.to_owned(), ..InferenceModel::default() };
202
203 let s = model.symbol_table.sym("S");
204 model.add_source(
205 proto_model.config_lines.input_name.clone(),
206 f32::fact(dims!(s, proto_model.config_lines.input_dim)).into(),
207 )?;
208 let mut inputs_to_wire: BTreeMap<InletId, String> = Default::default();
209 for (name, node) in &proto_model.config_lines.nodes {
210 match node {
211 NodeLine::Component(line) => {
212 let component = &proto_model.components[&line.component];
213 if crate::ops::AFFINE.contains(&&*component.klass)
214 && line.input.as_conv_shape_dilation().is_some()
215 {
216 let op = crate::ops::affine::affine_component(&ctx, name)?;
217 let id = model.add_node(
218 name.to_string(),
219 op,
220 tvec!(InferenceFact::default()),
221 )?;
222 inputs_to_wire
223 .insert(InletId::new(id, 0), line.input.inputs()[0].to_owned());
224 } else {
225 let op = match self.op_register.0.get(&*component.klass) {
226 Some(builder) => (builder)(&ctx, name)?,
227 None => Box::new(tract_hir::ops::unimpl::UnimplementedOp::new(
228 1,
229 &component.klass,
230 format!("{line:?}"),
231 )),
232 };
233 let id = model.add_node(
234 name.to_string(),
235 op,
236 tvec!(InferenceFact::default()),
237 )?;
238 line.input.wire(
239 InletId::new(id, 0),
240 name,
241 &mut model,
242 &mut inputs_to_wire,
243 None,
244 )?
245 }
246 }
247 NodeLine::DimRange(line) => {
248 let op =
249 tract_hir::ops::array::Slice::new(1, line.offset, line.offset + line.dim);
250 let id =
251 model.add_node(name.to_string(), op, tvec!(InferenceFact::default()))?;
252 line.input.wire(
253 InletId::new(id, 0),
254 name,
255 &mut model,
256 &mut inputs_to_wire,
257 None,
258 )?
259 }
260 }
261 }
262 let mut outputs = vec![];
263 for o in &proto_model.config_lines.outputs {
264 let output = model.add_node(
265 &*o.output_alias,
266 tract_hir::ops::identity::Identity::default(),
267 tvec!(InferenceFact::default()),
268 )?;
269 model.set_outlet_label(output.into(), o.output_alias.to_string())?;
270 o.descriptor.wire(
271 InletId::new(output, 0),
272 "output",
273 &mut model,
274 &mut inputs_to_wire,
275 Some(proto_model.adjust_final_offset),
276 )?;
277 outputs.push(OutletId::new(output, 0));
278 }
279 for (inlet, name) in inputs_to_wire {
280 let src = OutletId::new(model.node_by_name(&*name)?.id, 0);
281 model.add_edge(src, inlet)?;
282 }
283 model.set_output_outlets(&outputs)?;
284 Ok(model)
285 }
286}