1use thiserror::Error;
7
8#[derive(Debug, Clone, Error)]
10pub enum LlmError {
11 #[error("LLM API error: {0}")]
13 Api(String),
14
15 #[error("Request error: {0}")]
17 Request(String),
18
19 #[error("Configuration error: {0}")]
21 Config(String),
22
23 #[error("Failed to parse response: {0}")]
25 Parse(String),
26
27 #[error("Rate limit exceeded: {0}")]
29 RateLimit(String),
30
31 #[error("Request timeout: {0}")]
33 Timeout(String),
34
35 #[error("LLM returned no content")]
37 NoContent,
38
39 #[error("Retry exhausted after {attempts} attempts: {last_error}")]
41 RetryExhausted {
42 attempts: usize,
44 last_error: String,
46 },
47}
48
49impl LlmError {
50 pub fn is_retryable(&self) -> bool {
52 match self {
53 LlmError::Api(msg) => {
54 let msg_lower = msg.to_lowercase();
56 msg_lower.contains("rate limit")
57 || msg_lower.contains("429")
58 || msg_lower.contains("503")
59 || msg_lower.contains("502")
60 || msg_lower.contains("timeout")
61 || msg_lower.contains("overloaded")
62 }
63 LlmError::Timeout(_) => true,
64 LlmError::RateLimit(_) => true,
65 _ => false,
66 }
67 }
68
69 pub fn from_api_message(msg: &str) -> Self {
71 let msg_lower = msg.to_lowercase();
72
73 if msg_lower.contains("rate limit") || msg_lower.contains("429") {
74 LlmError::RateLimit(msg.to_string())
75 } else if msg_lower.contains("timeout") {
76 LlmError::Timeout(msg.to_string())
77 } else {
78 LlmError::Api(msg.to_string())
79 }
80 }
81}
82
83impl From<async_openai::error::OpenAIError> for LlmError {
84 fn from(e: async_openai::error::OpenAIError) -> Self {
85 let msg = e.to_string();
86 LlmError::from_api_message(&msg)
87 }
88}
89
90impl From<serde_json::Error> for LlmError {
91 fn from(e: serde_json::Error) -> Self {
92 LlmError::Parse(e.to_string())
93 }
94}
95
96impl From<LlmError> for crate::Error {
97 fn from(e: LlmError) -> Self {
98 crate::Error::Llm(e.to_string())
99 }
100}
101
102impl From<LlmError> for String {
103 fn from(e: LlmError) -> Self {
104 e.to_string()
105 }
106}
107
108pub type LlmResult<T> = std::result::Result<T, LlmError>;
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114
115 #[test]
116 fn test_is_retryable() {
117 assert!(LlmError::RateLimit("test".to_string()).is_retryable());
118 assert!(LlmError::Timeout("test".to_string()).is_retryable());
119 assert!(LlmError::Api("rate limit exceeded".to_string()).is_retryable());
120 assert!(!LlmError::Config("test".to_string()).is_retryable());
121 assert!(!LlmError::Parse("test".to_string()).is_retryable());
122 }
123
124 #[test]
125 fn test_from_api_message() {
126 let err = LlmError::from_api_message("Rate limit exceeded");
127 assert!(matches!(err, LlmError::RateLimit(_)));
128
129 let err = LlmError::from_api_message("Request timeout");
130 assert!(matches!(err, LlmError::Timeout(_)));
131
132 let err = LlmError::from_api_message("Internal server error");
133 assert!(matches!(err, LlmError::Api(_)));
134 }
135}