svod_model/jit/
recurrent.rs1use std::time::{Duration, Instant};
2
3use snafu::ResultExt;
4use svod_device::Buffer;
5
6use crate::jit::{DeviceSnafu, JitError, Result};
7
8pub struct LstmState {
13 pub h: Vec<f32>,
14 pub c: Vec<f32>,
15}
16
17impl LstmState {
18 pub fn zeros(total: usize) -> Self {
19 Self { h: vec![0.0f32; total], c: vec![0.0f32; total] }
20 }
21
22 pub fn reset(&mut self) {
23 self.h.fill(0.0);
24 self.c.fill(0.0);
25 }
26}
27
28#[derive(Default, Clone, Debug)]
32pub struct StepTiming {
33 pub pack: Duration,
34 pub exec: Duration,
35 pub read: Duration,
36}
37
38pub trait RecurrentJit {
45 fn pack_state(&mut self, state: &LstmState) -> Result<()>;
48
49 fn execute_step(&mut self) -> Result<()>;
51
52 fn output_buffer(&self) -> Result<&Buffer>;
54}
55
56pub struct JitRecurrent<J: RecurrentJit> {
59 pub jit: J,
60 state: LstmState,
61 head_buf: Vec<f32>,
62 head_len: usize,
63 pub last_timing: StepTiming,
64}
65
66impl<J: RecurrentJit> JitRecurrent<J> {
67 pub fn new(jit: J, state: LstmState, head_len: usize) -> Result<Self> {
73 let declared_state = state.h.len() + state.c.len();
74 let actual = jit.output_buffer()?.size() / std::mem::size_of::<f32>();
75 if actual != head_len + declared_state {
76 return Err(JitError::OutputLayoutMismatch { declared_head: head_len, declared_state, actual });
77 }
78 Ok(Self { jit, state, head_buf: vec![0.0; head_len], head_len, last_timing: StepTiming::default() })
79 }
80
81 pub fn state(&self) -> &LstmState {
82 &self.state
83 }
84 pub fn state_mut(&mut self) -> &mut LstmState {
85 &mut self.state
86 }
87
88 pub fn head_len(&self) -> usize {
89 self.head_len
90 }
91
92 pub fn step<F>(&mut self, pack_inputs: F) -> Result<&[f32]>
98 where
99 F: FnOnce(&mut J) -> Result<()>,
100 {
101 let t0 = Instant::now();
102 pack_inputs(&mut self.jit)?;
103 self.jit.pack_state(&self.state)?;
104 let t1 = Instant::now();
105 self.jit.execute_step()?;
106 let t2 = Instant::now();
107 {
108 let out = self.jit.output_buffer()?;
109 let arr = out.as_array::<f32>().context(DeviceSnafu)?;
110 let flat = arr.as_slice().expect("contiguous JIT output");
111 let head_len = self.head_len;
112 let h_len = self.state.h.len();
113 self.head_buf.copy_from_slice(&flat[..head_len]);
114 self.state.h.copy_from_slice(&flat[head_len..head_len + h_len]);
115 self.state.c.copy_from_slice(&flat[head_len + h_len..]);
116 }
117 let t3 = Instant::now();
118 self.last_timing = StepTiming { pack: t1 - t0, exec: t2 - t1, read: t3 - t2 };
119 Ok(&self.head_buf)
120 }
121
122 pub fn reset(&mut self) {
125 self.state.reset();
126 }
127}