1use std::fmt::Debug;
2
3use crate::pb::*;
4use tract_hir::internal::*;
5use tract_hir::tract_core::dyn_clone::{clone_trait_object, DynClone};
6use tract_hir::tract_core::ops::scan::ScanInfo;
7
8pub trait WireBody: Debug + DynClone + Send + Sync {
9 fn name(&self) -> &'static str;
10 fn wire_body(&self, prefix: &str, body: &mut TypedModel) -> TractResult<()>;
11 fn w_b_multipliers(&self) -> (usize, usize);
12 fn have_extra_c_state(&self) -> bool;
13}
14
15clone_trait_object!(WireBody);
16
17#[derive(Debug, Clone)]
18pub struct CommonRec {
19 pub optional_bias_input: Option<usize>,
20 pub optional_sequence_lens_input: Option<usize>,
21 pub optional_initial_h_input: Option<usize>,
22 pub optional_initial_c_input: Option<usize>,
23 pub optional_p_input: Option<usize>,
24 pub optional_y_output: Option<usize>,
25 pub optional_y_h_output: Option<usize>,
26 pub optional_y_c_output: Option<usize>,
27 pub batch_first: bool,
28 pub body: Box<dyn WireBody>,
29}
30
31impl CommonRec {
32 pub fn from_node_and_options(
33 pb: &NodeProto,
34 fixed_input: usize,
35 fixed_outputs: usize,
36 body: Box<dyn WireBody>,
37 ) -> TractResult<Self> {
38 let mut inputs = crate::model::optional_inputs(pb).skip(fixed_input);
39 let mut outputs = crate::model::optional_outputs(pb).skip(fixed_outputs);
40 Ok(Self {
41 optional_bias_input: inputs.next().unwrap(),
42 optional_sequence_lens_input: inputs.next().unwrap(),
43 optional_initial_h_input: inputs.next().unwrap(),
44 optional_initial_c_input: inputs.next().unwrap(),
45 optional_p_input: inputs.next().unwrap(),
46
47 optional_y_output: outputs.next().unwrap(),
48 optional_y_h_output: outputs.next().unwrap(),
49 optional_y_c_output: outputs.next().unwrap(),
50
51 batch_first: pb.get_attr_opt("layout")?.unwrap_or(0) == 1,
52 body,
53 })
54 }
55
56 #[allow(non_snake_case)]
57 fn wire_one_side(
58 &self,
59 prefix: &str,
60 target: &mut TypedModel,
61 inputs: &[OutletId],
62 dir: usize,
63 ) -> TractResult<TVec<OutletId>> {
64 use tract_hir::ops::{array, scan};
65
66 let x_fact = target.outlet_fact(inputs[0])?.clone();
67 let r_fact = target.outlet_fact(inputs[2])?.clone();
68
69 if let Some(seqlen) = self.optional_sequence_lens_input {
70 let Some(seqlen) = &target.outlet_fact(inputs[seqlen])?.konst else {
71 bail!("Non constant seq_len is not supported");
72 };
73 let Some(seqlen) = seqlen.as_uniform() else {
74 bail!("Non uniform seq_len is not supported");
75 };
76 let seqlen = seqlen.cast_to::<TDim>()?;
77 if seqlen.to_scalar::<TDim>()? != &x_fact.shape[self.batch_first as usize] {
78 bail!("seq_len only supported for trivial noop case");
79 };
80 }
81
82 let b_size = &x_fact.shape[1 - self.batch_first as usize];
83 let h_size = &r_fact.shape[2];
84
85 let chunk = if dir == 0 { 1 } else { -1 };
86
87 let mut body = TypedModel::default();
88 let mut outer_inputs = vec![];
89 let mut input_mapping = vec![];
90
91 macro_rules! target_wire {
92 ($name: ident = $op: expr, $($param: expr),*) => {
93 let $name = target.wire_node(
94 format!("{}.{}", prefix, stringify!($name)),
95 $op, [$($param),*].as_ref())?[0];
96 }
97 }
98
99 macro_rules! wire {
100 ($name: ident = $op: expr, $($param: expr),*) => {
101 #[allow(unused_variables)]
102 let $name = body.wire_node(
103 stringify!($name),
104 $op, [$($param),*].as_ref())?[0];
105 }
106 }
107
108 let x_batch_first = if self.batch_first {
111 inputs[0]
112 } else {
113 target_wire!(x_batch_first = AxisOp::Move(1, 0), inputs[0]);
114 x_batch_first
115 };
116 outer_inputs.push(x_batch_first);
120 input_mapping.push(scan::InputMapping::Scan(ScanInfo { axis: 1, chunk }));
121 let mut x_source_fact = target.outlet_fact(x_batch_first)?.without_value();
122 x_source_fact.shape.set(1, 1.to_dim());
123 let x_source = body.add_source("x_source", x_source_fact)?;
124 wire!(Xt = AxisOp::Rm(1), x_source);
125
126 target_wire!(w_dir = array::Slice::new(0, dir, dir + 1), inputs[1]);
129 target_wire!(w = AxisOp::Rm(0), w_dir);
130 outer_inputs.push(w);
131 input_mapping.push(scan::InputMapping::Full);
132 body.add_source("W", target.outlet_fact(w)?.clone())?;
133
134 target_wire!(r_dir = array::Slice::new(0, dir, dir + 1), inputs[2]);
137 target_wire!(r = AxisOp::Rm(0), r_dir);
138 outer_inputs.push(r);
139 input_mapping.push(scan::InputMapping::Full);
140 body.add_source("R", target.outlet_fact(r)?.clone())?;
141
142 if let Some(slot) = self.optional_bias_input {
144 target_wire!(b_dir = array::Slice::new(0, dir, dir + 1), inputs[slot]);
145 outer_inputs.push(b_dir);
146 input_mapping.push(scan::InputMapping::Full);
147 let b = body.add_source("b", target.outlet_fact(b_dir)?.clone())?;
148 Some(b)
149 } else {
150 None
151 };
152
153 let initializer = if let Some(initial_h_input) = self.optional_initial_h_input {
158 let mut input = inputs[initial_h_input];
159 if self.batch_first {
160 target_wire!(h_batch_first = AxisOp::Move(1, 0), input);
161 input = h_batch_first;
162 };
163 target_wire!(h_dir = array::Slice::new(0, dir, dir + 1), input);
164 target_wire!(h = AxisOp::Rm(0), h_dir);
165 target_wire!(h_chunk_ = AxisOp::Add(0), h);
166 target_wire!(h_chunk = AxisOp::Move(1, 0), h_chunk_);
167 h_chunk
168 } else {
169 target.add_const(
170 format!("{prefix}.h0"),
171 tensor0(0.0f32)
172 .broadcast_scalar_to_shape(&[
173 b_size.to_usize().unwrap(),
174 1,
175 h_size.to_usize().unwrap(),
176 ])?
177 .into_arc_tensor(),
178 )?
179 };
180 outer_inputs.push(initializer);
181 input_mapping.push(scan::InputMapping::State);
182
183 let h_source = body.add_source(
184 "h_source",
185 x_fact.datum_type.fact(&[b_size.clone(), 1.to_dim(), h_size.clone()]),
186 )?;
187 wire!(Ht_1 = AxisOp::Rm(1), h_source);
188
189 if self.body.have_extra_c_state() {
190 let initializer = if let Some(initial_c_input) = self.optional_initial_c_input {
191 let mut input = inputs[initial_c_input];
192 if self.batch_first {
193 target_wire!(c_batch_first = AxisOp::Move(1, 0), input);
194 input = c_batch_first;
195 };
196 target_wire!(c_dir = array::Slice::new(0, dir, dir + 1), input);
197 target_wire!(c = AxisOp::Rm(0), c_dir);
198 target_wire!(c_chunk_ = AxisOp::Add(0), c);
199 target_wire!(c_chunk = AxisOp::Move(1, 0), c_chunk_);
200 c_chunk
201 } else {
202 target.add_const(
203 format!("{prefix}.c0"),
204 tensor0(0.0f32)
205 .broadcast_scalar_to_shape(&[
206 b_size.to_usize().unwrap(),
207 1,
208 h_size.to_usize().unwrap(),
209 ])?
210 .into_arc_tensor(),
211 )?
212 };
213 outer_inputs.push(initializer);
214 input_mapping.push(scan::InputMapping::State);
215 let c_source = body.add_source(
216 "c_source",
217 x_fact.datum_type.fact(&[b_size.clone(), 1.to_dim(), h_size.clone()]),
218 )?;
219 wire!(Ct_1 = AxisOp::Rm(1), c_source);
220 }
221
222 if let Some(slot) = self.optional_p_input {
224 target_wire!(p = array::Slice::new(0, dir, dir + 1), inputs[slot]);
225 outer_inputs.push(p);
226 input_mapping.push(scan::InputMapping::Full);
227 body.add_source("peepholes", target.outlet_fact(p)?.clone())?;
228 };
229
230 self.body.wire_body(prefix, &mut body).context("Wiring body")?;
231
232 let mut output_mapping = vec![scan::OutputMapping {
233 state: true,
234 full_dim_hint: None,
235 last_value_slot: self.optional_y_h_output,
236 scan: self.optional_y_output.map(|slot| (slot, ScanInfo { axis: 1, chunk })),
237 }];
238 if self.body.have_extra_c_state() {
239 output_mapping.push(scan::OutputMapping {
240 state: true,
241 full_dim_hint: None,
242 last_value_slot: self.optional_y_c_output,
243 scan: None,
244 });
245 }
246
247 let scan_outputs = target.wire_node(
248 prefix,
249 tract_core::ops::scan::Scan::new(body, input_mapping, output_mapping, 0)?,
250 &outer_inputs,
251 )?;
252
253 let mut result = tvec!();
254 if let Some(slot) = self.optional_y_output {
255 if self.batch_first {
257 target_wire!(y = AxisOp::Add(2), scan_outputs[slot]);
259 result.push(y);
260 } else {
261 target_wire!(y_batch_middle = AxisOp::Move(1, 0), scan_outputs[slot]);
263 target_wire!(y = AxisOp::Add(1), y_batch_middle);
264 result.push(y);
265 }
266 }
267 if let Some(slot) = self.optional_y_h_output {
268 if self.batch_first {
269 result.push(scan_outputs[slot]);
270 } else {
271 target_wire!(y_h_batch_middle = AxisOp::Move(1, 0), scan_outputs[slot]);
272 result.push(y_h_batch_middle);
273 }
274 }
275 if let Some(slot) = self.optional_y_c_output {
276 if self.batch_first {
277 result.push(scan_outputs[slot]);
278 } else {
279 target_wire!(y_c_batch_middle = AxisOp::Move(1, 0), scan_outputs[slot]);
280 result.push(y_c_batch_middle);
281 }
282 }
283
284 Ok(result)
285 }
286}
287
288impl Expansion for CommonRec {
289 fn name(&self) -> StaticName {
290 self.body.name().into()
291 }
292
293 fn info(&self) -> TractResult<Vec<String>> {
294 Ok(vec![format!("batch_first: {:?}", self.batch_first)])
295 }
296
297 fn validation(&self) -> Validation {
298 Validation::Rounding
299 }
300
301 fn nboutputs(&self) -> TractResult<usize> {
302 Ok(self.optional_y_output.is_some() as usize
303 + self.optional_y_h_output.is_some() as usize
304 + self.optional_y_c_output.is_some() as usize)
305 }
306
307 fn rules<'r, 'p: 'r, 's: 'r>(
308 &'s self,
309 s: &mut Solver<'r>,
310 inputs: &'p [TensorProxy],
311 outputs: &'p [TensorProxy],
312 ) -> TractResult<()> {
313 let input_count = 3
314 + self.optional_bias_input.is_some() as usize
315 + self.optional_sequence_lens_input.is_some() as usize
316 + self.optional_initial_h_input.is_some() as usize
317 + self.optional_initial_c_input.is_some() as usize
318 + self.optional_p_input.is_some() as usize;
319 check_input_arity(inputs, input_count)?;
320 let output_count = self.optional_y_output.is_some() as usize
321 + self.optional_y_h_output.is_some() as usize
322 + self.optional_y_c_output.is_some() as usize;
323 check_output_arity(outputs, output_count)?;
324 s.equals(&inputs[0].datum_type, &inputs[1].datum_type)?;
325 s.equals(&inputs[0].datum_type, &inputs[2].datum_type)?;
326 s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
327 s.equals(&inputs[0].rank, 3)?;
328 s.equals(&inputs[1].rank, 3)?;
329 s.equals(&inputs[2].rank, 3)?;
330
331 let b = if self.batch_first { 0 } else { 1 };
342 let b_in_y = if self.batch_first { 0 } else { 2 };
343 let seq_len = if self.batch_first { 1 } else { 0 };
344 let dirs = if self.batch_first { 1 } else { 0 };
345 let dirs_in_y = if self.batch_first { 2 } else { 1 };
346
347 let (w_mul, b_mul) = self.body.w_b_multipliers();
348
349 s.equals(&inputs[1].shape[0], &inputs[2].shape[0])?; s.equals(&inputs[1].shape[1], (w_mul as i64) * inputs[2].shape[2].bex())?; s.equals(&inputs[2].shape[1], (w_mul as i64) * inputs[2].shape[2].bex())?; if let Some(bias) = self.optional_bias_input {
353 s.equals(&inputs[bias].datum_type, &inputs[0].datum_type)?;
354 s.equals(&inputs[bias].rank, 2)?;
355 s.equals(&inputs[bias].shape[0], &inputs[2].shape[0])?; s.equals(&inputs[bias].shape[1], (b_mul as i64) * inputs[2].shape[2].bex())?;
357 }
359 if let Some(seq_len) = self.optional_sequence_lens_input {
360 s.equals(&inputs[seq_len].rank, 1)?;
361 s.equals(&inputs[seq_len].shape[0], &inputs[0].shape[b])?; }
363 if let Some(initial_h) = self.optional_initial_h_input {
364 s.equals(&inputs[initial_h].datum_type, &inputs[0].datum_type)?;
365 s.equals(&inputs[initial_h].rank, 3)?;
366 s.equals(&inputs[initial_h].shape[dirs], &inputs[1].shape[0])?; s.equals(&inputs[initial_h].shape[b], &inputs[0].shape[b])?; s.equals(&inputs[initial_h].shape[2], &inputs[2].shape[2])?; }
370 if let Some(initial_c) = self.optional_initial_c_input {
371 s.equals(&inputs[initial_c].datum_type, &inputs[0].datum_type)?;
372 s.equals(&inputs[initial_c].rank, 3)?;
373 s.equals(&inputs[initial_c].shape[dirs], &inputs[1].shape[0])?; s.equals(&inputs[initial_c].shape[b], &inputs[0].shape[b])?; s.equals(&inputs[initial_c].shape[2], &inputs[2].shape[2])?; }
377 if let Some(p) = self.optional_p_input {
378 s.equals(&inputs[p].datum_type, &inputs[0].datum_type)?;
379 s.equals(&inputs[p].rank, 2)?;
380 s.equals(&inputs[p].shape[0], &inputs[1].shape[0])?; s.equals(&inputs[p].shape[1], 3 * inputs[2].shape[2].bex())?; }
383 if let Some(y) = self.optional_y_output {
384 s.equals(&outputs[y].datum_type, &inputs[0].datum_type)?;
385 s.equals(&outputs[y].rank, 4)?;
386 s.equals(&outputs[y].shape[seq_len], &inputs[0].shape[seq_len])?; s.equals(&outputs[y].shape[dirs_in_y], &inputs[1].shape[0])?; s.equals(&outputs[y].shape[b_in_y], &inputs[0].shape[b])?; s.equals(&outputs[y].shape[3], &inputs[2].shape[2])?; }
391 if let Some(y_h) = self.optional_y_h_output {
392 s.equals(&outputs[y_h].datum_type, &inputs[0].datum_type)?;
393 s.equals(&outputs[y_h].rank, 3)?;
394 s.equals(&outputs[y_h].shape[dirs], &inputs[1].shape[0])?; s.equals(&outputs[y_h].shape[b], &inputs[0].shape[b])?; s.equals(&outputs[y_h].shape[2], &inputs[2].shape[2])?; }
398 if let Some(y_c) = self.optional_y_c_output {
399 s.equals(&outputs[y_c].datum_type, &inputs[0].datum_type)?;
400 s.equals(&outputs[y_c].rank, 3)?;
401 s.equals(&outputs[y_c].shape[dirs], &inputs[1].shape[0])?; s.equals(&outputs[y_c].shape[b], &inputs[0].shape[b])?; s.equals(&outputs[y_c].shape[2], &inputs[2].shape[2])?; }
405 Ok(())
406 }
407
408 fn wire(
409 &self,
410 prefix: &str,
411 target: &mut TypedModel,
412 inputs: &[OutletId],
413 ) -> TractResult<TVec<OutletId>> {
414 use tract_hir::tract_core::ops::array::TypedConcat;
415 let fore = self.wire_one_side(prefix, target, inputs, 0)?;
416 let w_fact = target.outlet_fact(inputs[1])?;
417 if w_fact.shape[0] == 2.into() {
418 let back = self.wire_one_side(&format!("{prefix}.back"), target, inputs, 1)?;
419 let mut outputs = tvec!(0.into(); self.nboutputs()?);
420 if let Some(ix) = self.optional_y_output {
421 outputs[ix] = target.wire_node(
422 format!("{prefix}.merge_y_output"),
423 TypedConcat::new(1),
424 &[fore[ix], back[ix]],
425 )?[0];
426 }
427 if let Some(ix) = self.optional_y_h_output {
428 outputs[ix] = target.wire_node(
429 format!("{prefix}.merge_y_h_output"),
430 TypedConcat::new(0),
431 &[fore[ix], back[ix]],
432 )?[0];
433 }
434 if let Some(ix) = self.optional_y_c_output {
435 outputs[ix] = target.wire_node(
436 format!("{prefix}.merge_y_c_output"),
437 TypedConcat::new(0),
438 &[fore[ix], back[ix]],
439 )?[0];
440 }
441 Ok(outputs)
442 } else {
443 Ok(fore)
444 }
445 }
446}