Skip to main content

tsetlin_rs/
error.rs

1//! Error types for Tsetlin Machine.
2
3use core::fmt;
4
5/// # Overview
6///
7/// Errors that can occur when building or using a Tsetlin Machine.
8#[derive(Debug, Clone, PartialEq, Eq)]
9pub enum Error {
10    MissingClauses,
11    MissingFeatures,
12    OddClauses,
13    InvalidSpecificity,
14    InvalidThreshold,
15    EmptyDataset,
16    DimensionMismatch { expected: usize, got: usize }
17}
18
19impl fmt::Display for Error {
20    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21        match self {
22            Self::MissingClauses => write!(f, "n_clauses is required"),
23            Self::MissingFeatures => write!(f, "n_features is required"),
24            Self::OddClauses => write!(f, "n_clauses must be even"),
25            Self::InvalidSpecificity => write!(f, "s must be > 1.0"),
26            Self::InvalidThreshold => write!(f, "threshold must be > 0"),
27            Self::EmptyDataset => write!(f, "dataset cannot be empty"),
28            Self::DimensionMismatch {
29                expected,
30                got
31            } => {
32                write!(f, "dimension mismatch: expected {expected}, got {got}")
33            }
34        }
35    }
36}
37
38#[cfg(feature = "std")]
39impl std::error::Error for Error {}
40
41/// # Overview
42///
43/// Result type for Tsetlin Machine operations.
44pub type Result<T> = core::result::Result<T, Error>;
45
46#[cfg(test)]
47mod tests {
48    use super::*;
49
50    #[test]
51    fn error_display_all_variants() {
52        assert_eq!(Error::MissingClauses.to_string(), "n_clauses is required");
53        assert_eq!(Error::MissingFeatures.to_string(), "n_features is required");
54        assert_eq!(Error::OddClauses.to_string(), "n_clauses must be even");
55        assert_eq!(Error::InvalidSpecificity.to_string(), "s must be > 1.0");
56        assert_eq!(Error::InvalidThreshold.to_string(), "threshold must be > 0");
57        assert_eq!(Error::EmptyDataset.to_string(), "dataset cannot be empty");
58        assert_eq!(
59            Error::DimensionMismatch {
60                expected: 10,
61                got:      5
62            }
63            .to_string(),
64            "dimension mismatch: expected 10, got 5"
65        );
66    }
67
68    #[test]
69    fn error_debug() {
70        let err = Error::DimensionMismatch {
71            expected: 100,
72            got:      50
73        };
74        let debug_str = format!("{:?}", err);
75        assert!(debug_str.contains("DimensionMismatch"));
76        assert!(debug_str.contains("100"));
77        assert!(debug_str.contains("50"));
78    }
79
80    #[test]
81    fn error_eq() {
82        assert_eq!(Error::MissingClauses, Error::MissingClauses);
83        assert_ne!(Error::MissingClauses, Error::MissingFeatures);
84        assert_eq!(
85            Error::DimensionMismatch {
86                expected: 5,
87                got:      3
88            },
89            Error::DimensionMismatch {
90                expected: 5,
91                got:      3
92            }
93        );
94        assert_ne!(
95            Error::DimensionMismatch {
96                expected: 5,
97                got:      3
98            },
99            Error::DimensionMismatch {
100                expected: 5,
101                got:      4
102            }
103        );
104    }
105
106    #[test]
107    fn error_clone() {
108        let err = Error::DimensionMismatch {
109            expected: 10,
110            got:      5
111        };
112        let cloned = err.clone();
113        assert_eq!(err, cloned);
114    }
115}