zilla_muf/scan/sequential.rs
1use num_traits::Float;
2
3/// Sequential (recurrent) scan — the correctness baseline that every
4/// other scan in the crate is tested against.
5///
6/// Evaluates the first-order linear recurrence
7/// `h_t = a_t * h_{t-1} + b_t`, with `h_{-1} = h0`
8/// left-to-right, returning the full state sequence `[h_0, ..., h_{n-1}]`.
9///
10/// - `a`: per-step decay / transition coefficients
11/// - `b`: per-step inputs (must be the same length as `a`)
12/// - `h0`: the initial state fed in before the first step
13///
14/// Setting `a_t = 1` everywhere collapses this into a plain prefix sum
15/// (running total) of `b` — the easiest case to eyeball, which is exactly
16/// what the test below checks.
17///
18/// Cost: O(n), one multiply-add per element. Inherently sequential: each
19/// step needs the previous `h`, so it can't be parallelized as-is.
20/// `chunked_scan` exists to break that dependency for large inputs; this
21/// function is the reference it must match. Generic over the float type
22/// `T` (f32 for speed, f64 for numerical testing).
23///
24/// # Example
25///
26/// ```
27/// use zilla_muf::scan::sequential_scan;
28/// // a = 1 everywhere → plain prefix sum of b
29/// let h = sequential_scan(&[1.0, 1.0, 1.0], &[1.0, 2.0, 3.0], 0.0);
30/// assert_eq!(h, vec![1.0, 3.0, 6.0]);
31/// ```
32pub fn sequential_scan<T: Float>(a: &[T], b: &[T], h0: T) -> Vec<T> {
33 // A length mismatch is a caller bug, not a recoverable runtime state —
34 // fail loudly rather than silently truncating to the shorter slice.
35 assert_eq!(a.len(), b.len(), "a and b must be the same length");
36
37 let mut h = h0; // running state, seeded with h0
38 let mut out = Vec::with_capacity(a.len()); // exact size is known up front
39 for i in 0..a.len() {
40 h = a[i] * h + b[i]; // one recurrence step: decay the state, add input
41 out.push(h); // every intermediate state is part of the output
42 }
43 out
44}
45
46#[cfg(test)]
47mod tests {
48 use super::*;
49
50 #[test]
51 fn cumulative_sum_case() {
52 // a = 1 everywhere -> plain running sum
53 let a = [1.0, 1.0, 1.0, 1.0];
54 let b = [1.0, 2.0, 3.0, 4.0];
55 assert_eq!(sequential_scan(&a, &b, 0.0), vec![1.0, 3.0, 6.0, 10.0]);
56 }
57}