Skip to main content

traitclaw_core/
error.rs

1//! Error types for `TraitClaw` Core.
2
3use std::fmt;
4
5/// Errors that can occur during agent operation.
6#[derive(Debug, thiserror::Error)]
7pub enum Error {
8    /// Error from the LLM provider.
9    #[error("Provider error: {message}")]
10    Provider {
11        /// Error message from the provider.
12        message: String,
13        /// Optional HTTP status code (for retry classification).
14        status_code: Option<u16>,
15    },
16
17    /// Error during tool execution.
18    #[error("Tool execution failed: {tool_name}: {message}")]
19    ToolExecution {
20        /// Name of the tool that failed.
21        tool_name: String,
22        /// Error message describing the failure.
23        message: String,
24    },
25
26    /// Error from the memory system.
27    #[error("Memory error: {0}")]
28    Memory(String),
29
30    /// Configuration error (e.g., missing required fields).
31    #[error("Configuration error: {0}")]
32    Config(String),
33
34    /// Runtime error during agent loop execution.
35    #[error("Runtime error: {0}")]
36    Runtime(String),
37
38    /// Serialization/deserialization error.
39    #[error("Serialization error: {0}")]
40    Serialization(#[from] serde_json::Error),
41
42    /// IO error.
43    #[error("IO error: {0}")]
44    Io(#[from] std::io::Error),
45}
46
47/// Convenience type alias for `Result<T, Error>`.
48pub type Result<T> = std::result::Result<T, Error>;
49
50/// Default status codes considered retryable (transient).
51const RETRYABLE_STATUS_CODES: &[u16] = &[429, 500, 502, 503, 504];
52
53impl Error {
54    /// Create a provider error.
55    #[must_use]
56    pub fn provider(message: impl fmt::Display) -> Self {
57        Self::Provider {
58            message: message.to_string(),
59            status_code: None,
60        }
61    }
62
63    /// Create a provider error with an HTTP status code.
64    #[must_use]
65    pub fn provider_with_status(message: impl fmt::Display, status_code: u16) -> Self {
66        Self::Provider {
67            message: message.to_string(),
68            status_code: Some(status_code),
69        }
70    }
71
72    /// Create a tool execution error.
73    #[must_use]
74    pub fn tool_execution(tool_name: impl fmt::Display, message: impl fmt::Display) -> Self {
75        Self::ToolExecution {
76            tool_name: tool_name.to_string(),
77            message: message.to_string(),
78        }
79    }
80
81    /// Check whether this error is safe to retry.
82    ///
83    /// Returns `true` for transient provider errors (429, 500, 502, 503, 504)
84    /// and provider errors without a status code (e.g., timeouts).
85    #[must_use]
86    pub fn is_retryable(&self) -> bool {
87        match self {
88            Self::Provider { status_code, .. } => match status_code {
89                Some(code) => RETRYABLE_STATUS_CODES.contains(code),
90                None => true, // timeout / network errors are typically retryable
91            },
92            _ => false,
93        }
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100
101    #[test]
102    fn test_provider_error_display() {
103        let err = Error::provider("model not found");
104        assert_eq!(err.to_string(), "Provider error: model not found");
105    }
106
107    #[test]
108    fn test_tool_execution_error_display() {
109        let err = Error::tool_execution("web_search", "timeout");
110        assert_eq!(
111            err.to_string(),
112            "Tool execution failed: web_search: timeout"
113        );
114    }
115
116    #[test]
117    fn test_config_error_display() {
118        let err = Error::Config("no provider configured".into());
119        assert_eq!(
120            err.to_string(),
121            "Configuration error: no provider configured"
122        );
123    }
124
125    #[test]
126    fn test_runtime_error_display() {
127        let err = Error::Runtime("max iterations reached".into());
128        assert_eq!(err.to_string(), "Runtime error: max iterations reached");
129    }
130
131    #[test]
132    fn test_memory_error_display() {
133        let err = Error::Memory("session not found".into());
134        assert_eq!(err.to_string(), "Memory error: session not found");
135    }
136
137    #[test]
138    fn test_from_serde_json_error() {
139        // AC-4: #[from] conversion for serde_json::Error
140        let json_err = serde_json::from_str::<String>("not valid json").unwrap_err();
141        let err: Error = json_err.into();
142        assert!(matches!(err, Error::Serialization(_)));
143        assert!(err.to_string().contains("Serialization error"));
144    }
145
146    #[test]
147    fn test_from_io_error() {
148        // AC-4: #[from] conversion for std::io::Error
149        let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file missing");
150        let err: Error = io_err.into();
151        assert!(matches!(err, Error::Io(_)));
152        assert!(err.to_string().contains("IO error"));
153    }
154
155    #[test]
156    fn test_result_alias_with_question_mark() {
157        // AC-3: Result<T> alias works with ? operator
158        fn may_fail(succeed: bool) -> Result<String> {
159            if succeed {
160                Ok("done".into())
161            } else {
162                Err(Error::Runtime("failed".into()))
163            }
164        }
165
166        fn chain() -> Result<String> {
167            let val = may_fail(true)?;
168            Ok(val)
169        }
170
171        assert!(chain().is_ok());
172        assert!(may_fail(false).is_err());
173    }
174
175    #[test]
176    fn test_error_is_std_error() {
177        // AC-4: errors implement std::error::Error
178        let err = Error::provider("test");
179        let _: &dyn std::error::Error = &err;
180    }
181}