tract_onnx/ops/rec/
common.rs

1use std::fmt::Debug;
2
3use crate::pb::*;
4use tract_hir::internal::*;
5use tract_hir::tract_core::dyn_clone::{clone_trait_object, DynClone};
6use tract_hir::tract_core::ops::scan::ScanInfo;
7
8pub trait WireBody: Debug + DynClone + Send + Sync {
9    fn name(&self) -> &'static str;
10    fn wire_body(&self, prefix: &str, body: &mut TypedModel) -> TractResult<()>;
11    fn w_b_multipliers(&self) -> (usize, usize);
12    fn have_extra_c_state(&self) -> bool;
13}
14
15clone_trait_object!(WireBody);
16
17#[derive(Debug, Clone)]
18pub struct CommonRec {
19    pub optional_bias_input: Option<usize>,
20    pub optional_sequence_lens_input: Option<usize>,
21    pub optional_initial_h_input: Option<usize>,
22    pub optional_initial_c_input: Option<usize>,
23    pub optional_p_input: Option<usize>,
24    pub optional_y_output: Option<usize>,
25    pub optional_y_h_output: Option<usize>,
26    pub optional_y_c_output: Option<usize>,
27    pub batch_first: bool,
28    pub body: Box<dyn WireBody>,
29}
30
31impl CommonRec {
32    pub fn from_node_and_options(
33        pb: &NodeProto,
34        fixed_input: usize,
35        fixed_outputs: usize,
36        body: Box<dyn WireBody>,
37    ) -> TractResult<Self> {
38        let mut inputs = crate::model::optional_inputs(pb).skip(fixed_input);
39        let mut outputs = crate::model::optional_outputs(pb).skip(fixed_outputs);
40        Ok(Self {
41            optional_bias_input: inputs.next().unwrap(),
42            optional_sequence_lens_input: inputs.next().unwrap(),
43            optional_initial_h_input: inputs.next().unwrap(),
44            optional_initial_c_input: inputs.next().unwrap(),
45            optional_p_input: inputs.next().unwrap(),
46
47            optional_y_output: outputs.next().unwrap(),
48            optional_y_h_output: outputs.next().unwrap(),
49            optional_y_c_output: outputs.next().unwrap(),
50
51            batch_first: pb.get_attr_opt("layout")?.unwrap_or(0) == 1,
52            body,
53        })
54    }
55
56    #[allow(non_snake_case)]
57    fn wire_one_side(
58        &self,
59        prefix: &str,
60        target: &mut TypedModel,
61        inputs: &[OutletId],
62        dir: usize,
63    ) -> TractResult<TVec<OutletId>> {
64        use tract_hir::ops::{array, scan};
65
66        let x_fact = target.outlet_fact(inputs[0])?.clone();
67        let r_fact = target.outlet_fact(inputs[2])?.clone();
68
69        if let Some(seqlen) = self.optional_sequence_lens_input {
70            let Some(seqlen) = &target.outlet_fact(inputs[seqlen])?.konst else {
71                bail!("Non constant seq_len is not supported");
72            };
73            let Some(seqlen) = seqlen.as_uniform() else {
74                bail!("Non uniform seq_len is not supported");
75            };
76            let seqlen = seqlen.cast_to::<TDim>()?;
77            if seqlen.to_scalar::<TDim>()? != &x_fact.shape[self.batch_first as usize] {
78                bail!("seq_len only supported for trivial noop case");
79            };
80        }
81
82        let b_size = &x_fact.shape[1 - self.batch_first as usize];
83        let h_size = &r_fact.shape[2];
84
85        let chunk = if dir == 0 { 1 } else { -1 };
86
87        let mut body = TypedModel::default();
88        let mut outer_inputs = vec![];
89        let mut input_mapping = vec![];
90
91        macro_rules! target_wire {
92            ($name: ident = $op: expr, $($param: expr),*) => {
93                let $name = target.wire_node(
94                    format!("{}.{}", prefix, stringify!($name)),
95                    $op, [$($param),*].as_ref())?[0];
96            }
97        }
98
99        macro_rules! wire {
100            ($name: ident = $op: expr, $($param: expr),*) => {
101                #[allow(unused_variables)]
102                let $name = body.wire_node(
103                    stringify!($name),
104                    $op, [$($param),*].as_ref())?[0];
105            }
106        }
107
108        // X: onnx interface: [batch_size, seq_length, input_size]
109        // move batch first
110        let x_batch_first = if self.batch_first {
111            inputs[0]
112        } else {
113            target_wire!(x_batch_first = AxisOp::Move(1, 0), inputs[0]);
114            x_batch_first
115        };
116        // scan outer interface: idem
117        // scann inner interface: [chunk=1, batch_size, input_size]
118        // onnx inner interface: [batch_size, input_size]
119        outer_inputs.push(x_batch_first);
120        input_mapping.push(scan::InputMapping::Scan(ScanInfo { axis: 1, chunk }));
121        let mut x_source_fact = target.outlet_fact(x_batch_first)?.without_value();
122        x_source_fact.shape.set(1, 1.to_dim());
123        let x_source = body.add_source("x_source", x_source_fact)?;
124        wire!(Xt = AxisOp::Rm(1), x_source);
125
126        // W: onnx interface: [num_directions, 3*hidden_size, input_size]
127        // scan interfaces: [3*hidden_size, input_size]
128        target_wire!(w_dir = array::Slice::new(0, dir, dir + 1), inputs[1]);
129        target_wire!(w = AxisOp::Rm(0), w_dir);
130        outer_inputs.push(w);
131        input_mapping.push(scan::InputMapping::Full);
132        body.add_source("W", target.outlet_fact(w)?.clone())?;
133
134        // R: onnx interface: [num_directions, 3*hidden_size, hidden_size]
135        // scan interfaces: [3*hidden_size, hidden_size]
136        target_wire!(r_dir = array::Slice::new(0, dir, dir + 1), inputs[2]);
137        target_wire!(r = AxisOp::Rm(0), r_dir);
138        outer_inputs.push(r);
139        input_mapping.push(scan::InputMapping::Full);
140        body.add_source("R", target.outlet_fact(r)?.clone())?;
141
142        // B: onnx interface: [num_directions, 6*hidden_size]
143        if let Some(slot) = self.optional_bias_input {
144            target_wire!(b_dir = array::Slice::new(0, dir, dir + 1), inputs[slot]);
145            outer_inputs.push(b_dir);
146            input_mapping.push(scan::InputMapping::Full);
147            let b = body.add_source("b", target.outlet_fact(b_dir)?.clone())?;
148            Some(b)
149        } else {
150            None
151        };
152
153        // initial h, optional: onnx: [num_directions, batch_size, hidden_size]
154        // scan outer: [batch_size, chunk=1, hidden_size]
155        // scan inner: [batch_size, chunk=1, hidden_size]
156        // onnx inner: [batch_size, hidden_size]
157        let initializer = if let Some(initial_h_input) = self.optional_initial_h_input {
158            let mut input = inputs[initial_h_input];
159            if self.batch_first {
160                target_wire!(h_batch_first = AxisOp::Move(1, 0), input);
161                input = h_batch_first;
162            };
163            target_wire!(h_dir = array::Slice::new(0, dir, dir + 1), input);
164            target_wire!(h = AxisOp::Rm(0), h_dir);
165            target_wire!(h_chunk_ = AxisOp::Add(0), h);
166            target_wire!(h_chunk = AxisOp::Move(1, 0), h_chunk_);
167            h_chunk
168        } else {
169            target.add_const(
170                format!("{prefix}.h0"),
171                tensor0(0.0f32)
172                    .broadcast_scalar_to_shape(&[
173                        b_size.to_usize().unwrap(),
174                        1,
175                        h_size.to_usize().unwrap(),
176                    ])?
177                    .into_arc_tensor(),
178            )?
179        };
180        outer_inputs.push(initializer);
181        input_mapping.push(scan::InputMapping::State);
182
183        let h_source = body.add_source(
184            "h_source",
185            x_fact.datum_type.fact(&[b_size.clone(), 1.to_dim(), h_size.clone()]),
186        )?;
187        wire!(Ht_1 = AxisOp::Rm(1), h_source);
188
189        if self.body.have_extra_c_state() {
190            let initializer = if let Some(initial_c_input) = self.optional_initial_c_input {
191                let mut input = inputs[initial_c_input];
192                if self.batch_first {
193                    target_wire!(c_batch_first = AxisOp::Move(1, 0), input);
194                    input = c_batch_first;
195                };
196                target_wire!(c_dir = array::Slice::new(0, dir, dir + 1), input);
197                target_wire!(c = AxisOp::Rm(0), c_dir);
198                target_wire!(c_chunk_ = AxisOp::Add(0), c);
199                target_wire!(c_chunk = AxisOp::Move(1, 0), c_chunk_);
200                c_chunk
201            } else {
202                target.add_const(
203                    format!("{prefix}.c0"),
204                    tensor0(0.0f32)
205                        .broadcast_scalar_to_shape(&[
206                            b_size.to_usize().unwrap(),
207                            1,
208                            h_size.to_usize().unwrap(),
209                        ])?
210                        .into_arc_tensor(),
211                )?
212            };
213            outer_inputs.push(initializer);
214            input_mapping.push(scan::InputMapping::State);
215            let c_source = body.add_source(
216                "c_source",
217                x_fact.datum_type.fact(&[b_size.clone(), 1.to_dim(), h_size.clone()]),
218            )?;
219            wire!(Ct_1 = AxisOp::Rm(1), c_source);
220        }
221
222        // P: onnx [num_directions, 3*hidde_size]
223        if let Some(slot) = self.optional_p_input {
224            target_wire!(p = array::Slice::new(0, dir, dir + 1), inputs[slot]);
225            outer_inputs.push(p);
226            input_mapping.push(scan::InputMapping::Full);
227            body.add_source("peepholes", target.outlet_fact(p)?.clone())?;
228        };
229
230        self.body.wire_body(prefix, &mut body).context("Wiring body")?;
231
232        let mut output_mapping = vec![scan::OutputMapping {
233            state: true,
234            full_dim_hint: None,
235            last_value_slot: self.optional_y_h_output,
236            scan: self.optional_y_output.map(|slot| (slot, ScanInfo { axis: 1, chunk })),
237        }];
238        if self.body.have_extra_c_state() {
239            output_mapping.push(scan::OutputMapping {
240                state: true,
241                full_dim_hint: None,
242                last_value_slot: self.optional_y_c_output,
243                scan: None,
244            });
245        }
246
247        let scan_outputs = target.wire_node(
248            prefix,
249            tract_core::ops::scan::Scan::new(body, input_mapping, output_mapping, 0)?,
250            &outer_inputs,
251        )?;
252
253        let mut result = tvec!();
254        if let Some(slot) = self.optional_y_output {
255            // scan: [batch_size, seq_len, hidden_size]
256            if self.batch_first {
257                // onnx: Y.shape = [batch_size, seq_length, num_directions, hidden_size]
258                target_wire!(y = AxisOp::Add(2), scan_outputs[slot]);
259                result.push(y);
260            } else {
261                // onnx: Y.shape = [seq_length, num_directions, batch_size, hidden_size]
262                target_wire!(y_batch_middle = AxisOp::Move(1, 0), scan_outputs[slot]);
263                target_wire!(y = AxisOp::Add(1), y_batch_middle);
264                result.push(y);
265            }
266        }
267        if let Some(slot) = self.optional_y_h_output {
268            if self.batch_first {
269                result.push(scan_outputs[slot]);
270            } else {
271                target_wire!(y_h_batch_middle = AxisOp::Move(1, 0), scan_outputs[slot]);
272                result.push(y_h_batch_middle);
273            }
274        }
275        if let Some(slot) = self.optional_y_c_output {
276            if self.batch_first {
277                result.push(scan_outputs[slot]);
278            } else {
279                target_wire!(y_c_batch_middle = AxisOp::Move(1, 0), scan_outputs[slot]);
280                result.push(y_c_batch_middle);
281            }
282        }
283
284        Ok(result)
285    }
286}
287
288impl Expansion for CommonRec {
289    fn name(&self) -> StaticName {
290        self.body.name().into()
291    }
292
293    fn info(&self) -> TractResult<Vec<String>> {
294        Ok(vec![format!("batch_first: {:?}", self.batch_first)])
295    }
296
297    fn validation(&self) -> Validation {
298        Validation::Rounding
299    }
300
301    fn nboutputs(&self) -> TractResult<usize> {
302        Ok(self.optional_y_output.is_some() as usize
303            + self.optional_y_h_output.is_some() as usize
304            + self.optional_y_c_output.is_some() as usize)
305    }
306
307    fn rules<'r, 'p: 'r, 's: 'r>(
308        &'s self,
309        s: &mut Solver<'r>,
310        inputs: &'p [TensorProxy],
311        outputs: &'p [TensorProxy],
312    ) -> TractResult<()> {
313        let input_count = 3
314            + self.optional_bias_input.is_some() as usize
315            + self.optional_sequence_lens_input.is_some() as usize
316            + self.optional_initial_h_input.is_some() as usize
317            + self.optional_initial_c_input.is_some() as usize
318            + self.optional_p_input.is_some() as usize;
319        check_input_arity(inputs, input_count)?;
320        let output_count = self.optional_y_output.is_some() as usize
321            + self.optional_y_h_output.is_some() as usize
322            + self.optional_y_c_output.is_some() as usize;
323        check_output_arity(outputs, output_count)?;
324        s.equals(&inputs[0].datum_type, &inputs[1].datum_type)?;
325        s.equals(&inputs[0].datum_type, &inputs[2].datum_type)?;
326        s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
327        s.equals(&inputs[0].rank, 3)?;
328        s.equals(&inputs[1].rank, 3)?;
329        s.equals(&inputs[2].rank, 3)?;
330
331        /* If 0
332         *      X.shape = [seq_length, batch_size, input_size],
333         *      Y.shape = [seq_length, num_directions, batch_size, hidden_size],
334         *      initial_h.shape = Y_h.shape = [num_directions, batch_size, hidden_size].
335         *  If 1,
336         *      X.shape = [batch_size, seq_length, input_size],
337         *      Y.shape = [batch_size, seq_length, num_directions, hidden_size],
338         *      initial_h.shape = Y_h.shape = [batch_size, num_directions, hidden_size].
339         */
340
341        let b = if self.batch_first { 0 } else { 1 };
342        let b_in_y = if self.batch_first { 0 } else { 2 };
343        let seq_len = if self.batch_first { 1 } else { 0 };
344        let dirs = if self.batch_first { 1 } else { 0 };
345        let dirs_in_y = if self.batch_first { 2 } else { 1 };
346
347        let (w_mul, b_mul) = self.body.w_b_multipliers();
348
349        s.equals(&inputs[1].shape[0], &inputs[2].shape[0])?; // num_directions
350        s.equals(&inputs[1].shape[1], (w_mul as i64) * inputs[2].shape[2].bex())?; // hidden_size
351        s.equals(&inputs[2].shape[1], (w_mul as i64) * inputs[2].shape[2].bex())?; // hidden_size
352        if let Some(bias) = self.optional_bias_input {
353            s.equals(&inputs[bias].datum_type, &inputs[0].datum_type)?;
354            s.equals(&inputs[bias].rank, 2)?;
355            s.equals(&inputs[bias].shape[0], &inputs[2].shape[0])?; // num_directions
356            s.equals(&inputs[bias].shape[1], (b_mul as i64) * inputs[2].shape[2].bex())?;
357            // 6 * hidden_size
358        }
359        if let Some(seq_len) = self.optional_sequence_lens_input {
360            s.equals(&inputs[seq_len].rank, 1)?;
361            s.equals(&inputs[seq_len].shape[0], &inputs[0].shape[b])?; // batch_size
362        }
363        if let Some(initial_h) = self.optional_initial_h_input {
364            s.equals(&inputs[initial_h].datum_type, &inputs[0].datum_type)?;
365            s.equals(&inputs[initial_h].rank, 3)?;
366            s.equals(&inputs[initial_h].shape[dirs], &inputs[1].shape[0])?; // num_directions
367            s.equals(&inputs[initial_h].shape[b], &inputs[0].shape[b])?; // batch_size
368            s.equals(&inputs[initial_h].shape[2], &inputs[2].shape[2])?; // hidden_size
369        }
370        if let Some(initial_c) = self.optional_initial_c_input {
371            s.equals(&inputs[initial_c].datum_type, &inputs[0].datum_type)?;
372            s.equals(&inputs[initial_c].rank, 3)?;
373            s.equals(&inputs[initial_c].shape[dirs], &inputs[1].shape[0])?; // num_directions
374            s.equals(&inputs[initial_c].shape[b], &inputs[0].shape[b])?; // batch_size
375            s.equals(&inputs[initial_c].shape[2], &inputs[2].shape[2])?; // hidden_size
376        }
377        if let Some(p) = self.optional_p_input {
378            s.equals(&inputs[p].datum_type, &inputs[0].datum_type)?;
379            s.equals(&inputs[p].rank, 2)?;
380            s.equals(&inputs[p].shape[0], &inputs[1].shape[0])?; // num_directions
381            s.equals(&inputs[p].shape[1], 3 * inputs[2].shape[2].bex())?; // hidden_size
382        }
383        if let Some(y) = self.optional_y_output {
384            s.equals(&outputs[y].datum_type, &inputs[0].datum_type)?;
385            s.equals(&outputs[y].rank, 4)?;
386            s.equals(&outputs[y].shape[seq_len], &inputs[0].shape[seq_len])?; // seq_lenght
387            s.equals(&outputs[y].shape[dirs_in_y], &inputs[1].shape[0])?; // num_directions
388            s.equals(&outputs[y].shape[b_in_y], &inputs[0].shape[b])?; // batch_size
389            s.equals(&outputs[y].shape[3], &inputs[2].shape[2])?; // hidden_size
390        }
391        if let Some(y_h) = self.optional_y_h_output {
392            s.equals(&outputs[y_h].datum_type, &inputs[0].datum_type)?;
393            s.equals(&outputs[y_h].rank, 3)?;
394            s.equals(&outputs[y_h].shape[dirs], &inputs[1].shape[0])?; // num_directions
395            s.equals(&outputs[y_h].shape[b], &inputs[0].shape[b])?; // batch_size
396            s.equals(&outputs[y_h].shape[2], &inputs[2].shape[2])?; // hidden_size
397        }
398        if let Some(y_c) = self.optional_y_c_output {
399            s.equals(&outputs[y_c].datum_type, &inputs[0].datum_type)?;
400            s.equals(&outputs[y_c].rank, 3)?;
401            s.equals(&outputs[y_c].shape[dirs], &inputs[1].shape[0])?; // num_directions
402            s.equals(&outputs[y_c].shape[b], &inputs[0].shape[b])?; // batch_size
403            s.equals(&outputs[y_c].shape[2], &inputs[2].shape[2])?; // hidden_size
404        }
405        Ok(())
406    }
407
408    fn wire(
409        &self,
410        prefix: &str,
411        target: &mut TypedModel,
412        inputs: &[OutletId],
413    ) -> TractResult<TVec<OutletId>> {
414        use tract_hir::tract_core::ops::array::TypedConcat;
415        let fore = self.wire_one_side(prefix, target, inputs, 0)?;
416        let w_fact = target.outlet_fact(inputs[1])?;
417        if w_fact.shape[0] == 2.into() {
418            let back = self.wire_one_side(&format!("{prefix}.back"), target, inputs, 1)?;
419            let mut outputs = tvec!(0.into(); self.nboutputs()?);
420            if let Some(ix) = self.optional_y_output {
421                outputs[ix] = target.wire_node(
422                    format!("{prefix}.merge_y_output"),
423                    TypedConcat::new(1),
424                    &[fore[ix], back[ix]],
425                )?[0];
426            }
427            if let Some(ix) = self.optional_y_h_output {
428                outputs[ix] = target.wire_node(
429                    format!("{prefix}.merge_y_h_output"),
430                    TypedConcat::new(0),
431                    &[fore[ix], back[ix]],
432                )?[0];
433            }
434            if let Some(ix) = self.optional_y_c_output {
435                outputs[ix] = target.wire_node(
436                    format!("{prefix}.merge_y_c_output"),
437                    TypedConcat::new(0),
438                    &[fore[ix], back[ix]],
439                )?[0];
440            }
441            Ok(outputs)
442        } else {
443            Ok(fore)
444        }
445    }
446}