Skip to main content

zai_rs/toolkits/
error.rs

1//! Enhanced error handling with better Rust idioms
2
3use std::borrow::Cow;
4
5use thiserror::Error;
6
7/// Result type for tool operations
8pub type ToolResult<T> = Result<T, ToolError>;
9
10/// Error severity levels for better error handling strategies
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum ErrorSeverity {
13    User,      // User error, no retry needed
14    Normal,    // Normal error, may retry
15    Transient, // Transient error, should retry
16    Critical,  // Critical error, log and alert
17}
18
19/// Enhanced error type with better context and error chaining
20#[derive(Error, Debug)]
21pub enum ToolError {
22    #[error("Tool '{name}' not found")]
23    ToolNotFound { name: Cow<'static, str> },
24
25    #[error("Invalid parameters for tool '{tool}': {message}")]
26    InvalidParameters {
27        tool: Cow<'static, str>,
28        message: Cow<'static, str>,
29    },
30
31    #[error("Tool '{tool}' execution failed: {message}")]
32    ExecutionFailed {
33        tool: Cow<'static, str>,
34        message: Cow<'static, str>,
35    },
36
37    #[error("Schema validation failed for tool '{tool}': {message}")]
38    SchemaValidation {
39        tool: Cow<'static, str>,
40        message: Cow<'static, str>,
41    },
42
43    #[error("Tool registration failed: {message}")]
44    RegistrationError { message: Cow<'static, str> },
45
46    #[error("Serialization error for tool '{tool}': {source}")]
47    SerializationError {
48        tool: Cow<'static, str>,
49        #[source]
50        source: serde_json::Error,
51    },
52
53    #[error("Timeout error for tool '{tool}': execution exceeded {timeout:?}")]
54    TimeoutError {
55        tool: Cow<'static, str>,
56        timeout: std::time::Duration,
57    },
58
59    #[error("Retry limit exceeded for tool '{tool}': failed after {attempts} attempts")]
60    RetryLimitExceeded {
61        tool: Cow<'static, str>,
62        attempts: u32,
63    },
64
65    #[error("Validation error for field '{field}': {message}")]
66    ValidationError {
67        field: Cow<'static, str>,
68        message: Cow<'static, str>,
69    },
70
71    #[error("Concurrent access error: {message}")]
72    ConcurrentAccessError { message: Cow<'static, str> },
73
74    #[error("Internal error: {0}")]
75    Internal(String),
76}
77
78impl ToolError {
79    /// Determine if the error is retryable
80    pub fn is_retryable(&self) -> bool {
81        matches!(
82            self,
83            ToolError::TimeoutError { .. }
84                | ToolError::ConcurrentAccessError { .. }
85                | ToolError::ExecutionFailed { .. }
86        )
87    }
88
89    /// Get the severity level of the error
90    pub fn severity(&self) -> ErrorSeverity {
91        match self {
92            ToolError::ToolNotFound { .. } => ErrorSeverity::User,
93            ToolError::InvalidParameters { .. } => ErrorSeverity::User,
94            ToolError::ValidationError { .. } => ErrorSeverity::User,
95            ToolError::TimeoutError { .. } => ErrorSeverity::Transient,
96            ToolError::ConcurrentAccessError { .. } => ErrorSeverity::Transient,
97            ToolError::Internal(_) => ErrorSeverity::Critical,
98            _ => ErrorSeverity::Normal,
99        }
100    }
101}
102
103/// Error context builder for better error reporting
104pub struct ErrorContext {
105    tool_name: Option<String>,
106    operation: Option<String>,
107}
108
109impl ErrorContext {
110    pub fn new() -> Self {
111        Self {
112            tool_name: None,
113            operation: None,
114        }
115    }
116
117    pub fn with_tool(mut self, tool_name: impl Into<String>) -> Self {
118        self.tool_name = Some(tool_name.into());
119        self
120    }
121
122    pub fn with_operation(mut self, operation: impl Into<String>) -> Self {
123        self.operation = Some(operation.into());
124        self
125    }
126
127    fn get_tool_name(&self) -> String {
128        self.tool_name
129            .clone()
130            .unwrap_or_else(|| "unknown".to_string())
131    }
132
133    pub fn tool_not_found(self) -> ToolError {
134        ToolError::ToolNotFound {
135            name: Cow::Owned(self.get_tool_name()),
136        }
137    }
138
139    pub fn invalid_parameters(self, message: impl Into<String>) -> ToolError {
140        let mut msg = message.into();
141        if let Some(ref op) = self.operation {
142            msg = format!("[{}] {}", op, msg);
143        }
144        ToolError::InvalidParameters {
145            tool: Cow::Owned(self.get_tool_name()),
146            message: Cow::Owned(msg),
147        }
148    }
149
150    pub fn execution_failed(self, message: impl Into<String>) -> ToolError {
151        let mut msg = message.into();
152        if let Some(ref op) = self.operation {
153            msg = format!("[{}] {}", op, msg);
154        }
155        ToolError::ExecutionFailed {
156            tool: Cow::Owned(self.get_tool_name()),
157            message: Cow::Owned(msg),
158        }
159    }
160
161    pub fn schema_validation(self, message: impl Into<String>) -> ToolError {
162        let mut msg = message.into();
163        if let Some(ref op) = self.operation {
164            msg = format!("[{}] {}", op, msg);
165        }
166        ToolError::SchemaValidation {
167            tool: Cow::Owned(self.get_tool_name()),
168            message: Cow::Owned(msg),
169        }
170    }
171
172    pub fn serialization_error(self, source: serde_json::Error) -> ToolError {
173        let mut tool_name = self.get_tool_name();
174        if let Some(ref op) = self.operation {
175            tool_name = format!("{} [{}]", tool_name, op);
176        }
177        ToolError::SerializationError {
178            tool: Cow::Owned(tool_name),
179            source,
180        }
181    }
182
183    pub fn timeout_error(self, timeout: std::time::Duration) -> ToolError {
184        ToolError::TimeoutError {
185            tool: Cow::Owned(self.get_tool_name()),
186            timeout,
187        }
188    }
189
190    pub fn retry_limit_exceeded(self, attempts: u32) -> ToolError {
191        ToolError::RetryLimitExceeded {
192            tool: Cow::Owned(self.get_tool_name()),
193            attempts,
194        }
195    }
196
197    pub fn validation_error(
198        self,
199        field: impl Into<String>,
200        message: impl Into<String>,
201    ) -> ToolError {
202        ToolError::ValidationError {
203            field: Cow::Owned(field.into()),
204            message: Cow::Owned(message.into()),
205        }
206    }
207
208    pub fn concurrent_access_error(self, message: impl Into<String>) -> ToolError {
209        ToolError::ConcurrentAccessError {
210            message: Cow::Owned(message.into()),
211        }
212    }
213}
214
215impl Default for ErrorContext {
216    fn default() -> Self {
217        Self::new()
218    }
219}
220
221/// Convenience function to create error context
222pub fn error_context() -> ErrorContext {
223    ErrorContext::new()
224}