tract_onnx/ops/rec/
rnn.rs1use crate::model::ParsingContext;
2use crate::pb::*;
3use tract_hir::internal::*;
4use tract_hir::ops;
5use tract_hir::tract_core::ops::einsum::EinSum;
6
7use super::common::CommonRec;
8use super::common::WireBody;
9
10pub fn rnn(
11 _ctx: &ParsingContext,
12 pb: &NodeProto,
13) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
14 let rnn = RNN { fore: Box::new(ops::math::tanh()), back: Box::new(ops::math::tanh()) };
15 let common = CommonRec::from_node_and_options(pb, 3, 0, Box::new(rnn))?;
16 Ok((expand(common), vec![]))
17}
18
19#[derive(Debug, Clone, new)]
20pub struct RNN {
21 pub fore: Box<dyn TypedOp>,
22 pub back: Box<dyn TypedOp>,
23}
24
25impl WireBody for RNN {
26 fn name(&self) -> &'static str {
27 "RNN"
28 }
29
30 fn w_b_multipliers(&self) -> (usize, usize) {
31 (1, 2)
32 }
33
34 fn have_extra_c_state(&self) -> bool {
35 false
36 }
37
38 #[allow(non_snake_case)]
39 fn wire_body(&self, prefix: &str, body: &mut TypedModel) -> TractResult<()> {
40 use tract_hir::ops::{array, math};
41 macro_rules! wire {
42 ($name: ident = $op: expr, $($param: expr),*) => {
43 let $name = body.wire_node(
44 format!("{}.{}", prefix, stringify!($name)),
45 $op, [$($param),*].as_ref())?[0];
46 }
47 }
48
49 let Xt: OutletId = body.node_by_name("Xt").unwrap().id.into();
50 let W: OutletId = body.node_by_name("W").unwrap().id.into();
51 let R: OutletId = body.node_by_name("R").unwrap().id.into();
52 let Ht_1: OutletId = body.node_by_name("Ht_1").unwrap().id.into();
53 let b: Option<OutletId> = body.node_by_name("b").ok().map(|n| n.id.into());
54
55 let h_size = body.outlet_fact(R)?.shape[1].clone();
56
57 let bias = if let Some(b) = b {
58 wire!(Wbi = array::Slice::new(1, 0.to_dim() * &h_size, 1.to_dim() * &h_size), b);
59 wire!(Rbi = array::Slice::new(1, 1.to_dim() * &h_size, 2.to_dim() * &h_size), b);
60 wire!(bi = math::add(), Wbi, Rbi);
61 Some(bi)
62 } else {
63 None
64 };
65
66 let matmul_t = EinSum::new("mk,nk->mn".parse()?, f32::datum_type());
67
68 wire!(Xt_WiT = matmul_t.clone(), Xt, W);
70 wire!(Ht_1_RiT = matmul_t, Ht_1, R);
71
72 wire!(ht0 = math::add(), Xt_WiT, Ht_1_RiT);
73 let mut ht0 = ht0;
74 if let Some(bias) = bias {
75 wire!(ht_bias = math::add(), ht0, bias);
76 ht0 = ht_bias;
77 }
78 wire!(Ht = self.fore.clone(), ht0);
79
80 wire!(y_h = AxisOp::Add(1), Ht);
81 body.set_output_outlets(&[y_h])?;
82 Ok(())
83 }
84}