Skip to main content

svod_model/jit/
mod.rs

1use snafu::Snafu;
2use svod_dtype::DType;
3
4mod recurrent;
5pub use recurrent::{JitRecurrent, LstmState, RecurrentJit, StepTiming};
6
7/// Shape + dtype descriptor for a single JIT input. Used by
8/// `jit_wrapper!`-generated `prepare()` calls to allocate zero-initialized
9/// placeholder buffers internally — callers no longer construct fake
10/// `Tensor::zeros(..).realize()` placeholders.
11#[derive(Clone, Debug)]
12pub struct InputSpec {
13    pub shape: Vec<usize>,
14    pub dtype: DType,
15}
16
17impl InputSpec {
18    pub fn new(shape: &[usize], dtype: DType) -> Self {
19        Self { shape: shape.to_vec(), dtype }
20    }
21
22    pub fn f32(shape: &[usize]) -> Self {
23        Self::new(shape, DType::Float32)
24    }
25
26    pub fn i32(shape: &[usize]) -> Self {
27        Self::new(shape, DType::Int32)
28    }
29
30    pub fn i64(shape: &[usize]) -> Self {
31        Self::new(shape, DType::Int64)
32    }
33}
34
35#[derive(Debug, Snafu)]
36#[snafu(visibility(pub))]
37pub enum JitError {
38    #[snafu(display("JIT not prepared: call prepare() first"))]
39    NotPrepared,
40
41    #[snafu(display("input buffer not found: {name}"))]
42    InputBufferNotFound { name: &'static str },
43
44    #[snafu(display("duplicate JIT input buffer: {name} aliases {duplicate_of} with {buffer_id:?}"))]
45    DuplicateInputBuffer { name: &'static str, duplicate_of: &'static str, buffer_id: svod_device::BufferId },
46
47    /// Wraps the user-supplied error type returned by a `jit_wrapper!` build
48    /// closure. Genuine `Box<dyn>` because the closure's `E` is arbitrary.
49    #[snafu(display("{source}"))]
50    Build { source: Box<dyn std::error::Error + Send + Sync> },
51
52    #[snafu(display("{source}"))]
53    Tensor {
54        #[snafu(source(from(svod_tensor::error::Error, Box::new)))]
55        source: Box<svod_tensor::error::Error>,
56    },
57
58    #[snafu(display("{source}"))]
59    Device {
60        #[snafu(source(from(svod_device::error::Error, Box::new)))]
61        source: Box<svod_device::error::Error>,
62    },
63
64    /// `JitRecurrent::new` rejected a JIT whose output element count does not
65    /// match the declared `head_len + |h| + |c|`. Typically means the `build`
66    /// closure was changed and now emits a different layout than the wrapper
67    /// expects.
68    #[snafu(display(
69        "JIT output layout mismatch: declared {declared_head} head + {declared_state} state elements \
70         ({}), actual {actual} elements. Check that the `build` closure returns `cat([head, h, c], -1)` \
71         with the declared shapes.",
72        declared_head + declared_state
73    ))]
74    OutputLayoutMismatch { declared_head: usize, declared_state: usize, actual: usize },
75
76    #[snafu(display("{source}"))]
77    Runtime { source: svod_runtime::Error },
78}
79
80pub type Result<T> = std::result::Result<T, JitError>;