Skip to main content

zilla_muf/
complex_ops.rs

1// complex_ops.rs — complex-arithmetic helpers for S4-style models, which
2// diagonalize the state matrix into complex eigenvalues. Kept as a thin,
3// stable layer over `num-complex` so the SSM code has one canonical
4// spelling for each operation even if the backing implementation changes.
5use num_complex::Complex;
6use num_traits::Float;
7
8/// Complex exponential `e^z` for a complex argument.
9///
10/// Thin wrapper over `num_complex`'s own `exp`, exposed here so callers
11/// have a single, stable name for "exponentiate a complex eigenvalue"
12/// (used when discretizing or generating kernels for diagonalized state
13/// matrices). For `z = x + iy` this returns `e^x · (cos y + i·sin y)`.
14///
15/// # Example
16///
17/// ```
18/// use num_complex::Complex;
19/// use zilla_muf::complex_ops::complex_exp;
20/// use std::f64::consts::PI;
21/// // Euler's identity: e^(iπ) = -1
22/// let result = complex_exp(Complex::new(0.0_f64, PI));
23/// assert!((result.re + 1.0).abs() < 1e-10);
24/// assert!(result.im.abs() < 1e-10);
25/// ```
26pub fn complex_exp<T: Float>(z: Complex<T>) -> Complex<T> {
27	z.exp()
28}
29
30/// Element-wise exponential of a diagonal complex matrix's eigenvalues.
31///
32/// For a diagonal matrix `A = diag(λ_0, …, λ_{n-1})`, returns
33/// `[exp(λ_0), …, exp(λ_{n-1})]` — the diagonal of `exp(A)`.
34/// This is the S4D ZOH discretization step: given continuous-time
35/// eigenvalues `Λ`, the discrete-time transition is `diag(exp(Λ · Δt))`.
36///
37/// # Example
38///
39/// ```
40/// use num_complex::Complex;
41/// use zilla_muf::complex_ops::{complex_exp, diag_complex_matrix_exp};
42/// let lambdas = vec![Complex::new(0.0_f64, 1.0), Complex::new(-1.0_f64, 0.0)];
43/// let result = diag_complex_matrix_exp(&lambdas);
44/// for (r, &l) in result.iter().zip(lambdas.iter()) {
45///     let expected = complex_exp(l);
46///     assert!((r.re - expected.re).abs() < 1e-12);
47///     assert!((r.im - expected.im).abs() < 1e-12);
48/// }
49/// ```
50pub fn diag_complex_matrix_exp<T: Float>(lambdas: &[Complex<T>]) -> Vec<Complex<T>> {
51	lambdas.iter().map(|&z| z.exp()).collect()
52}
53
54/// Real output from a conjugate-pair-folded complex state.
55///
56/// In S4D and related models the state vector is stored as conjugate pairs
57/// `(h_i, conj(h_i))`. The real output projection then reduces to
58/// `2 · Re(Σ c_i · h_i)`, which avoids keeping the redundant conjugate
59/// half explicitly.
60///
61/// `states` and `c_coeffs` must have the same length; both hold one
62/// element per conjugate pair (the upper half).
63///
64/// # Example
65///
66/// ```
67/// use num_complex::Complex;
68/// use zilla_muf::complex_ops::conjugate_pair_output;
69/// let states = vec![Complex::new(1.0_f64, 0.5), Complex::new(0.0_f64, -1.0)];
70/// let c = vec![Complex::new(1.0_f64, 0.0), Complex::new(0.0_f64, 1.0)];
71/// let y = conjugate_pair_output(&states, &c);
72/// // brute-force: 2 * Re(c[0]*h[0] + c[1]*h[1])
73/// let dot = c[0] * states[0] + c[1] * states[1];
74/// assert!((y - 2.0 * dot.re).abs() < 1e-12);
75/// ```
76pub fn conjugate_pair_output<T: Float>(states: &[Complex<T>], c_coeffs: &[Complex<T>]) -> T {
77	assert_eq!(states.len(), c_coeffs.len(), "states and c_coeffs must be the same length");
78	let two = T::one() + T::one();
79	let re_sum = states
80		.iter()
81		.zip(c_coeffs.iter())
82		.fold(T::zero(), |acc, (&h, &c)| acc + (c * h).re);
83	two * re_sum
84}
85
86#[cfg(test)]
87mod tests {
88	use super::*;
89	use std::f64::consts::PI;
90
91	#[test]
92	fn euler_identity() {
93		// e^(iπ) = -1 + 0i
94		let result = complex_exp(Complex::new(0.0_f64, PI));
95		assert!((result.re + 1.0).abs() < 1e-10, "re={}", result.re);
96		assert!(result.im.abs() < 1e-10, "im={}", result.im);
97	}
98
99	#[test]
100	fn diag_matrix_exp_matches_elementwise() {
101		let lambdas = vec![
102			Complex::new(0.0_f64, PI),
103			Complex::new(-1.0, 0.5),
104			Complex::new(0.0, 0.0),
105		];
106		let result = diag_complex_matrix_exp(&lambdas);
107		for (&l, r) in lambdas.iter().zip(result.iter()) {
108			let expected = complex_exp(l);
109			assert!((r.re - expected.re).abs() < 1e-12);
110			assert!((r.im - expected.im).abs() < 1e-12);
111		}
112	}
113
114	#[test]
115	fn diag_matrix_exp_empty() {
116		assert!(diag_complex_matrix_exp::<f64>(&[]).is_empty());
117	}
118
119	#[test]
120	fn conjugate_pair_output_matches_brute_force() {
121		let states = vec![
122			Complex::new(1.0_f64, 0.5),
123			Complex::new(-0.3, 0.7),
124		];
125		let c = vec![
126			Complex::new(2.0_f64, -1.0),
127			Complex::new(0.5, 0.5),
128		];
129		let y = conjugate_pair_output(&states, &c);
130		let dot = c[0] * states[0] + c[1] * states[1];
131		assert!((y - 2.0 * dot.re).abs() < 1e-12, "y={y}, expected={}", 2.0 * dot.re);
132	}
133}