tensorlogic_sklears_kernels/learned_composition/
mixture.rs1use std::fmt;
8use std::sync::Arc;
9
10use crate::error::{KernelError, Result};
11use crate::types::Kernel;
12
13#[derive(Clone)]
27pub struct LearnedMixtureKernel {
28 base_kernels: Vec<Arc<dyn Kernel>>,
29 logits: Vec<f64>,
30}
31
32impl LearnedMixtureKernel {
33 pub fn new(base_kernels: Vec<Arc<dyn Kernel>>, logits: Vec<f64>) -> Result<Self> {
38 if base_kernels.is_empty() {
39 return Err(KernelError::InvalidParameter {
40 parameter: "base_kernels".to_string(),
41 value: "[]".to_string(),
42 reason: "learned mixture requires at least one base kernel".to_string(),
43 });
44 }
45 if base_kernels.len() != logits.len() {
46 return Err(KernelError::DimensionMismatch {
47 expected: vec![base_kernels.len()],
48 got: vec![logits.len()],
49 context: "LearnedMixtureKernel logits length".to_string(),
50 });
51 }
52 for (i, &w) in logits.iter().enumerate() {
53 if !w.is_finite() {
54 return Err(KernelError::InvalidParameter {
55 parameter: format!("logits[{}]", i),
56 value: w.to_string(),
57 reason: "logits must be finite".to_string(),
58 });
59 }
60 }
61 Ok(Self {
62 base_kernels,
63 logits,
64 })
65 }
66
67 pub fn uniform(base_kernels: Vec<Arc<dyn Kernel>>) -> Result<Self> {
69 let n = base_kernels.len();
70 Self::new(base_kernels, vec![0.0; n])
71 }
72
73 pub fn num_kernels(&self) -> usize {
75 self.base_kernels.len()
76 }
77
78 pub fn logits(&self) -> &[f64] {
80 &self.logits
81 }
82
83 pub fn weights(&self) -> Vec<f64> {
86 softmax(&self.logits)
87 }
88
89 pub fn set_logits(&mut self, new_logits: Vec<f64>) -> Result<()> {
92 if new_logits.len() != self.base_kernels.len() {
93 return Err(KernelError::DimensionMismatch {
94 expected: vec![self.base_kernels.len()],
95 got: vec![new_logits.len()],
96 context: "LearnedMixtureKernel::set_logits".to_string(),
97 });
98 }
99 for (i, &w) in new_logits.iter().enumerate() {
100 if !w.is_finite() {
101 return Err(KernelError::InvalidParameter {
102 parameter: format!("logits[{}]", i),
103 value: w.to_string(),
104 reason: "logits must be finite".to_string(),
105 });
106 }
107 }
108 self.logits = new_logits;
109 Ok(())
110 }
111
112 pub fn apply_gradient_step(&mut self, gradient: &[f64], learning_rate: f64) -> Result<()> {
115 if gradient.len() != self.logits.len() {
116 return Err(KernelError::DimensionMismatch {
117 expected: vec![self.logits.len()],
118 got: vec![gradient.len()],
119 context: "LearnedMixtureKernel::apply_gradient_step".to_string(),
120 });
121 }
122 if !learning_rate.is_finite() {
123 return Err(KernelError::InvalidParameter {
124 parameter: "learning_rate".to_string(),
125 value: learning_rate.to_string(),
126 reason: "must be finite".to_string(),
127 });
128 }
129 for (w, &g) in self.logits.iter_mut().zip(gradient.iter()) {
130 if !g.is_finite() {
131 return Err(KernelError::InvalidParameter {
132 parameter: "gradient".to_string(),
133 value: g.to_string(),
134 reason: "gradient entries must be finite".to_string(),
135 });
136 }
137 *w -= learning_rate * g;
138 }
139 Ok(())
140 }
141
142 fn per_kernel_values(&self, x: &[f64], y: &[f64]) -> Result<Vec<f64>> {
144 let mut values = Vec::with_capacity(self.base_kernels.len());
145 for kernel in &self.base_kernels {
146 values.push(kernel.compute(x, y)?);
147 }
148 Ok(values)
149 }
150
151 pub fn evaluate(&self, x: &[f64], y: &[f64]) -> Result<f64> {
153 let weights = self.weights();
154 let mut acc = 0.0;
155 for (kernel, &w) in self.base_kernels.iter().zip(weights.iter()) {
156 acc += w * kernel.compute(x, y)?;
157 }
158 Ok(acc)
159 }
160
161 pub fn gradient_wrt_logits(&self, x: &[f64], y: &[f64]) -> Result<Vec<f64>> {
166 let weights = self.weights();
167 let k_vals = self.per_kernel_values(x, y)?;
168 let k_mix: f64 = weights.iter().zip(k_vals.iter()).map(|(p, k)| p * k).sum();
169 Ok(weights
170 .iter()
171 .zip(k_vals.iter())
172 .map(|(p, k)| p * (k - k_mix))
173 .collect())
174 }
175
176 pub fn evaluate_with_gradient(&self, x: &[f64], y: &[f64]) -> Result<(f64, Vec<f64>)> {
179 let weights = self.weights();
180 let k_vals = self.per_kernel_values(x, y)?;
181 let k_mix: f64 = weights.iter().zip(k_vals.iter()).map(|(p, k)| p * k).sum();
182 let grad: Vec<f64> = weights
183 .iter()
184 .zip(k_vals.iter())
185 .map(|(p, k)| p * (k - k_mix))
186 .collect();
187 Ok((k_mix, grad))
188 }
189
190 pub fn compute_gram(&self, xs: &[&[f64]], ys: &[&[f64]]) -> Result<Vec<Vec<f64>>> {
194 let mut matrix = vec![vec![0.0; ys.len()]; xs.len()];
195 for (i, &xi) in xs.iter().enumerate() {
196 for (j, &yj) in ys.iter().enumerate() {
197 matrix[i][j] = self.evaluate(xi, yj)?;
198 }
199 }
200 Ok(matrix)
201 }
202}
203
204impl fmt::Debug for LearnedMixtureKernel {
205 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
206 let names: Vec<&str> = self.base_kernels.iter().map(|k| k.name()).collect();
207 f.debug_struct("LearnedMixtureKernel")
208 .field("base_kernels", &names)
209 .field("logits", &self.logits)
210 .finish()
211 }
212}
213
214impl Kernel for LearnedMixtureKernel {
215 fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
216 self.evaluate(x, y)
217 }
218
219 fn name(&self) -> &str {
220 "LearnedMixture"
221 }
222
223 fn is_psd(&self) -> bool {
224 self.base_kernels.iter().all(|k| k.is_psd())
227 }
228}
229
230pub(crate) fn softmax(logits: &[f64]) -> Vec<f64> {
232 if logits.is_empty() {
233 return Vec::new();
234 }
235 let max = logits.iter().copied().fold(f64::NEG_INFINITY, f64::max);
236 let shifted: Vec<f64> = logits.iter().map(|&w| (w - max).exp()).collect();
238 let denom: f64 = shifted.iter().sum();
239 if denom <= 0.0 || !denom.is_finite() {
240 let n = logits.len() as f64;
243 return vec![1.0 / n; logits.len()];
244 }
245 shifted.iter().map(|&e| e / denom).collect()
246}