tt_provider_compat/
errors.rs1use chrono::{DateTime, Utc};
9use serde::Deserialize;
10use tt_shared::ProviderError;
11
12#[derive(Debug, Deserialize)]
14struct OpenAiErrorBody {
15 error: OpenAiError,
16}
17
18#[derive(Debug, Deserialize)]
20struct OpenAiError {
21 message: String,
22 #[serde(rename = "type")]
23 error_type: Option<String>,
24 #[allow(dead_code)]
25 code: Option<String>,
26 #[allow(dead_code)]
27 param: Option<serde_json::Value>,
28}
29
30pub fn map_response_error(
35 status: u16,
36 body: &str,
37 retry_after_header: Option<&str>,
38) -> ProviderError {
39 let parsed: Option<OpenAiErrorBody> = serde_json::from_str(body).ok();
40 let message = parsed
41 .as_ref()
42 .map(|p| p.error.message.clone())
43 .unwrap_or_else(|| body.to_string());
44 let error_type = parsed
45 .as_ref()
46 .and_then(|p| p.error.error_type.clone())
47 .unwrap_or_default();
48
49 match status {
50 401 => ProviderError::Unauthorized(message),
51 429 => {
52 let retry_after_ms = parse_retry_after(retry_after_header);
53 ProviderError::RateLimited { retry_after_ms }
54 }
55 400 if error_type == "invalid_request_error" => ProviderError::InvalidRequest(message),
56 400 => ProviderError::InvalidRequest(message),
57 404 if message.to_lowercase().contains("model") => {
58 let model = extract_model_name(&message);
59 ProviderError::ModelNotFound { model }
60 }
61 404 => ProviderError::InvalidRequest(message),
62 408 => ProviderError::Timeout { ms: 0 },
63 500..=599 => ProviderError::ProviderUpstream { status, message },
64 _ => ProviderError::ProviderUpstream { status, message },
65 }
66}
67
68pub fn map_reqwest_error(err: reqwest::Error) -> ProviderError {
70 if err.is_timeout() {
71 ProviderError::Timeout { ms: 0 }
72 } else {
73 ProviderError::Network(err)
74 }
75}
76
77fn parse_retry_after(header: Option<&str>) -> u64 {
84 let Some(value) = header else {
85 return 1000;
86 };
87
88 if let Ok(secs) = value.trim().parse::<u64>() {
90 return secs * 1000;
91 }
92
93 if let Ok(date) = DateTime::parse_from_rfc2822(value.trim()) {
95 let delta = date.with_timezone(&Utc) - Utc::now();
96 let ms = delta.num_milliseconds().max(0) as u64;
97 return ms;
98 }
99
100 1000
101}
102
103fn extract_model_name(message: &str) -> String {
107 if let Some(start) = message.find('\'') {
109 let after = &message[start + 1..];
110 if let Some(end) = after.find('\'') {
111 return after[..end].to_string();
112 }
113 }
114 if let Some(start) = message.find('"') {
115 let after = &message[start + 1..];
116 if let Some(end) = after.find('"') {
117 return after[..end].to_string();
118 }
119 }
120 message.to_string()
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126
127 #[test]
128 fn parse_retry_after_integer() {
129 assert_eq!(parse_retry_after(Some("5")), 5000);
130 }
131
132 #[test]
133 fn parse_retry_after_missing() {
134 assert_eq!(parse_retry_after(None), 1000);
135 }
136
137 #[test]
138 fn parse_retry_after_garbage() {
139 assert_eq!(parse_retry_after(Some("garbage")), 1000);
140 }
141
142 #[test]
143 fn extract_model_single_quoted() {
144 let msg = "The model 'gpt-99' does not exist";
145 assert_eq!(extract_model_name(msg), "gpt-99");
146 }
147
148 #[test]
149 fn map_401_to_unauthorized() {
150 let body = r#"{"error":{"message":"Invalid API key","type":"invalid_api_key","code":"invalid_api_key","param":null}}"#;
151 let err = map_response_error(401, body, None);
152 assert!(matches!(err, ProviderError::Unauthorized(_)));
153 }
154
155 #[test]
156 fn map_429_with_retry_after() {
157 let body = r#"{"error":{"message":"Rate limit exceeded","type":"requests","code":null,"param":null}}"#;
158 let err = map_response_error(429, body, Some("5"));
159 assert!(matches!(
160 err,
161 ProviderError::RateLimited {
162 retry_after_ms: 5000
163 }
164 ));
165 }
166
167 #[test]
168 fn map_429_without_retry_after() {
169 let body = r#"{"error":{"message":"Rate limit exceeded","type":"requests","code":null,"param":null}}"#;
170 let err = map_response_error(429, body, None);
171 assert!(matches!(
172 err,
173 ProviderError::RateLimited {
174 retry_after_ms: 1000
175 }
176 ));
177 }
178}