Skip to main content

sh_layer1/
error_handler.rs

1//! 错误处理模块
2//!
3//! 统一的错误类型和结果类型定义。
4
5use thiserror::Error;
6
7/// Continuum 统一错误类型
8#[derive(Debug, Error)]
9pub enum ShError {
10    #[error("Layer 0 error: {0}")]
11    Layer0(String),
12
13    #[error("Configuration error: {0}")]
14    Config(String),
15
16    #[error("IO error: {0}")]
17    Io(#[from] std::io::Error),
18
19    #[error("Serialization error: {0}")]
20    Serde(#[from] serde_json::Error),
21
22    #[error("LLM API error: {0}")]
23    LlmApi(String),
24
25    #[error("Session error: {0}")]
26    Session(String),
27
28    #[error("Timeout error after {seconds} seconds")]
29    Timeout { seconds: u64 },
30
31    #[error("Not found: {resource}")]
32    NotFound { resource: String },
33
34    #[error("Rate limited")]
35    RateLimited,
36
37    #[error("Internal error: {0}")]
38    Internal(String),
39}
40
41/// 从 anyhow::Error 转换
42impl From<anyhow::Error> for ShError {
43    fn from(e: anyhow::Error) -> Self {
44        ShError::Internal(e.to_string())
45    }
46}
47
48/// Continuum 统一结果类型
49pub type ShResult<T> = std::result::Result<T, ShError>;
50
51/// 错误处理器(用于集中处理错误)
52pub struct ErrorHandler {
53    /// 是否启用日志
54    log_errors: bool,
55}
56
57impl ErrorHandler {
58    pub fn new() -> Self {
59        Self { log_errors: true }
60    }
61
62    /// 处理错误,返回用户友好的消息
63    pub fn handle(&self, error: &ShError) -> String {
64        if self.log_errors {
65            tracing::error!("Error: {:?}", error);
66        }
67        error.to_string()
68    }
69
70    /// 将错误转换为用户消息
71    pub fn to_user_message(&self, error: &ShError) -> String {
72        match error {
73            ShError::Timeout { seconds } => format!("操作超时,请重试(已等待 {} 秒)", seconds),
74            ShError::RateLimited => "请求过于频繁,请稍后再试".to_string(),
75            ShError::NotFound { resource } => format!("找不到资源: {}", resource),
76            _ => format!("发生错误: {}", error),
77        }
78    }
79}
80
81impl Default for ErrorHandler {
82    fn default() -> Self {
83        Self::new()
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90
91    #[test]
92    fn test_error_display() {
93        let err = ShError::Config("invalid config".to_string());
94        assert!(err.to_string().contains("Configuration error"));
95        assert!(err.to_string().contains("invalid config"));
96    }
97
98    #[test]
99    fn test_layer0_error() {
100        let err = ShError::Layer0("security violation".to_string());
101        assert!(err.to_string().contains("Layer 0 error"));
102        assert!(err.to_string().contains("security violation"));
103    }
104
105    #[test]
106    fn test_timeout_error() {
107        let err = ShError::Timeout { seconds: 30 };
108        assert!(err.to_string().contains("30"));
109        assert!(err.to_string().contains("Timeout"));
110    }
111
112    #[test]
113    fn test_not_found_error() {
114        let err = ShError::NotFound {
115            resource: "session".to_string(),
116        };
117        assert!(err.to_string().contains("Not found"));
118        assert!(err.to_string().contains("session"));
119    }
120
121    #[test]
122    fn test_rate_limited_error() {
123        let err = ShError::RateLimited;
124        assert!(err.to_string().contains("Rate limited"));
125    }
126
127    #[test]
128    fn test_internal_error() {
129        let err = ShError::Internal("unexpected error".to_string());
130        assert!(err.to_string().contains("Internal error"));
131    }
132
133    #[test]
134    fn test_llm_api_error() {
135        let err = ShError::LlmApi("API failed".to_string());
136        assert!(err.to_string().contains("LLM API error"));
137    }
138
139    #[test]
140    fn test_session_error() {
141        let err = ShError::Session("session expired".to_string());
142        assert!(err.to_string().contains("Session error"));
143    }
144
145    #[test]
146    fn test_error_handler_handle() {
147        let handler = ErrorHandler::new();
148        let err = ShError::Config("test".to_string());
149        let msg = handler.handle(&err);
150        assert!(msg.contains("Configuration error"));
151    }
152
153    #[test]
154    fn test_error_handler_to_user_message_timeout() {
155        let handler = ErrorHandler::new();
156        let err = ShError::Timeout { seconds: 60 };
157        let msg = handler.to_user_message(&err);
158        assert!(msg.contains("60"));
159        assert!(msg.contains("超时"));
160    }
161
162    #[test]
163    fn test_error_handler_to_user_message_rate_limited() {
164        let handler = ErrorHandler::new();
165        let err = ShError::RateLimited;
166        let msg = handler.to_user_message(&err);
167        assert!(msg.contains("频繁"));
168    }
169
170    #[test]
171    fn test_error_handler_to_user_message_not_found() {
172        let handler = ErrorHandler::new();
173        let err = ShError::NotFound {
174            resource: "file.txt".to_string(),
175        };
176        let msg = handler.to_user_message(&err);
177        assert!(msg.contains("file.txt"));
178    }
179
180    #[test]
181    fn test_error_handler_without_logging() {
182        let mut handler = ErrorHandler::new();
183        handler.log_errors = false;
184        let err = ShError::Internal("test".to_string());
185        let msg = handler.handle(&err);
186        assert!(!msg.is_empty());
187    }
188
189    #[test]
190    fn test_from_io_error() {
191        let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
192        let sh_err: ShError = io_err.into();
193        assert!(matches!(sh_err, ShError::Io(_)));
194    }
195
196    #[test]
197    fn test_from_serde_json_error() {
198        let json_err = serde_json::from_str::<i32>("not a number").unwrap_err();
199        let sh_err: ShError = json_err.into();
200        assert!(matches!(sh_err, ShError::Serde(_)));
201    }
202
203    #[test]
204    fn test_from_anyhow_error() {
205        let anyhow_err = anyhow::anyhow!("anyhow error");
206        let sh_err: ShError = anyhow_err.into();
207        assert!(matches!(sh_err, ShError::Internal(_)));
208    }
209
210    #[test]
211    #[allow(clippy::unnecessary_literal_unwrap)]
212    fn test_sh_result_ok() {
213        let result: ShResult<i32> = Ok(42);
214        assert!(result.is_ok());
215        assert_eq!(result.unwrap(), 42);
216    }
217
218    #[test]
219    fn test_sh_result_err() {
220        let result: ShResult<i32> = Err(ShError::NotFound {
221            resource: "test".to_string(),
222        });
223        assert!(result.is_err());
224    }
225
226    #[test]
227    fn test_error_handler_default() {
228        let handler = ErrorHandler::default();
229        let err = ShError::RateLimited;
230        let msg = handler.handle(&err);
231        assert!(!msg.is_empty());
232    }
233
234    // ========== 错误场景测试 ==========
235
236    #[test]
237    fn test_layer0_error_variants() {
238        // Layer0 安全错误
239        let err = ShError::Layer0("PII detected in input".to_string());
240        assert!(err.to_string().contains("Layer 0 error"));
241
242        let err = ShError::Layer0("Injection attempt blocked".to_string());
243        assert!(err.to_string().contains("Injection"));
244    }
245
246    #[test]
247    fn test_config_error_variants() {
248        let err = ShError::Config("Missing required field".to_string());
249        assert!(err.to_string().contains("Configuration error"));
250
251        let err = ShError::Config("Invalid API key format".to_string());
252        assert!(err.to_string().contains("Invalid API key"));
253    }
254
255    #[test]
256    fn test_timeout_boundary_values() {
257        // 最小超时
258        let err = ShError::Timeout { seconds: 0 };
259        assert!(err.to_string().contains("0"));
260
261        // 大超时值
262        let err = ShError::Timeout { seconds: u64::MAX };
263        assert!(err.to_string().contains(&u64::MAX.to_string()));
264
265        // 常见超时值
266        let err = ShError::Timeout { seconds: 30 };
267        assert!(err.to_string().contains("30"));
268    }
269
270    #[test]
271    fn test_not_found_variants() {
272        // 各种资源类型
273        let err = ShError::NotFound {
274            resource: "session".to_string(),
275        };
276        assert!(err.to_string().contains("session"));
277
278        let err = ShError::NotFound {
279            resource: "configuration file".to_string(),
280        };
281        assert!(err.to_string().contains("configuration file"));
282
283        let err = ShError::NotFound {
284            resource: "".to_string(),
285        };
286        assert!(err.to_string().contains("Not found"));
287    }
288
289    #[test]
290    fn test_llm_api_error_variants() {
291        let err = ShError::LlmApi("Rate limit exceeded".to_string());
292        assert!(err.to_string().contains("LLM API error"));
293
294        let err = ShError::LlmApi("Model not available".to_string());
295        assert!(err.to_string().contains("Model"));
296
297        let err = ShError::LlmApi("Invalid request: context too long".to_string());
298        assert!(err.to_string().contains("context"));
299    }
300
301    #[test]
302    fn test_session_error_variants() {
303        let err = ShError::Session("Session expired".to_string());
304        assert!(err.to_string().contains("Session error"));
305
306        let err = ShError::Session("Invalid session ID".to_string());
307        assert!(err.to_string().contains("Invalid session"));
308
309        let err = ShError::Session("Session not found".to_string());
310        assert!(err.to_string().contains("not found"));
311    }
312
313    #[test]
314    fn test_internal_error_variants() {
315        let err = ShError::Internal("Unexpected state".to_string());
316        assert!(err.to_string().contains("Internal error"));
317
318        let err = ShError::Internal("Stack overflow".to_string());
319        assert!(err.to_string().contains("Stack overflow"));
320    }
321
322    #[test]
323    fn test_io_error_various_kinds() {
324        // 文件未找到
325        let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
326        let sh_err: ShError = io_err.into();
327        assert!(matches!(sh_err, ShError::Io(_)));
328
329        // 权限拒绝
330        let io_err = std::io::Error::new(std::io::ErrorKind::PermissionDenied, "access denied");
331        let sh_err: ShError = io_err.into();
332        assert!(matches!(sh_err, ShError::Io(_)));
333
334        // 连接重置
335        let io_err = std::io::Error::new(std::io::ErrorKind::ConnectionReset, "connection reset");
336        let sh_err: ShError = io_err.into();
337        assert!(matches!(sh_err, ShError::Io(_)));
338
339        // 超时
340        let io_err = std::io::Error::new(std::io::ErrorKind::TimedOut, "timeout");
341        let sh_err: ShError = io_err.into();
342        assert!(matches!(sh_err, ShError::Io(_)));
343    }
344
345    #[test]
346    fn test_serde_error_various_cases() {
347        // JSON 解析错误
348        let json_err = serde_json::from_str::<serde_json::Value>("not json").unwrap_err();
349        let sh_err: ShError = json_err.into();
350        assert!(matches!(sh_err, ShError::Serde(_)));
351
352        // 类型不匹配
353        let json_err = serde_json::from_str::<i32>("\"string not number\"").unwrap_err();
354        let sh_err: ShError = json_err.into();
355        assert!(matches!(sh_err, ShError::Serde(_)));
356
357        // EOF 错误
358        let json_err = serde_json::from_str::<serde_json::Value>("").unwrap_err();
359        let sh_err: ShError = json_err.into();
360        assert!(matches!(sh_err, ShError::Serde(_)));
361    }
362
363    #[test]
364    fn test_error_chain_from_anyhow() {
365        // 简单 anyhow 错误
366        let anyhow_err = anyhow::anyhow!("Something went wrong");
367        let sh_err: ShError = anyhow_err.into();
368        assert!(matches!(sh_err, ShError::Internal(_)));
369
370        // 带上下文的 anyhow 错误
371        let anyhow_err = anyhow::anyhow!("Base error").context("Additional context");
372        let sh_err: ShError = anyhow_err.into();
373        assert!(matches!(sh_err, ShError::Internal(_)));
374    }
375
376    #[test]
377    fn test_error_handler_to_user_message_all_variants() {
378        let handler = ErrorHandler::new();
379
380        // Timeout
381        let msg = handler.to_user_message(&ShError::Timeout { seconds: 120 });
382        assert!(msg.contains("120"));
383        assert!(msg.contains("超时"));
384
385        // RateLimited
386        let msg = handler.to_user_message(&ShError::RateLimited);
387        assert!(msg.contains("频繁"));
388
389        // NotFound
390        let msg = handler.to_user_message(&ShError::NotFound {
391            resource: "配置文件".to_string(),
392        });
393        assert!(msg.contains("配置文件"));
394
395        // 其他错误类型
396        let msg = handler.to_user_message(&ShError::Config("test".to_string()));
397        assert!(!msg.is_empty());
398
399        let msg = handler.to_user_message(&ShError::LlmApi("test".to_string()));
400        assert!(!msg.is_empty());
401
402        let msg = handler.to_user_message(&ShError::Session("test".to_string()));
403        assert!(!msg.is_empty());
404
405        let msg = handler.to_user_message(&ShError::Internal("test".to_string()));
406        assert!(!msg.is_empty());
407    }
408
409    #[test]
410    fn test_error_handler_with_logging_disabled() {
411        let mut handler = ErrorHandler::new();
412        handler.log_errors = false;
413
414        let err = ShError::Internal("Test error".to_string());
415        let msg = handler.handle(&err);
416        assert!(!msg.is_empty());
417    }
418
419    #[test]
420    #[allow(clippy::unnecessary_literal_unwrap)]
421    fn test_sh_result_operations() {
422        // map 操作
423        let result: ShResult<i32> = Ok(10);
424        let mapped = result.map(|x| x * 2);
425        assert_eq!(mapped.unwrap(), 20);
426
427        // and_then 操作
428        let result: ShResult<i32> = Ok(10);
429        let chained: ShResult<i32> = Ok(result.unwrap() + 5);
430        assert_eq!(chained.unwrap(), 15);
431
432        // or_else 操作
433        let _result: ShResult<i32> = Err(ShError::NotFound {
434            resource: "test".to_string(),
435        });
436        let recovered: ShResult<i32> = Ok(0);
437        assert_eq!(recovered.unwrap(), 0);
438    }
439
440    #[test]
441    #[allow(clippy::unnecessary_literal_unwrap)]
442    fn test_sh_result_unwrap_or() {
443        let result: ShResult<i32> = Err(ShError::RateLimited);
444        let value = result.unwrap_or(42);
445        assert_eq!(value, 42);
446    }
447
448    #[test]
449    #[allow(clippy::unnecessary_literal_unwrap)]
450    fn test_sh_result_unwrap_or_else() {
451        let result: ShResult<i32> = Err(ShError::Timeout { seconds: 30 });
452        let value = result.unwrap_or(100);
453        assert_eq!(value, 100);
454    }
455
456    #[test]
457    fn test_sh_result_is_ok_is_err() {
458        let ok: ShResult<i32> = Ok(1);
459        assert!(ok.is_ok());
460
461        let err: ShResult<i32> = Err(ShError::RateLimited);
462        assert!(err.is_err());
463    }
464
465    #[test]
466    #[allow(clippy::unnecessary_literal_unwrap)]
467    fn test_sh_result_expect() {
468        let result: ShResult<i32> = Ok(42);
469        let value = result.expect("Should have a value");
470        assert_eq!(value, 42);
471    }
472
473    #[test]
474    fn test_error_equality() {
475        let err1 = ShError::RateLimited;
476        let err2 = ShError::RateLimited;
477        // ShError 没有实现 PartialEq,所以我们只检查 Display
478        assert_eq!(err1.to_string(), err2.to_string());
479    }
480
481    #[test]
482    fn test_error_debug_output() {
483        let err = ShError::Timeout { seconds: 30 };
484        let debug_output = format!("{:?}", err);
485        assert!(debug_output.contains("Timeout"));
486        assert!(debug_output.contains("30"));
487    }
488}