1use std::io;
4use thiserror::Error;
5
6#[derive(Error, Debug)]
8pub enum DatasetsError {
9 #[error("Invalid format: {0}")]
11 InvalidFormat(String),
12
13 #[error("Loading error: {0}")]
15 LoadingError(String),
16
17 #[error("Format error: {0}")]
19 FormatError(String),
20
21 #[error("Not found: {0}")]
23 NotFound(String),
24
25 #[error("Authentication error: {0}")]
27 AuthenticationError(String),
28
29 #[error("Download error: {0}")]
31 DownloadError(String),
32
33 #[error("Cache error: {0}")]
35 CacheError(String),
36
37 #[error("IO error: {0}")]
39 IoError(#[from] io::Error),
40
41 #[error("Serialization error: {0}")]
43 SerdeError(String),
44
45 #[error("GPU error: {0}")]
47 GpuError(String),
48
49 #[error("Computation error: {0}")]
51 ComputationError(String),
52
53 #[error("Validation error: {0}")]
55 ValidationError(String),
56
57 #[error("Error: {0}")]
59 Other(String),
60}
61
62impl PartialEq for DatasetsError {
63 fn eq(&self, other: &Self) -> bool {
64 match (self, other) {
65 (DatasetsError::InvalidFormat(a), DatasetsError::InvalidFormat(b)) => a == b,
66 (DatasetsError::LoadingError(a), DatasetsError::LoadingError(b)) => a == b,
67 (DatasetsError::FormatError(a), DatasetsError::FormatError(b)) => a == b,
68 (DatasetsError::NotFound(a), DatasetsError::NotFound(b)) => a == b,
69 (DatasetsError::AuthenticationError(a), DatasetsError::AuthenticationError(b)) => {
70 a == b
71 }
72 (DatasetsError::DownloadError(a), DatasetsError::DownloadError(b)) => a == b,
73 (DatasetsError::CacheError(a), DatasetsError::CacheError(b)) => a == b,
74 (DatasetsError::IoError(a), DatasetsError::IoError(b)) => {
75 a.kind() == b.kind() && a.to_string() == b.to_string()
77 }
78 (DatasetsError::SerdeError(a), DatasetsError::SerdeError(b)) => a == b,
79 (DatasetsError::GpuError(a), DatasetsError::GpuError(b)) => a == b,
80 (DatasetsError::ComputationError(a), DatasetsError::ComputationError(b)) => a == b,
81 (DatasetsError::ValidationError(a), DatasetsError::ValidationError(b)) => a == b,
82 (DatasetsError::Other(a), DatasetsError::Other(b)) => a == b,
83 _ => false,
84 }
85 }
86}
87
88pub type Result<T> = std::result::Result<T, DatasetsError>;
90
91#[cfg(test)]
92mod tests {
93 use super::*;
94 use std::io;
95
96 #[test]
97 fn test_invalid_format_error() {
98 let error = DatasetsError::InvalidFormat("test format".to_string());
99 assert_eq!(error.to_string(), "Invalid format: test format");
100 }
101
102 #[test]
103 fn test_loading_error() {
104 let error = DatasetsError::LoadingError("test loading".to_string());
105 assert_eq!(error.to_string(), "Loading error: test loading");
106 }
107
108 #[test]
109 fn test_download_error() {
110 let error = DatasetsError::DownloadError("test download".to_string());
111 assert_eq!(error.to_string(), "Download error: test download");
112 }
113
114 #[test]
115 fn test_cache_error() {
116 let error = DatasetsError::CacheError("test cache".to_string());
117 assert_eq!(error.to_string(), "Cache error: test cache");
118 }
119
120 #[test]
121 fn test_io_error_conversion() {
122 let io_error = io::Error::new(io::ErrorKind::NotFound, "file not found");
123 let datasets_error: DatasetsError = io_error.into();
124
125 match datasets_error {
126 DatasetsError::IoError(_) => {
127 assert!(datasets_error.to_string().contains("file not found"));
128 }
129 _ => panic!("Expected IoError variant"),
130 }
131 }
132
133 #[test]
134 fn test_serde_error() {
135 let error = DatasetsError::SerdeError("serialization failed".to_string());
136 assert_eq!(
137 error.to_string(),
138 "Serialization error: serialization failed"
139 );
140 }
141
142 #[test]
143 fn test_gpu_error() {
144 let error = DatasetsError::GpuError("CUDA initialization failed".to_string());
145 assert_eq!(error.to_string(), "GPU error: CUDA initialization failed");
146 }
147
148 #[test]
149 fn test_other_error() {
150 let error = DatasetsError::Other("generic error".to_string());
151 assert_eq!(error.to_string(), "Error: generic error");
152 }
153
154 #[test]
155 fn test_error_debug_format() {
156 let error = DatasetsError::InvalidFormat("debug test".to_string());
157 let debug_str = format!("{error:?}");
158 assert!(debug_str.contains("InvalidFormat"));
159 assert!(debug_str.contains("debug test"));
160 }
161
162 #[test]
163 fn test_result_type() {
164 let ok_result: Result<i32> = Ok(42);
166 assert_eq!(ok_result, Ok(42));
167
168 let err_result: Result<i32> = Err(DatasetsError::Other("test".to_string()));
170 assert!(err_result.is_err());
171 }
172
173 #[test]
174 fn test_error_from_io_error() {
175 let io_err = io::Error::new(io::ErrorKind::PermissionDenied, "access denied");
176 let datasets_err = DatasetsError::from(io_err);
177
178 if let DatasetsError::IoError(ref inner) = datasets_err {
179 assert_eq!(inner.kind(), io::ErrorKind::PermissionDenied);
180 } else {
181 panic!("Expected IoError variant");
182 }
183 }
184
185 #[test]
186 fn test_error_chain() {
187 let error = DatasetsError::LoadingError("failed to parse CSV".to_string());
189 let result: Result<()> = Err(error);
190
191 match result {
192 Ok(_) => panic!("Expected error"),
193 Err(e) => {
194 assert_eq!(e.to_string(), "Loading error: failed to parse CSV");
195 }
196 }
197 }
198}