Skip to main content

scirs2_interpolate/lie_group/
kernel.rs

1//! Geometric kernels for Lie group interpolation.
2//!
3//! Provides distance functions and kernel evaluations for:
4//! - S² (2-sphere): great-circle (geodesic) distance
5//! - SO(3) rotation group: quaternion-based geodesic distance
6//! - SE(3) rigid motions: product-metric distance
7//!
8//! The kernel functions map geodesic distances to kernel values suitable
9//! for radial basis function interpolation on curved spaces.
10
11use std::f64::consts::PI;
12
13/// Kernel type for geometric (manifold) RBF interpolation.
14#[derive(Clone, Debug)]
15pub enum GeometricKernel {
16    /// Heat kernel: `k(d) = exp(-d² / σ²)`.
17    ///
18    /// Corresponds to the fundamental solution of the heat equation on the manifold.
19    Heat {
20        /// Bandwidth parameter σ > 0.
21        sigma: f64,
22    },
23
24    /// Zonal kernel via truncated spherical harmonic expansion.
25    ///
26    /// `k(θ) = Σ_{l=0}^{L} (2l+1)/(4π) · exp(-l(l+1)σ²) · P_l(cos θ)`
27    ///
28    /// where P_l are Legendre polynomials.  Only meaningful on S².
29    SphericalHarmonic {
30        /// Maximum angular frequency (bandwidth).
31        bandwidth: usize,
32        /// Diffusion time / decay parameter σ > 0.
33        sigma: f64,
34    },
35
36    /// Matérn kernel using geodesic distance.
37    ///
38    /// Supports ν = 0.5 (exponential), 1.5, 2.5; falls back to Gaussian for other ν.
39    Matern {
40        /// Smoothness parameter ν (typically 0.5, 1.5, or 2.5).
41        nu: f64,
42        /// Length-scale parameter ℓ > 0.
43        length_scale: f64,
44    },
45}
46
47/// Compute the geodesic (great-circle) distance on S² between two unit vectors.
48///
49/// Both `u` and `v` should be unit vectors in ℝ³ (‖u‖ = ‖v‖ = 1).
50/// Returns an angle in [0, π].
51///
52/// # Examples
53///
54/// ```
55/// use scirs2_interpolate::lie_group::kernel::sphere_geodesic_dist;
56/// let north = [0.0_f64, 0.0, 1.0];
57/// let south = [0.0_f64, 0.0, -1.0];
58/// let d = sphere_geodesic_dist(&north, &south);
59/// assert!((d - std::f64::consts::PI).abs() < 1e-12);
60/// ```
61pub fn sphere_geodesic_dist(u: &[f64; 3], v: &[f64; 3]) -> f64 {
62    let dot = (u[0] * v[0] + u[1] * v[1] + u[2] * v[2]).clamp(-1.0, 1.0);
63    dot.acos()
64}
65
66/// Compute the geodesic distance on SO(3) between two unit quaternions.
67///
68/// Quaternions are given as `[w, x, y, z]` (scalar-first convention).
69/// Returns an angle in [0, π].
70///
71/// The formula exploits the double-cover SO(3) ≅ S³/±1:
72/// `d(q1, q2) = 2 · arccos(|q1·q2|)`.
73///
74/// # Examples
75///
76/// ```
77/// use scirs2_interpolate::lie_group::kernel::so3_geodesic_dist;
78/// let identity = [1.0_f64, 0.0, 0.0, 0.0];
79/// let d = so3_geodesic_dist(&identity, &identity);
80/// assert!(d.abs() < 1e-12);
81/// ```
82pub fn so3_geodesic_dist(q1: &[f64; 4], q2: &[f64; 4]) -> f64 {
83    // |q1·q2|: absolute dot product handles the double-cover identification q ~ -q.
84    let dot = (q1[0] * q2[0] + q1[1] * q2[1] + q1[2] * q2[2] + q1[3] * q2[3])
85        .abs()
86        .clamp(0.0, 1.0);
87    2.0 * dot.acos()
88}
89
90/// Compute a weighted product-metric distance on SE(3).
91///
92/// SE(3) = ℝ³ ⋊ SO(3).  The distance combines Euclidean translation distance
93/// and the SO(3) geodesic:
94/// `d((t1,q1),(t2,q2)) = ‖t1-t2‖ + w_rot · d_{SO(3)}(q1,q2)`.
95///
96/// # Arguments
97///
98/// * `t1`, `t2` — translation vectors in ℝ³.
99/// * `q1`, `q2` — unit quaternions `[w, x, y, z]` for the rotation components.
100/// * `w_rot` — non-negative weight balancing rotational vs translational distance.
101pub fn se3_geodesic_dist(
102    t1: &[f64; 3],
103    q1: &[f64; 4],
104    t2: &[f64; 3],
105    q2: &[f64; 4],
106    w_rot: f64,
107) -> f64 {
108    let dt = ((t1[0] - t2[0]).powi(2) + (t1[1] - t2[1]).powi(2) + (t1[2] - t2[2]).powi(2)).sqrt();
109    let dr = so3_geodesic_dist(q1, q2);
110    dt + w_rot * dr
111}
112
113/// Evaluate a [`GeometricKernel`] at a given geodesic distance.
114///
115/// Returns a non-negative kernel value.  The functions are all positive (semi-)
116/// definite on their respective manifolds when the parameters are in range.
117pub fn eval_kernel(dist: f64, kernel: &GeometricKernel) -> f64 {
118    match kernel {
119        GeometricKernel::Heat { sigma } => {
120            let s2 = sigma * sigma;
121            if s2 < f64::EPSILON {
122                return if dist < f64::EPSILON { 1.0 } else { 0.0 };
123            }
124            (-dist * dist / s2).exp()
125        }
126
127        GeometricKernel::Matern { nu, length_scale } => {
128            let r = dist / length_scale.max(f64::EPSILON);
129            if r < 1e-12 {
130                return 1.0;
131            }
132            // Match on 2*nu rounded to nearest integer to avoid floating-point equality.
133            let two_nu = (nu * 2.0).round() as u32;
134            match two_nu {
135                1 => (-r).exp(),                                             // ν = 0.5
136                3 => (1.0 + 3_f64.sqrt() * r) * (-(3_f64.sqrt() * r)).exp(), // ν = 1.5
137                5 => (1.0 + 5_f64.sqrt() * r + 5.0 / 3.0 * r * r) * (-(5_f64.sqrt() * r)).exp(), // ν = 2.5
138                _ => (-r * r / 2.0).exp(), // fallback: squared-exponential (smooth)
139            }
140        }
141
142        GeometricKernel::SphericalHarmonic { bandwidth, sigma } => {
143            // Truncated zonal harmonic expansion using Bonnet recurrence for P_l.
144            // k(θ) = Σ_{l=0}^{L} (2l+1)/(4π) · exp(-l(l+1)σ²) · P_l(cos θ)
145            let cos_theta = if dist < 1e-12 {
146                1.0
147            } else {
148                dist.cos().clamp(-1.0, 1.0)
149            };
150            // Legendre recurrence: P_0=1, P_1=x, P_l = ((2l-1)x·P_{l-1} - (l-1)P_{l-2}) / l
151            let mut p_prev = 1.0_f64; // P_{l-2}
152            let mut p_curr = cos_theta; // P_{l-1}
153            let mut sum = 0.0_f64;
154            for l in 0usize..=*bandwidth {
155                let pl = if l == 0 {
156                    1.0
157                } else if l == 1 {
158                    cos_theta
159                } else {
160                    let lf = l as f64;
161                    let next = ((2.0 * lf - 1.0) * cos_theta * p_curr - (lf - 1.0) * p_prev) / lf;
162                    p_prev = p_curr;
163                    p_curr = next;
164                    next
165                };
166                let weight = (2 * l + 1) as f64 / (4.0 * PI)
167                    * (-((l * (l + 1)) as f64) * sigma * sigma).exp();
168                sum += weight * pl;
169            }
170            sum
171        }
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178
179    #[test]
180    fn test_sphere_geodesic_dist_poles() {
181        let north = [0.0_f64, 0.0, 1.0];
182        let south = [0.0_f64, 0.0, -1.0];
183        let d = sphere_geodesic_dist(&north, &south);
184        assert!(
185            (d - PI).abs() < 1e-12,
186            "antipodal distance should be π, got {d}"
187        );
188    }
189
190    #[test]
191    fn test_sphere_geodesic_dist_same_point() {
192        let p = [1.0_f64, 0.0, 0.0];
193        let d = sphere_geodesic_dist(&p, &p);
194        assert!(d.abs() < 1e-12, "self-distance should be 0, got {d}");
195    }
196
197    #[test]
198    fn test_sphere_geodesic_dist_quarter_circle() {
199        let p1 = [1.0_f64, 0.0, 0.0];
200        let p2 = [0.0_f64, 1.0, 0.0];
201        let d = sphere_geodesic_dist(&p1, &p2);
202        assert!(
203            (d - PI / 2.0).abs() < 1e-12,
204            "quarter circle should be π/2, got {d}"
205        );
206    }
207
208    #[test]
209    fn test_so3_geodesic_dist_identity() {
210        let id = [1.0_f64, 0.0, 0.0, 0.0];
211        let d = so3_geodesic_dist(&id, &id);
212        assert!(d.abs() < 1e-12, "identity distance should be 0, got {d}");
213    }
214
215    #[test]
216    fn test_so3_geodesic_dist_half_cover() {
217        // q and -q represent the same rotation; distance should be 0.
218        // Use exact 1/√2 to ensure the quaternion is precisely unit length.
219        let s = 1.0_f64 / 2.0_f64.sqrt();
220        let q = [s, s, 0.0_f64, 0.0];
221        let neg_q = [-s, -s, 0.0_f64, 0.0];
222        let d = so3_geodesic_dist(&q, &neg_q);
223        // Tolerance reflects floating-point precision in arccos near 1.0.
224        assert!(d.abs() < 1e-6, "q and -q should have distance 0, got {d}");
225    }
226
227    #[test]
228    fn test_eval_kernel_heat_at_zero() {
229        let k = eval_kernel(0.0, &GeometricKernel::Heat { sigma: 1.0 });
230        assert!(
231            (k - 1.0).abs() < 1e-12,
232            "heat kernel at 0 should be 1, got {k}"
233        );
234    }
235
236    #[test]
237    fn test_eval_kernel_matern_at_zero() {
238        for nu in [0.5, 1.5, 2.5, 3.0] {
239            let k = eval_kernel(
240                0.0,
241                &GeometricKernel::Matern {
242                    nu,
243                    length_scale: 1.0,
244                },
245            );
246            assert!(
247                (k - 1.0).abs() < 1e-12,
248                "Matern(nu={nu}) at 0 should be 1, got {k}"
249            );
250        }
251    }
252
253    #[test]
254    fn test_eval_kernel_spherical_harmonic_non_negative() {
255        let kernel = GeometricKernel::SphericalHarmonic {
256            bandwidth: 10,
257            sigma: 0.3,
258        };
259        for dist_deg in [0, 30, 60, 90, 120, 150, 180] {
260            let d = (dist_deg as f64) * PI / 180.0;
261            let k = eval_kernel(d, &kernel);
262            assert!(
263                k.is_finite(),
264                "SphericalHarmonic kernel at {dist_deg}° should be finite, got {k}"
265            );
266        }
267    }
268
269    #[test]
270    fn test_se3_geodesic_dist_same_pose() {
271        let t = [1.0_f64, 2.0, 3.0];
272        let q = [1.0_f64, 0.0, 0.0, 0.0];
273        let d = se3_geodesic_dist(&t, &q, &t, &q, 1.0);
274        assert!(d.abs() < 1e-12, "same pose distance should be 0, got {d}");
275    }
276}