1use snafu::Snafu;
2use svod_dtype::DType;
3
4mod recurrent;
5pub use recurrent::{JitRecurrent, LstmState, RecurrentJit, StepTiming};
6
7#[derive(Clone, Debug)]
12pub struct InputSpec {
13 pub shape: Vec<usize>,
14 pub dtype: DType,
15}
16
17impl InputSpec {
18 pub fn new(shape: &[usize], dtype: DType) -> Self {
19 Self { shape: shape.to_vec(), dtype }
20 }
21
22 pub fn f32(shape: &[usize]) -> Self {
23 Self::new(shape, DType::Float32)
24 }
25
26 pub fn i32(shape: &[usize]) -> Self {
27 Self::new(shape, DType::Int32)
28 }
29
30 pub fn i64(shape: &[usize]) -> Self {
31 Self::new(shape, DType::Int64)
32 }
33}
34
35#[derive(Debug, Snafu)]
36#[snafu(visibility(pub))]
37pub enum JitError {
38 #[snafu(display("JIT not prepared: call prepare() first"))]
39 NotPrepared,
40
41 #[snafu(display("input buffer not found: {name}"))]
42 InputBufferNotFound { name: &'static str },
43
44 #[snafu(display("duplicate JIT input buffer: {name} aliases {duplicate_of} with {buffer_id:?}"))]
45 DuplicateInputBuffer { name: &'static str, duplicate_of: &'static str, buffer_id: svod_device::BufferId },
46
47 #[snafu(display("{source}"))]
50 Build { source: Box<dyn std::error::Error + Send + Sync> },
51
52 #[snafu(display("{source}"))]
53 Tensor {
54 #[snafu(source(from(svod_tensor::error::Error, Box::new)))]
55 source: Box<svod_tensor::error::Error>,
56 },
57
58 #[snafu(display("{source}"))]
59 Device {
60 #[snafu(source(from(svod_device::error::Error, Box::new)))]
61 source: Box<svod_device::error::Error>,
62 },
63
64 #[snafu(display(
69 "JIT output layout mismatch: declared {declared_head} head + {declared_state} state elements \
70 ({}), actual {actual} elements. Check that the `build` closure returns `cat([head, h, c], -1)` \
71 with the declared shapes.",
72 declared_head + declared_state
73 ))]
74 OutputLayoutMismatch { declared_head: usize, declared_state: usize, actual: usize },
75
76 #[snafu(display("{source}"))]
77 Runtime { source: svod_runtime::Error },
78}
79
80pub type Result<T> = std::result::Result<T, JitError>;