1use std::collections::{HashMap, VecDeque};
4use std::hash::{Hash, Hasher};
5use std::sync::{LazyLock, Mutex};
6use std::time::{Duration, SystemTime};
7
8use http::StatusCode;
9use serde_json::Value;
10use thiserror::Error;
11
12#[cfg(feature = "mcp")]
13use rmcp::service::ServiceError;
14
15use crate::client::RetryMetadata;
16
17const API_ERROR_METADATA_CAPACITY: usize = 4096;
18const API_ERROR_METADATA_MAX_TOTAL_BYTES: usize = 512 * 1024;
19const API_ERROR_METADATA_MAX_BODY_BYTES: usize = 8 * 1024;
20const API_ERROR_METADATA_MAX_DETAILS_BYTES: usize = 8 * 1024;
21const API_ERROR_METADATA_MAX_HEADERS_BYTES: usize = 4 * 1024;
22
23#[derive(Clone, Debug, Default)]
24struct ApiErrorMetadata {
25 code: Option<String>,
26 details: Option<Value>,
27 headers: Option<HashMap<String, String>>,
28 body: Option<String>,
29 retry_after_secs: Option<u64>,
30 retryable: Option<bool>,
31 attempts: Option<u32>,
32}
33
34impl ApiErrorMetadata {
35 fn bounded(mut self) -> Self {
36 self.body = self
37 .body
38 .take()
39 .and_then(|body| truncate_string(body, API_ERROR_METADATA_MAX_BODY_BYTES));
40 self.details = self.details.take().and_then(bound_details);
41 self.headers = self.headers.take().and_then(bound_headers);
42 self
43 }
44
45 fn retained_bytes(&self) -> usize {
46 self.code.as_ref().map_or(0, String::len)
47 + self.body.as_ref().map_or(0, String::len)
48 + self.headers.as_ref().map_or(0, |headers| {
49 headers
50 .iter()
51 .map(|(name, value)| name.len() + value.len())
52 .sum()
53 })
54 + self.details.as_ref().map_or(0, |details| {
55 serde_json::to_vec(details).map_or(0, |bytes| bytes.len())
56 })
57 }
58}
59
60#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
61struct ApiErrorKey {
62 status: u16,
63 message_ptr: usize,
64 message_len: usize,
65 message_hash: u64,
66}
67
68impl ApiErrorKey {
69 fn new(status: u16, message: &str) -> Self {
70 let mut hasher = std::collections::hash_map::DefaultHasher::new();
71 message.hash(&mut hasher);
72 Self {
73 status,
74 message_ptr: message.as_ptr() as usize,
75 message_len: message.len(),
76 message_hash: hasher.finish(),
77 }
78 }
79}
80
81#[derive(Default)]
82struct ApiErrorMetadataRegistry {
83 entries: HashMap<ApiErrorKey, ApiErrorMetadata>,
84 order: VecDeque<ApiErrorKey>,
85 total_bytes: usize,
86}
87
88impl ApiErrorMetadataRegistry {
89 fn get(&self, key: &ApiErrorKey) -> Option<ApiErrorMetadata> {
90 self.entries.get(key).cloned()
91 }
92
93 fn insert(&mut self, key: ApiErrorKey, metadata: ApiErrorMetadata) {
94 let metadata = metadata.bounded();
95 let metadata_bytes = metadata.retained_bytes();
96
97 if let Some(previous) = self.entries.insert(key, metadata) {
98 self.total_bytes = self.total_bytes.saturating_sub(previous.retained_bytes());
99 } else {
100 self.order.push_back(key);
101 }
102 self.total_bytes += metadata_bytes;
103
104 while self.entries.len() > API_ERROR_METADATA_CAPACITY
105 || self.total_bytes > API_ERROR_METADATA_MAX_TOTAL_BYTES
106 {
107 let Some(oldest_key) = self.order.pop_front() else {
108 break;
109 };
110 if let Some(removed) = self.entries.remove(&oldest_key) {
111 self.total_bytes = self.total_bytes.saturating_sub(removed.retained_bytes());
112 }
113 }
114 }
115}
116
117static API_ERROR_METADATA_REGISTRY: LazyLock<Mutex<ApiErrorMetadataRegistry>> =
118 LazyLock::new(|| Mutex::new(ApiErrorMetadataRegistry::default()));
119
120#[derive(Debug, Error)]
121pub enum Error {
122 #[error("HTTP client error: {source}")]
123 HttpClient {
124 #[from]
125 source: reqwest::Error,
126 },
127
128 #[error("API error (status {status}): {message}")]
129 ApiError { status: u16, message: String },
130
131 #[error("Invalid configuration: {message}")]
132 InvalidConfig { message: String },
133
134 #[error("Parse error: {message}")]
135 Parse { message: String },
136
137 #[error("Serialization error: {source}")]
138 Serialization {
139 #[from]
140 source: serde_json::Error,
141 },
142
143 #[error("IO error: {source}")]
144 Io {
145 #[from]
146 source: std::io::Error,
147 },
148
149 #[error("Timeout: {message}")]
150 Timeout { message: String },
151
152 #[error("Missing thought signature: {message}")]
153 MissingThoughtSignature { message: String },
154
155 #[error("Auth error: {message}")]
156 Auth { message: String },
157
158 #[error("Channel closed")]
159 ChannelClosed,
160
161 #[error("WebSocket error: {source}")]
162 WebSocket {
163 #[from]
164 source: tokio_tungstenite::tungstenite::Error,
165 },
166
167 #[cfg(feature = "mcp")]
168 #[error("MCP error: {source}")]
169 Mcp {
170 #[from]
171 source: ServiceError,
172 },
173}
174
175impl Error {
176 pub(crate) fn api_error_with_retryable(
177 status: u16,
178 message: impl Into<String>,
179 retryable: bool,
180 ) -> Self {
181 let message = message.into();
182 set_api_metadata(
183 status,
184 &message,
185 ApiErrorMetadata {
186 retryable: Some(retryable),
187 ..Default::default()
188 },
189 );
190 Self::ApiError { status, message }
191 }
192
193 pub(crate) async fn api_error_from_response(
194 response: reqwest::Response,
195 retryable_override: Option<bool>,
196 ) -> Self {
197 let status = response.status().as_u16();
198 let retry_metadata = response.extensions().get::<RetryMetadata>().copied();
199 let headers = header_map_to_hash_map(response.headers());
200 let retry_after_secs = retry_after_secs(response.headers());
201 let body = response.text().await.unwrap_or_default();
202 let (message, code, details) = parse_google_error(&body, status);
203 set_api_metadata(
204 status,
205 &message,
206 ApiErrorMetadata {
207 code,
208 details,
209 headers,
210 body: if body.is_empty() { None } else { Some(body) },
211 retry_after_secs,
212 retryable: retryable_override
213 .or(retry_metadata.map(|meta| meta.retryable))
214 .or(Some(default_retryable_status(status))),
215 attempts: retry_metadata.map(|meta| meta.attempts),
216 },
217 );
218
219 Self::ApiError { status, message }
220 }
221
222 fn api_metadata(&self) -> Option<ApiErrorMetadata> {
223 match self {
224 Self::ApiError { status, message } => api_metadata(*status, message),
225 _ => None,
226 }
227 }
228
229 #[must_use]
230 pub fn status(&self) -> Option<StatusCode> {
231 match self {
232 Self::ApiError { status, .. } => StatusCode::from_u16(*status).ok(),
233 _ => None,
234 }
235 }
236
237 #[must_use]
238 pub fn code(&self) -> Option<String> {
239 self.api_metadata().and_then(|metadata| metadata.code)
240 }
241
242 #[must_use]
243 pub fn details(&self) -> Option<Value> {
244 self.api_metadata().and_then(|metadata| metadata.details)
245 }
246
247 #[must_use]
248 pub fn headers(&self) -> Option<HashMap<String, String>> {
249 self.api_metadata().and_then(|metadata| metadata.headers)
250 }
251
252 #[must_use]
253 pub fn body(&self) -> Option<String> {
254 self.api_metadata().and_then(|metadata| metadata.body)
255 }
256
257 #[must_use]
258 pub fn attempts(&self) -> Option<u32> {
259 self.api_metadata().and_then(|metadata| metadata.attempts)
260 }
261
262 #[must_use]
263 pub fn retry_after(&self) -> Option<Duration> {
264 self.api_metadata()
265 .and_then(|metadata| metadata.retry_after_secs)
266 .map(Duration::from_secs)
267 }
268
269 #[must_use]
270 pub fn is_rate_limited(&self) -> bool {
271 matches!(self, Self::ApiError { status: 429, .. })
272 }
273
274 #[must_use]
275 pub fn is_retryable(&self) -> bool {
276 match self {
277 Self::ApiError { status, .. } => self
278 .api_metadata()
279 .and_then(|metadata| metadata.retryable)
280 .unwrap_or_else(|| default_retryable_status(*status)),
281 _ => false,
282 }
283 }
284}
285
286fn default_retryable_status(status: u16) -> bool {
287 matches!(status, 408 | 429 | 500 | 502 | 503 | 504)
288}
289
290fn api_metadata(status: u16, message: &str) -> Option<ApiErrorMetadata> {
291 api_error_metadata_registry().get(&ApiErrorKey::new(status, message))
292}
293
294fn set_api_metadata(status: u16, message: &str, metadata: ApiErrorMetadata) {
295 api_error_metadata_registry().insert(ApiErrorKey::new(status, message), metadata);
296}
297
298fn api_error_metadata_registry() -> std::sync::MutexGuard<'static, ApiErrorMetadataRegistry> {
299 API_ERROR_METADATA_REGISTRY
300 .lock()
301 .unwrap_or_else(|poisoned| poisoned.into_inner())
302}
303
304fn truncate_string(mut value: String, max_bytes: usize) -> Option<String> {
305 if value.is_empty() || max_bytes == 0 {
306 return None;
307 }
308
309 if value.len() <= max_bytes {
310 return Some(value);
311 }
312
313 while value.len() > max_bytes.saturating_sub(3) && !value.is_empty() {
314 value.pop();
315 }
316 value.push_str("...");
317 Some(value)
318}
319
320fn bound_details(details: Value) -> Option<Value> {
321 let bytes = serde_json::to_vec(&details).ok()?;
322 if bytes.len() <= API_ERROR_METADATA_MAX_DETAILS_BYTES {
323 return Some(details);
324 }
325 Some(Value::String(format!(
326 "[truncated error.details: {} bytes]",
327 bytes.len()
328 )))
329}
330
331fn bound_headers(headers: HashMap<String, String>) -> Option<HashMap<String, String>> {
332 if headers.is_empty() || API_ERROR_METADATA_MAX_HEADERS_BYTES == 0 {
333 return None;
334 }
335
336 let mut remaining = API_ERROR_METADATA_MAX_HEADERS_BYTES;
337 let mut bounded = HashMap::new();
338
339 for (name, value) in headers {
340 let required = name.len() + value.len();
341 if required > remaining {
342 continue;
343 }
344 remaining -= required;
345 bounded.insert(name, value);
346 }
347
348 (!bounded.is_empty()).then_some(bounded)
349}
350
351fn header_map_to_hash_map(headers: &reqwest::header::HeaderMap) -> Option<HashMap<String, String>> {
352 let mut map = HashMap::new();
353 for (name, value) in headers {
354 let Ok(value_str) = value.to_str() else {
355 continue;
356 };
357 map.entry(name.as_str().to_string())
358 .and_modify(|existing: &mut String| {
359 if !existing.is_empty() {
360 existing.push_str(", ");
361 }
362 existing.push_str(value_str);
363 })
364 .or_insert_with(|| value_str.to_string());
365 }
366 (!map.is_empty()).then_some(map)
367}
368
369fn retry_after_secs(headers: &reqwest::header::HeaderMap) -> Option<u64> {
370 let retry_after = headers
371 .get(reqwest::header::RETRY_AFTER)
372 .and_then(|value| value.to_str().ok())?
373 .trim();
374
375 retry_after.parse::<u64>().ok().or_else(|| {
376 httpdate::parse_http_date(retry_after).ok().map(|deadline| {
377 deadline
378 .duration_since(SystemTime::now())
379 .unwrap_or_default()
380 .as_secs()
381 })
382 })
383}
384
385fn parse_google_error(body: &str, status: u16) -> (String, Option<String>, Option<Value>) {
386 let fallback = if body.trim().is_empty() {
387 StatusCode::from_u16(status)
388 .ok()
389 .and_then(|code| code.canonical_reason().map(str::to_string))
390 .unwrap_or_else(|| format!("HTTP {status}"))
391 } else {
392 body.to_string()
393 };
394
395 let Ok(value) = serde_json::from_str::<Value>(body) else {
396 return (fallback, None, None);
397 };
398 let Some(error) = value.get("error") else {
399 return (fallback, None, None);
400 };
401
402 let message = error
403 .get("message")
404 .and_then(Value::as_str)
405 .map(str::to_string)
406 .unwrap_or(fallback);
407 let code = error
408 .get("status")
409 .and_then(Value::as_str)
410 .map(str::to_string)
411 .or_else(|| {
412 error
413 .get("code")
414 .and_then(Value::as_i64)
415 .map(|value| value.to_string())
416 });
417 let details = error.get("details").cloned();
418
419 (message, code, details)
420}
421
422pub type Result<T> = std::result::Result<T, Error>;
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427 use reqwest::header::{HeaderMap, HeaderValue, RETRY_AFTER};
428 use serde_json::json;
429 use std::time::SystemTime;
430
431 #[test]
432 fn parse_google_error_extracts_metadata() {
433 let body = json!({
434 "error": {
435 "message": "quota exceeded",
436 "status": "RESOURCE_EXHAUSTED",
437 "details": [{"kind": "quota"}]
438 }
439 })
440 .to_string();
441 let (message, code, details) = parse_google_error(&body, 429);
442
443 assert_eq!(message, "quota exceeded");
444 assert_eq!(code.as_deref(), Some("RESOURCE_EXHAUSTED"));
445 assert_eq!(details, Some(json!([{"kind": "quota"}])));
446 }
447
448 #[test]
449 fn parse_google_error_falls_back_to_body() {
450 let body = "plain-text failure";
451 let (message, code, details) = parse_google_error(body, 500);
452
453 assert_eq!(message, body);
454 assert!(code.is_none());
455 assert!(details.is_none());
456 }
457
458 #[test]
459 fn api_error_accessors_cover_defaults() {
460 let err =
461 Error::api_error_with_retryable(503, "unavailable", default_retryable_status(503));
462 assert_eq!(err.status(), Some(StatusCode::SERVICE_UNAVAILABLE));
463 assert_eq!(err.code(), None);
464 assert_eq!(err.details(), None);
465 assert_eq!(err.headers(), None);
466 assert_eq!(err.body(), None);
467 assert_eq!(err.attempts(), None);
468 assert_eq!(err.retry_after(), None);
469 assert!(err.is_retryable());
470 assert!(!err.is_rate_limited());
471
472 let bad_request =
473 Error::api_error_with_retryable(400, "bad request", default_retryable_status(400));
474 assert_eq!(bad_request.status(), Some(StatusCode::BAD_REQUEST));
475 assert!(!bad_request.is_retryable());
476
477 let terminal = Error::api_error_with_retryable(500, "terminal", false);
478 assert_eq!(terminal.status(), Some(StatusCode::INTERNAL_SERVER_ERROR));
479 assert!(!terminal.is_retryable());
480 }
481
482 #[test]
483 fn api_error_public_shape_stays_constructible() {
484 let err = Error::ApiError {
485 status: 418,
486 message: "teapot".into(),
487 };
488
489 assert_eq!(err.status(), Some(StatusCode::IM_A_TEAPOT));
490 assert_eq!(err.code(), None);
491 assert_eq!(err.details(), None);
492 assert_eq!(err.headers(), None);
493 assert_eq!(err.body(), None);
494 assert_eq!(err.attempts(), None);
495 assert_eq!(err.retry_after(), None);
496 assert!(!err.is_retryable());
497 }
498
499 #[test]
500 fn accessors_are_empty_for_non_api_errors() {
501 let err = Error::Parse {
502 message: "boom".into(),
503 };
504 assert_eq!(err.status(), None);
505 assert_eq!(err.code(), None);
506 assert_eq!(err.details(), None);
507 assert_eq!(err.headers(), None);
508 assert_eq!(err.body(), None);
509 assert_eq!(err.attempts(), None);
510 assert_eq!(err.retry_after(), None);
511 assert!(!err.is_retryable());
512 assert!(!err.is_rate_limited());
513 }
514
515 #[test]
516 fn header_helpers_collect_values_and_retry_after() {
517 let mut headers = HeaderMap::new();
518 headers.insert("x-test", HeaderValue::from_static("a"));
519 headers.append("x-test", HeaderValue::from_static("b"));
520 headers.insert(RETRY_AFTER, HeaderValue::from_static("7"));
521
522 let flattened = header_map_to_hash_map(&headers).unwrap();
523 assert_eq!(flattened.get("x-test").map(String::as_str), Some("a, b"));
524 assert_eq!(retry_after_secs(&headers), Some(7));
525 }
526
527 #[test]
528 fn retry_after_secs_parses_http_date() {
529 let mut headers = HeaderMap::new();
530 let deadline = SystemTime::now() + Duration::from_secs(60);
531 let header = httpdate::fmt_http_date(deadline);
532 headers.insert(RETRY_AFTER, HeaderValue::from_str(&header).unwrap());
533
534 let retry_after = retry_after_secs(&headers).unwrap();
535 assert!((58..=60).contains(&retry_after));
536 }
537
538 #[test]
539 fn api_error_metadata_bounds_large_payloads() {
540 let headers = (0..64)
541 .map(|idx| (format!("x-{idx}"), "v".repeat(128)))
542 .collect::<HashMap<_, _>>();
543 let metadata = ApiErrorMetadata {
544 code: Some("RESOURCE_EXHAUSTED".into()),
545 details: Some(json!({ "payload": "x".repeat(API_ERROR_METADATA_MAX_DETAILS_BYTES) })),
546 headers: Some(headers),
547 body: Some("b".repeat(API_ERROR_METADATA_MAX_BODY_BYTES + 32)),
548 retry_after_secs: Some(7),
549 retryable: Some(true),
550 attempts: Some(2),
551 }
552 .bounded();
553
554 assert!(metadata.body.unwrap().len() <= API_ERROR_METADATA_MAX_BODY_BYTES);
555 assert!(
556 metadata
557 .headers
558 .unwrap()
559 .into_iter()
560 .map(|(name, value)| name.len() + value.len())
561 .sum::<usize>()
562 <= API_ERROR_METADATA_MAX_HEADERS_BYTES
563 );
564 assert!(matches!(metadata.details, Some(Value::String(_))));
565 }
566
567 #[test]
568 fn bound_headers_keeps_smaller_headers_after_large_entries() {
569 let headers = HashMap::from([
570 (
571 "x-large".to_string(),
572 "v".repeat(API_ERROR_METADATA_MAX_HEADERS_BYTES + 1),
573 ),
574 ("retry-after".to_string(), "7".to_string()),
575 ("x-small".to_string(), "ok".to_string()),
576 ]);
577
578 let bounded = bound_headers(headers).unwrap();
579 assert_eq!(bounded.get("retry-after").map(String::as_str), Some("7"));
580 assert_eq!(bounded.get("x-small").map(String::as_str), Some("ok"));
581 assert!(!bounded.contains_key("x-large"));
582 }
583
584 #[test]
585 fn api_error_metadata_registry_evicts_by_total_bytes() {
586 let mut registry = ApiErrorMetadataRegistry::default();
587 let first_key = ApiErrorKey::new(500, "first");
588
589 registry.insert(
590 first_key,
591 ApiErrorMetadata {
592 body: Some("a".repeat(API_ERROR_METADATA_MAX_BODY_BYTES)),
593 ..Default::default()
594 },
595 );
596
597 for idx in 0..96 {
598 registry.insert(
599 ApiErrorKey::new(500, &format!("entry-{idx}")),
600 ApiErrorMetadata {
601 body: Some("b".repeat(API_ERROR_METADATA_MAX_BODY_BYTES)),
602 ..Default::default()
603 },
604 );
605 }
606
607 assert!(registry.total_bytes <= API_ERROR_METADATA_MAX_TOTAL_BYTES);
608 assert!(registry.get(&first_key).is_none());
609 }
610}