tensorlogic_sklears_kernels/
error.rs1use std::fmt;
4
5#[derive(Debug, Clone, PartialEq)]
7pub enum KernelError {
8 DimensionMismatch {
10 expected: Vec<usize>,
11 got: Vec<usize>,
12 context: String,
13 },
14 InvalidParameter {
16 parameter: String,
17 value: String,
18 reason: String,
19 },
20 ComputationError(String),
22 InvalidExpression(String),
24 IncompatibleKernels {
26 kernel_a: String,
27 kernel_b: String,
28 reason: String,
29 },
30}
31
32impl fmt::Display for KernelError {
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 match self {
35 Self::DimensionMismatch {
36 expected,
37 got,
38 context,
39 } => write!(
40 f,
41 "Dimension mismatch in {}: expected {:?}, got {:?}",
42 context, expected, got
43 ),
44 Self::InvalidParameter {
45 parameter,
46 value,
47 reason,
48 } => write!(
49 f,
50 "Invalid parameter '{}' = '{}': {}",
51 parameter, value, reason
52 ),
53 Self::ComputationError(msg) => write!(f, "Kernel computation error: {}", msg),
54 Self::InvalidExpression(msg) => write!(f, "Invalid expression for kernel: {}", msg),
55 Self::IncompatibleKernels {
56 kernel_a,
57 kernel_b,
58 reason,
59 } => write!(
60 f,
61 "Incompatible kernels '{}' and '{}': {}",
62 kernel_a, kernel_b, reason
63 ),
64 }
65 }
66}
67
68impl std::error::Error for KernelError {}
69
70impl From<tensorlogic_ir::IrError> for KernelError {
72 fn from(err: tensorlogic_ir::IrError) -> Self {
73 KernelError::InvalidExpression(err.to_string())
74 }
75}
76
77pub type Result<T> = std::result::Result<T, KernelError>;
79
80#[cfg(test)]
81mod tests {
82 use super::*;
83
84 #[test]
85 fn test_dimension_mismatch_display() {
86 let err = KernelError::DimensionMismatch {
87 expected: vec![10, 20],
88 got: vec![10, 30],
89 context: "kernel matrix".to_string(),
90 };
91 let msg = err.to_string();
92 assert!(msg.contains("10, 20"));
93 assert!(msg.contains("10, 30"));
94 }
95
96 #[test]
97 fn test_invalid_parameter_display() {
98 let err = KernelError::InvalidParameter {
99 parameter: "gamma".to_string(),
100 value: "-1.0".to_string(),
101 reason: "must be positive".to_string(),
102 };
103 let msg = err.to_string();
104 assert!(msg.contains("gamma"));
105 assert!(msg.contains("-1.0"));
106 assert!(msg.contains("positive"));
107 }
108}