Skip to main content

tensorlogic_sklears_kernels/learned_composition/
mixture.rs

1//! Core [`LearnedMixtureKernel`] type.
2//!
3//! Implements the forward pass `K_mix = sum_i p_i * K_i` and the analytical
4//! gradient `dK_mix/dw_i = p_i * (K_i - K_mix)` with numerically stable
5//! softmax (max subtraction).
6
7use std::fmt;
8use std::sync::Arc;
9
10use crate::error::{KernelError, Result};
11use crate::types::Kernel;
12
13/// A differentiable mixture over a library of base kernels.
14///
15/// The mixture is parameterised by a vector of logits `w`. Weights
16/// `p = softmax(w)` are always strictly positive and sum to 1. The
17/// evaluation is
18///
19/// ```text
20/// K_mix(x, y) = sum_i p_i * K_i(x, y).
21/// ```
22///
23/// Logits are unconstrained real numbers; the softmax parameterisation
24/// guarantees a valid convex combination on the simplex, which keeps the
25/// mixture positive semi-definite when every base kernel is PSD.
26#[derive(Clone)]
27pub struct LearnedMixtureKernel {
28    base_kernels: Vec<Arc<dyn Kernel>>,
29    logits: Vec<f64>,
30}
31
32impl LearnedMixtureKernel {
33    /// Build a mixture from a non-empty library and matching logits.
34    ///
35    /// Errors when the library is empty, the vectors disagree in length,
36    /// or any logit is non-finite.
37    pub fn new(base_kernels: Vec<Arc<dyn Kernel>>, logits: Vec<f64>) -> Result<Self> {
38        if base_kernels.is_empty() {
39            return Err(KernelError::InvalidParameter {
40                parameter: "base_kernels".to_string(),
41                value: "[]".to_string(),
42                reason: "learned mixture requires at least one base kernel".to_string(),
43            });
44        }
45        if base_kernels.len() != logits.len() {
46            return Err(KernelError::DimensionMismatch {
47                expected: vec![base_kernels.len()],
48                got: vec![logits.len()],
49                context: "LearnedMixtureKernel logits length".to_string(),
50            });
51        }
52        for (i, &w) in logits.iter().enumerate() {
53            if !w.is_finite() {
54                return Err(KernelError::InvalidParameter {
55                    parameter: format!("logits[{}]", i),
56                    value: w.to_string(),
57                    reason: "logits must be finite".to_string(),
58                });
59            }
60        }
61        Ok(Self {
62            base_kernels,
63            logits,
64        })
65    }
66
67    /// Build a mixture with uniform logits (all zeros → equal weights).
68    pub fn uniform(base_kernels: Vec<Arc<dyn Kernel>>) -> Result<Self> {
69        let n = base_kernels.len();
70        Self::new(base_kernels, vec![0.0; n])
71    }
72
73    /// Number of base kernels in the library.
74    pub fn num_kernels(&self) -> usize {
75        self.base_kernels.len()
76    }
77
78    /// Immutable view of the raw logits.
79    pub fn logits(&self) -> &[f64] {
80        &self.logits
81    }
82
83    /// Softmax weights `p_i = softmax(w)_i`. Always strictly positive,
84    /// always sums to 1 in exact arithmetic.
85    pub fn weights(&self) -> Vec<f64> {
86        softmax(&self.logits)
87    }
88
89    /// Replace the logits. The new vector must match `num_kernels()` and
90    /// every element must be finite.
91    pub fn set_logits(&mut self, new_logits: Vec<f64>) -> Result<()> {
92        if new_logits.len() != self.base_kernels.len() {
93            return Err(KernelError::DimensionMismatch {
94                expected: vec![self.base_kernels.len()],
95                got: vec![new_logits.len()],
96                context: "LearnedMixtureKernel::set_logits".to_string(),
97            });
98        }
99        for (i, &w) in new_logits.iter().enumerate() {
100            if !w.is_finite() {
101                return Err(KernelError::InvalidParameter {
102                    parameter: format!("logits[{}]", i),
103                    value: w.to_string(),
104                    reason: "logits must be finite".to_string(),
105                });
106            }
107        }
108        self.logits = new_logits;
109        Ok(())
110    }
111
112    /// Apply a raw gradient update `w_i <- w_i - lr * g_i` in place.
113    /// Used by [`crate::learned_composition::TrainableKernelMixture`].
114    pub fn apply_gradient_step(&mut self, gradient: &[f64], learning_rate: f64) -> Result<()> {
115        if gradient.len() != self.logits.len() {
116            return Err(KernelError::DimensionMismatch {
117                expected: vec![self.logits.len()],
118                got: vec![gradient.len()],
119                context: "LearnedMixtureKernel::apply_gradient_step".to_string(),
120            });
121        }
122        if !learning_rate.is_finite() {
123            return Err(KernelError::InvalidParameter {
124                parameter: "learning_rate".to_string(),
125                value: learning_rate.to_string(),
126                reason: "must be finite".to_string(),
127            });
128        }
129        for (w, &g) in self.logits.iter_mut().zip(gradient.iter()) {
130            if !g.is_finite() {
131                return Err(KernelError::InvalidParameter {
132                    parameter: "gradient".to_string(),
133                    value: g.to_string(),
134                    reason: "gradient entries must be finite".to_string(),
135                });
136            }
137            *w -= learning_rate * g;
138        }
139        Ok(())
140    }
141
142    /// Compute base-kernel values `[K_1(x,y), ..., K_n(x,y)]`.
143    fn per_kernel_values(&self, x: &[f64], y: &[f64]) -> Result<Vec<f64>> {
144        let mut values = Vec::with_capacity(self.base_kernels.len());
145        for kernel in &self.base_kernels {
146            values.push(kernel.compute(x, y)?);
147        }
148        Ok(values)
149    }
150
151    /// Evaluate the mixture on a single input pair.
152    pub fn evaluate(&self, x: &[f64], y: &[f64]) -> Result<f64> {
153        let weights = self.weights();
154        let mut acc = 0.0;
155        for (kernel, &w) in self.base_kernels.iter().zip(weights.iter()) {
156            acc += w * kernel.compute(x, y)?;
157        }
158        Ok(acc)
159    }
160
161    /// Return the analytical gradient `dK_mix/dw_i = p_i * (K_i - K_mix)`.
162    ///
163    /// This form is numerically cleaner than routing through the full
164    /// softmax Jacobian (it stays bounded as `p_i` concentrates mass).
165    pub fn gradient_wrt_logits(&self, x: &[f64], y: &[f64]) -> Result<Vec<f64>> {
166        let weights = self.weights();
167        let k_vals = self.per_kernel_values(x, y)?;
168        let k_mix: f64 = weights.iter().zip(k_vals.iter()).map(|(p, k)| p * k).sum();
169        Ok(weights
170            .iter()
171            .zip(k_vals.iter())
172            .map(|(p, k)| p * (k - k_mix))
173            .collect())
174    }
175
176    /// Return the forward value and the full gradient in one pass — the
177    /// preferred API for optimizer steps (avoids redundant evaluations).
178    pub fn evaluate_with_gradient(&self, x: &[f64], y: &[f64]) -> Result<(f64, Vec<f64>)> {
179        let weights = self.weights();
180        let k_vals = self.per_kernel_values(x, y)?;
181        let k_mix: f64 = weights.iter().zip(k_vals.iter()).map(|(p, k)| p * k).sum();
182        let grad: Vec<f64> = weights
183            .iter()
184            .zip(k_vals.iter())
185            .map(|(p, k)| p * (k - k_mix))
186            .collect();
187        Ok((k_mix, grad))
188    }
189
190    /// Compute a Gram matrix `G[i,j] = K_mix(xs[i], ys[j])` over two sets
191    /// of raw slices. Works for square `xs == ys` and rectangular cross-
192    /// evaluation alike.
193    pub fn compute_gram(&self, xs: &[&[f64]], ys: &[&[f64]]) -> Result<Vec<Vec<f64>>> {
194        let mut matrix = vec![vec![0.0; ys.len()]; xs.len()];
195        for (i, &xi) in xs.iter().enumerate() {
196            for (j, &yj) in ys.iter().enumerate() {
197                matrix[i][j] = self.evaluate(xi, yj)?;
198            }
199        }
200        Ok(matrix)
201    }
202}
203
204impl fmt::Debug for LearnedMixtureKernel {
205    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
206        let names: Vec<&str> = self.base_kernels.iter().map(|k| k.name()).collect();
207        f.debug_struct("LearnedMixtureKernel")
208            .field("base_kernels", &names)
209            .field("logits", &self.logits)
210            .finish()
211    }
212}
213
214impl Kernel for LearnedMixtureKernel {
215    fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
216        self.evaluate(x, y)
217    }
218
219    fn name(&self) -> &str {
220        "LearnedMixture"
221    }
222
223    fn is_psd(&self) -> bool {
224        // Softmax weights are strictly positive and sum to 1; the mixture
225        // is PSD whenever every base kernel is PSD.
226        self.base_kernels.iter().all(|k| k.is_psd())
227    }
228}
229
230/// Numerically stable softmax: subtract the max before exponentiating.
231pub(crate) fn softmax(logits: &[f64]) -> Vec<f64> {
232    if logits.is_empty() {
233        return Vec::new();
234    }
235    let max = logits.iter().copied().fold(f64::NEG_INFINITY, f64::max);
236    // `max` is finite by construction in `new()` / `set_logits()`.
237    let shifted: Vec<f64> = logits.iter().map(|&w| (w - max).exp()).collect();
238    let denom: f64 = shifted.iter().sum();
239    if denom <= 0.0 || !denom.is_finite() {
240        // Degenerate fallback — all exponentials underflowed or overflowed.
241        // Return the uniform distribution.
242        let n = logits.len() as f64;
243        return vec![1.0 / n; logits.len()];
244    }
245    shifted.iter().map(|&e| e / denom).collect()
246}