1#![allow(clippy::needless_range_loop)]
2use crate::error::{KernelError, Result};
41use crate::tensor_kernels::{
42 LaplacianKernel, MaternKernel, PolynomialKernel, RationalQuadraticKernel, RbfKernel,
43};
44
45#[derive(Debug, Clone)]
47pub struct KernelGradientMatrix {
48 pub kernel_matrix: Vec<Vec<f64>>,
50 pub gradients: Vec<GradientComponent>,
52}
53
54#[derive(Debug, Clone)]
56pub struct GradientComponent {
57 pub name: String,
59 pub matrix: Vec<Vec<f64>>,
61}
62
63impl KernelGradientMatrix {
64 pub fn new(kernel_matrix: Vec<Vec<f64>>, gradients: Vec<GradientComponent>) -> Self {
66 Self {
67 kernel_matrix,
68 gradients,
69 }
70 }
71
72 pub fn n_samples(&self) -> usize {
74 self.kernel_matrix.len()
75 }
76
77 pub fn get_gradient(&self, name: &str) -> Option<&Vec<Vec<f64>>> {
79 self.gradients
80 .iter()
81 .find(|g| g.name == name)
82 .map(|g| &g.matrix)
83 }
84
85 pub fn gradient_names(&self) -> Vec<&str> {
87 self.gradients.iter().map(|g| g.name.as_str()).collect()
88 }
89}
90
91#[derive(Debug, Clone)]
93pub struct RbfGradientResult {
94 pub kernel_matrix: Vec<Vec<f64>>,
96 pub gradient_gamma: Vec<Vec<f64>>,
98 pub gradient_length_scale: Vec<Vec<f64>>,
100}
101
102pub fn compute_rbf_gradient_matrix(
120 kernel: &RbfKernel,
121 data: &[Vec<f64>],
122) -> Result<RbfGradientResult> {
123 let n = data.len();
124 if n == 0 {
125 return Err(KernelError::ComputationError("Empty data".to_string()));
126 }
127
128 let mut kernel_matrix = vec![vec![0.0; n]; n];
129 let mut gradient_gamma = vec![vec![0.0; n]; n];
130 let mut gradient_length_scale = vec![vec![0.0; n]; n];
131
132 for i in 0..n {
133 for j in i..n {
134 let (k, grad_gamma) = kernel.compute_with_gradient(&data[i], &data[j])?;
135 let (_, grad_ls) = kernel.compute_with_length_scale_gradient(&data[i], &data[j])?;
136
137 kernel_matrix[i][j] = k;
138 kernel_matrix[j][i] = k;
139
140 gradient_gamma[i][j] = grad_gamma;
141 gradient_gamma[j][i] = grad_gamma;
142
143 gradient_length_scale[i][j] = grad_ls;
144 gradient_length_scale[j][i] = grad_ls;
145 }
146 }
147
148 Ok(RbfGradientResult {
149 kernel_matrix,
150 gradient_gamma,
151 gradient_length_scale,
152 })
153}
154
155#[derive(Debug, Clone)]
157pub struct PolynomialGradientResult {
158 pub kernel_matrix: Vec<Vec<f64>>,
160 pub gradient_constant: Vec<Vec<f64>>,
162 pub gradient_degree: Vec<Vec<f64>>,
164}
165
166pub fn compute_polynomial_gradient_matrix(
168 kernel: &PolynomialKernel,
169 data: &[Vec<f64>],
170) -> Result<PolynomialGradientResult> {
171 let n = data.len();
172 if n == 0 {
173 return Err(KernelError::ComputationError("Empty data".to_string()));
174 }
175
176 let mut kernel_matrix = vec![vec![0.0; n]; n];
177 let mut gradient_constant = vec![vec![0.0; n]; n];
178 let mut gradient_degree = vec![vec![0.0; n]; n];
179
180 for i in 0..n {
181 for j in i..n {
182 let (k, grad_c, grad_d) = kernel.compute_with_all_gradients(&data[i], &data[j])?;
183
184 kernel_matrix[i][j] = k;
185 kernel_matrix[j][i] = k;
186
187 gradient_constant[i][j] = grad_c;
188 gradient_constant[j][i] = grad_c;
189
190 let grad_d_safe = if grad_d.is_nan() { 0.0 } else { grad_d };
192 gradient_degree[i][j] = grad_d_safe;
193 gradient_degree[j][i] = grad_d_safe;
194 }
195 }
196
197 Ok(PolynomialGradientResult {
198 kernel_matrix,
199 gradient_constant,
200 gradient_degree,
201 })
202}
203
204#[derive(Debug, Clone)]
206pub struct MaternGradientResult {
207 pub kernel_matrix: Vec<Vec<f64>>,
209 pub gradient_length_scale: Vec<Vec<f64>>,
211}
212
213pub fn compute_matern_gradient_matrix(
215 kernel: &MaternKernel,
216 data: &[Vec<f64>],
217) -> Result<MaternGradientResult> {
218 let n = data.len();
219 if n == 0 {
220 return Err(KernelError::ComputationError("Empty data".to_string()));
221 }
222
223 let mut kernel_matrix = vec![vec![0.0; n]; n];
224 let mut gradient_length_scale = vec![vec![0.0; n]; n];
225
226 for i in 0..n {
227 for j in i..n {
228 let (k, grad_l) = kernel.compute_with_length_scale_gradient(&data[i], &data[j])?;
229
230 kernel_matrix[i][j] = k;
231 kernel_matrix[j][i] = k;
232
233 gradient_length_scale[i][j] = grad_l;
234 gradient_length_scale[j][i] = grad_l;
235 }
236 }
237
238 Ok(MaternGradientResult {
239 kernel_matrix,
240 gradient_length_scale,
241 })
242}
243
244#[derive(Debug, Clone)]
246pub struct LaplacianGradientResult {
247 pub kernel_matrix: Vec<Vec<f64>>,
249 pub gradient_gamma: Vec<Vec<f64>>,
251 pub gradient_sigma: Vec<Vec<f64>>,
253}
254
255pub fn compute_laplacian_gradient_matrix(
257 kernel: &LaplacianKernel,
258 data: &[Vec<f64>],
259) -> Result<LaplacianGradientResult> {
260 let n = data.len();
261 if n == 0 {
262 return Err(KernelError::ComputationError("Empty data".to_string()));
263 }
264
265 let mut kernel_matrix = vec![vec![0.0; n]; n];
266 let mut gradient_gamma = vec![vec![0.0; n]; n];
267 let mut gradient_sigma = vec![vec![0.0; n]; n];
268
269 for i in 0..n {
270 for j in i..n {
271 let (k, grad_g) = kernel.compute_with_gradient(&data[i], &data[j])?;
272 let (_, grad_s) = kernel.compute_with_sigma_gradient(&data[i], &data[j])?;
273
274 kernel_matrix[i][j] = k;
275 kernel_matrix[j][i] = k;
276
277 gradient_gamma[i][j] = grad_g;
278 gradient_gamma[j][i] = grad_g;
279
280 gradient_sigma[i][j] = grad_s;
281 gradient_sigma[j][i] = grad_s;
282 }
283 }
284
285 Ok(LaplacianGradientResult {
286 kernel_matrix,
287 gradient_gamma,
288 gradient_sigma,
289 })
290}
291
292#[derive(Debug, Clone)]
294pub struct RationalQuadraticGradientResult {
295 pub kernel_matrix: Vec<Vec<f64>>,
297 pub gradient_length_scale: Vec<Vec<f64>>,
299 pub gradient_alpha: Vec<Vec<f64>>,
301}
302
303pub fn compute_rational_quadratic_gradient_matrix(
305 kernel: &RationalQuadraticKernel,
306 data: &[Vec<f64>],
307) -> Result<RationalQuadraticGradientResult> {
308 let n = data.len();
309 if n == 0 {
310 return Err(KernelError::ComputationError("Empty data".to_string()));
311 }
312
313 let mut kernel_matrix = vec![vec![0.0; n]; n];
314 let mut gradient_length_scale = vec![vec![0.0; n]; n];
315 let mut gradient_alpha = vec![vec![0.0; n]; n];
316
317 for i in 0..n {
318 for j in i..n {
319 let (k, grad_l, grad_a) = kernel.compute_with_all_gradients(&data[i], &data[j])?;
320
321 kernel_matrix[i][j] = k;
322 kernel_matrix[j][i] = k;
323
324 gradient_length_scale[i][j] = grad_l;
325 gradient_length_scale[j][i] = grad_l;
326
327 gradient_alpha[i][j] = grad_a;
328 gradient_alpha[j][i] = grad_a;
329 }
330 }
331
332 Ok(RationalQuadraticGradientResult {
333 kernel_matrix,
334 gradient_length_scale,
335 gradient_alpha,
336 })
337}
338
339pub fn compute_generic_gradient_matrix<F>(
343 data: &[Vec<f64>],
344 kernel_fn: F,
345 gradient_names: Vec<String>,
346) -> Result<KernelGradientMatrix>
347where
348 F: Fn(&[f64], &[f64]) -> Result<(f64, Vec<f64>)>,
349{
350 let n = data.len();
351 if n == 0 {
352 return Err(KernelError::ComputationError("Empty data".to_string()));
353 }
354
355 let n_params = gradient_names.len();
356 let mut kernel_matrix = vec![vec![0.0; n]; n];
357 let mut gradient_matrices: Vec<Vec<Vec<f64>>> =
358 (0..n_params).map(|_| vec![vec![0.0; n]; n]).collect();
359
360 for i in 0..n {
361 for j in i..n {
362 let (k, grads) = kernel_fn(&data[i], &data[j])?;
363
364 if grads.len() != n_params {
365 return Err(KernelError::ComputationError(format!(
366 "Expected {} gradients, got {}",
367 n_params,
368 grads.len()
369 )));
370 }
371
372 kernel_matrix[i][j] = k;
373 kernel_matrix[j][i] = k;
374
375 for (p, grad) in grads.iter().enumerate() {
376 gradient_matrices[p][i][j] = *grad;
377 gradient_matrices[p][j][i] = *grad;
378 }
379 }
380 }
381
382 let gradients = gradient_names
383 .into_iter()
384 .zip(gradient_matrices)
385 .map(|(name, matrix)| GradientComponent { name, matrix })
386 .collect();
387
388 Ok(KernelGradientMatrix::new(kernel_matrix, gradients))
389}
390
391pub fn trace_product(a: &[Vec<f64>], b: &[Vec<f64>]) -> Result<f64> {
403 let n = a.len();
404 if n == 0 {
405 return Ok(0.0);
406 }
407 if b.len() != n {
408 return Err(KernelError::DimensionMismatch {
409 expected: vec![n],
410 got: vec![b.len()],
411 context: "trace_product matrix dimensions".to_string(),
412 });
413 }
414
415 let mut trace = 0.0;
417 for i in 0..n {
418 if a[i].len() != n || b[i].len() != n {
419 return Err(KernelError::DimensionMismatch {
420 expected: vec![n],
421 got: vec![a[i].len()],
422 context: "trace_product row dimension".to_string(),
423 });
424 }
425 for j in 0..n {
426 trace += a[i][j] * b[j][i];
427 }
428 }
429
430 Ok(trace)
431}
432
433pub fn frobenius_norm(matrix: &[Vec<f64>]) -> f64 {
437 let sum_sq: f64 = matrix
438 .iter()
439 .flat_map(|row| row.iter())
440 .map(|x| x * x)
441 .sum();
442 sum_sq.sqrt()
443}
444
445pub fn is_symmetric(matrix: &[Vec<f64>], tolerance: f64) -> bool {
447 let n = matrix.len();
448 for i in 0..n {
449 for j in i + 1..n {
450 if (matrix[i][j] - matrix[j][i]).abs() > tolerance {
451 return false;
452 }
453 }
454 }
455 true
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461 use crate::types::RbfKernelConfig;
462
463 #[test]
464 fn test_rbf_gradient_matrix() {
465 let kernel = RbfKernel::new(RbfKernelConfig::new(0.5)).unwrap();
466 let data = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
467
468 let result = compute_rbf_gradient_matrix(&kernel, &data).unwrap();
469
470 assert_eq!(result.kernel_matrix.len(), 3);
472 assert_eq!(result.gradient_gamma.len(), 3);
473 assert_eq!(result.gradient_length_scale.len(), 3);
474
475 assert!(is_symmetric(&result.kernel_matrix, 1e-10));
477 assert!(is_symmetric(&result.gradient_gamma, 1e-10));
478 assert!(is_symmetric(&result.gradient_length_scale, 1e-10));
479
480 for i in 0..3 {
482 assert!((result.kernel_matrix[i][i] - 1.0).abs() < 1e-10);
483 }
484
485 for i in 0..3 {
487 assert!(result.gradient_gamma[i][i].abs() < 1e-10);
488 }
489 }
490
491 #[test]
492 fn test_polynomial_gradient_matrix() {
493 let kernel = PolynomialKernel::new(2, 1.0).unwrap();
494 let data = vec![vec![1.0, 2.0], vec![2.0, 3.0]];
495
496 let result = compute_polynomial_gradient_matrix(&kernel, &data).unwrap();
497
498 assert_eq!(result.kernel_matrix.len(), 2);
499 assert!(is_symmetric(&result.kernel_matrix, 1e-10));
500 assert!(is_symmetric(&result.gradient_constant, 1e-10));
501 }
502
503 #[test]
504 fn test_matern_gradient_matrix() {
505 let kernel = MaternKernel::nu_3_2(1.0).unwrap();
506 let data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
507
508 let result = compute_matern_gradient_matrix(&kernel, &data).unwrap();
509
510 assert_eq!(result.kernel_matrix.len(), 2);
511 assert!(is_symmetric(&result.kernel_matrix, 1e-10));
512 assert!(is_symmetric(&result.gradient_length_scale, 1e-10));
513 }
514
515 #[test]
516 fn test_laplacian_gradient_matrix() {
517 let kernel = LaplacianKernel::new(0.5).unwrap();
518 let data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
519
520 let result = compute_laplacian_gradient_matrix(&kernel, &data).unwrap();
521
522 assert_eq!(result.kernel_matrix.len(), 2);
523 assert!(is_symmetric(&result.kernel_matrix, 1e-10));
524 assert!(is_symmetric(&result.gradient_gamma, 1e-10));
525 assert!(is_symmetric(&result.gradient_sigma, 1e-10));
526 }
527
528 #[test]
529 fn test_rational_quadratic_gradient_matrix() {
530 let kernel = RationalQuadraticKernel::new(1.0, 2.0).unwrap();
531 let data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
532
533 let result = compute_rational_quadratic_gradient_matrix(&kernel, &data).unwrap();
534
535 assert_eq!(result.kernel_matrix.len(), 2);
536 assert!(is_symmetric(&result.kernel_matrix, 1e-10));
537 assert!(is_symmetric(&result.gradient_length_scale, 1e-10));
538 assert!(is_symmetric(&result.gradient_alpha, 1e-10));
539 }
540
541 #[test]
542 fn test_trace_product() {
543 let a = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
544 let b = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
545
546 let trace = trace_product(&a, &b).unwrap();
547
548 assert!((trace - 69.0).abs() < 1e-10);
551 }
552
553 #[test]
554 fn test_frobenius_norm() {
555 let matrix = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
556
557 let norm = frobenius_norm(&matrix);
558
559 let expected = 30.0_f64.sqrt();
561 assert!((norm - expected).abs() < 1e-10);
562 }
563
564 #[test]
565 fn test_is_symmetric() {
566 let symmetric = vec![vec![1.0, 2.0], vec![2.0, 4.0]];
567 let asymmetric = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
568
569 assert!(is_symmetric(&symmetric, 1e-10));
570 assert!(!is_symmetric(&asymmetric, 1e-10));
571 }
572
573 #[test]
574 fn test_kernel_gradient_matrix_accessors() {
575 let kernel = RbfKernel::new(RbfKernelConfig::new(0.5)).unwrap();
576 let data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
577
578 let rbf_result = compute_rbf_gradient_matrix(&kernel, &data).unwrap();
579
580 let gradients = vec![
582 GradientComponent {
583 name: "gamma".to_string(),
584 matrix: rbf_result.gradient_gamma,
585 },
586 GradientComponent {
587 name: "length_scale".to_string(),
588 matrix: rbf_result.gradient_length_scale,
589 },
590 ];
591 let result = KernelGradientMatrix::new(rbf_result.kernel_matrix, gradients);
592
593 assert_eq!(result.n_samples(), 2);
594 assert_eq!(result.gradient_names(), vec!["gamma", "length_scale"]);
595 assert!(result.get_gradient("gamma").is_some());
596 assert!(result.get_gradient("length_scale").is_some());
597 assert!(result.get_gradient("nonexistent").is_none());
598 }
599
600 #[test]
601 fn test_empty_data() {
602 let kernel = RbfKernel::new(RbfKernelConfig::new(0.5)).unwrap();
603 let data: Vec<Vec<f64>> = vec![];
604
605 let result = compute_rbf_gradient_matrix(&kernel, &data);
606 assert!(result.is_err());
607 }
608
609 #[test]
610 fn test_single_point() {
611 let kernel = RbfKernel::new(RbfKernelConfig::new(0.5)).unwrap();
612 let data = vec![vec![1.0, 2.0]];
613
614 let result = compute_rbf_gradient_matrix(&kernel, &data).unwrap();
615
616 assert_eq!(result.kernel_matrix.len(), 1);
618 assert!((result.kernel_matrix[0][0] - 1.0).abs() < 1e-10);
619 assert!(result.gradient_gamma[0][0].abs() < 1e-10);
620 }
621
622 #[test]
623 fn test_gradient_consistency_with_element_wise() {
624 let kernel = RbfKernel::new(RbfKernelConfig::new(0.5)).unwrap();
626 let data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
627
628 let result = compute_rbf_gradient_matrix(&kernel, &data).unwrap();
629
630 for i in 0..2 {
632 for j in 0..2 {
633 let (k, grad_g) = kernel.compute_with_gradient(&data[i], &data[j]).unwrap();
634 assert!(
635 (result.kernel_matrix[i][j] - k).abs() < 1e-10,
636 "K[{},{}] mismatch",
637 i,
638 j
639 );
640 assert!(
641 (result.gradient_gamma[i][j] - grad_g).abs() < 1e-10,
642 "dK/dγ[{},{}] mismatch",
643 i,
644 j
645 );
646 }
647 }
648 }
649}