1use std::fmt;
4
5#[derive(Debug, thiserror::Error)]
7pub enum Error {
8 #[error("Provider error: {message}")]
10 Provider {
11 message: String,
13 status_code: Option<u16>,
15 },
16
17 #[error("Tool execution failed: {tool_name}: {message}")]
19 ToolExecution {
20 tool_name: String,
22 message: String,
24 },
25
26 #[error("Memory error: {0}")]
28 Memory(String),
29
30 #[error("Configuration error: {0}")]
32 Config(String),
33
34 #[error("Runtime error: {0}")]
36 Runtime(String),
37
38 #[error("Serialization error: {0}")]
40 Serialization(#[from] serde_json::Error),
41
42 #[error("IO error: {0}")]
44 Io(#[from] std::io::Error),
45}
46
47pub type Result<T> = std::result::Result<T, Error>;
49
50const RETRYABLE_STATUS_CODES: &[u16] = &[429, 500, 502, 503, 504];
52
53impl Error {
54 #[must_use]
56 pub fn provider(message: impl fmt::Display) -> Self {
57 Self::Provider {
58 message: message.to_string(),
59 status_code: None,
60 }
61 }
62
63 #[must_use]
65 pub fn provider_with_status(message: impl fmt::Display, status_code: u16) -> Self {
66 Self::Provider {
67 message: message.to_string(),
68 status_code: Some(status_code),
69 }
70 }
71
72 #[must_use]
74 pub fn tool_execution(tool_name: impl fmt::Display, message: impl fmt::Display) -> Self {
75 Self::ToolExecution {
76 tool_name: tool_name.to_string(),
77 message: message.to_string(),
78 }
79 }
80
81 #[must_use]
86 pub fn is_retryable(&self) -> bool {
87 match self {
88 Self::Provider { status_code, .. } => match status_code {
89 Some(code) => RETRYABLE_STATUS_CODES.contains(code),
90 None => true, },
92 _ => false,
93 }
94 }
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100
101 #[test]
102 fn test_provider_error_display() {
103 let err = Error::provider("model not found");
104 assert_eq!(err.to_string(), "Provider error: model not found");
105 }
106
107 #[test]
108 fn test_tool_execution_error_display() {
109 let err = Error::tool_execution("web_search", "timeout");
110 assert_eq!(
111 err.to_string(),
112 "Tool execution failed: web_search: timeout"
113 );
114 }
115
116 #[test]
117 fn test_config_error_display() {
118 let err = Error::Config("no provider configured".into());
119 assert_eq!(
120 err.to_string(),
121 "Configuration error: no provider configured"
122 );
123 }
124
125 #[test]
126 fn test_runtime_error_display() {
127 let err = Error::Runtime("max iterations reached".into());
128 assert_eq!(err.to_string(), "Runtime error: max iterations reached");
129 }
130
131 #[test]
132 fn test_memory_error_display() {
133 let err = Error::Memory("session not found".into());
134 assert_eq!(err.to_string(), "Memory error: session not found");
135 }
136
137 #[test]
138 fn test_from_serde_json_error() {
139 let json_err = serde_json::from_str::<String>("not valid json").unwrap_err();
141 let err: Error = json_err.into();
142 assert!(matches!(err, Error::Serialization(_)));
143 assert!(err.to_string().contains("Serialization error"));
144 }
145
146 #[test]
147 fn test_from_io_error() {
148 let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file missing");
150 let err: Error = io_err.into();
151 assert!(matches!(err, Error::Io(_)));
152 assert!(err.to_string().contains("IO error"));
153 }
154
155 #[test]
156 fn test_result_alias_with_question_mark() {
157 fn may_fail(succeed: bool) -> Result<String> {
159 if succeed {
160 Ok("done".into())
161 } else {
162 Err(Error::Runtime("failed".into()))
163 }
164 }
165
166 fn chain() -> Result<String> {
167 let val = may_fail(true)?;
168 Ok(val)
169 }
170
171 assert!(chain().is_ok());
172 assert!(may_fail(false).is_err());
173 }
174
175 #[test]
176 fn test_error_is_std_error() {
177 let err = Error::provider("test");
179 let _: &dyn std::error::Error = &err;
180 }
181}