Skip to main content

rust_genai/
error.rs

1//! Error definitions for the SDK.
2
3use std::collections::{HashMap, VecDeque};
4use std::hash::{Hash, Hasher};
5use std::sync::{LazyLock, Mutex};
6use std::time::{Duration, SystemTime};
7
8use http::StatusCode;
9use serde_json::Value;
10use thiserror::Error;
11
12#[cfg(feature = "mcp")]
13use rmcp::service::ServiceError;
14
15use crate::client::RetryMetadata;
16
17const API_ERROR_METADATA_CAPACITY: usize = 4096;
18const API_ERROR_METADATA_MAX_TOTAL_BYTES: usize = 512 * 1024;
19const API_ERROR_METADATA_MAX_BODY_BYTES: usize = 8 * 1024;
20const API_ERROR_METADATA_MAX_DETAILS_BYTES: usize = 8 * 1024;
21const API_ERROR_METADATA_MAX_HEADERS_BYTES: usize = 4 * 1024;
22
23#[derive(Clone, Debug, Default)]
24struct ApiErrorMetadata {
25    code: Option<String>,
26    details: Option<Value>,
27    headers: Option<HashMap<String, String>>,
28    body: Option<String>,
29    retry_after_secs: Option<u64>,
30    retryable: Option<bool>,
31    attempts: Option<u32>,
32}
33
34impl ApiErrorMetadata {
35    fn bounded(mut self) -> Self {
36        self.body = self
37            .body
38            .take()
39            .and_then(|body| truncate_string(body, API_ERROR_METADATA_MAX_BODY_BYTES));
40        self.details = self.details.take().and_then(bound_details);
41        self.headers = self.headers.take().and_then(bound_headers);
42        self
43    }
44
45    fn retained_bytes(&self) -> usize {
46        self.code.as_ref().map_or(0, String::len)
47            + self.body.as_ref().map_or(0, String::len)
48            + self.headers.as_ref().map_or(0, |headers| {
49                headers
50                    .iter()
51                    .map(|(name, value)| name.len() + value.len())
52                    .sum()
53            })
54            + self.details.as_ref().map_or(0, |details| {
55                serde_json::to_vec(details).map_or(0, |bytes| bytes.len())
56            })
57    }
58}
59
60#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
61struct ApiErrorKey {
62    status: u16,
63    message_ptr: usize,
64    message_len: usize,
65    message_hash: u64,
66}
67
68impl ApiErrorKey {
69    fn new(status: u16, message: &str) -> Self {
70        let mut hasher = std::collections::hash_map::DefaultHasher::new();
71        message.hash(&mut hasher);
72        Self {
73            status,
74            message_ptr: message.as_ptr() as usize,
75            message_len: message.len(),
76            message_hash: hasher.finish(),
77        }
78    }
79}
80
81#[derive(Default)]
82struct ApiErrorMetadataRegistry {
83    entries: HashMap<ApiErrorKey, ApiErrorMetadata>,
84    order: VecDeque<ApiErrorKey>,
85    total_bytes: usize,
86}
87
88impl ApiErrorMetadataRegistry {
89    fn get(&self, key: &ApiErrorKey) -> Option<ApiErrorMetadata> {
90        self.entries.get(key).cloned()
91    }
92
93    fn insert(&mut self, key: ApiErrorKey, metadata: ApiErrorMetadata) {
94        let metadata = metadata.bounded();
95        let metadata_bytes = metadata.retained_bytes();
96
97        if let Some(previous) = self.entries.insert(key, metadata) {
98            self.total_bytes = self.total_bytes.saturating_sub(previous.retained_bytes());
99        } else {
100            self.order.push_back(key);
101        }
102        self.total_bytes += metadata_bytes;
103
104        while self.entries.len() > API_ERROR_METADATA_CAPACITY
105            || self.total_bytes > API_ERROR_METADATA_MAX_TOTAL_BYTES
106        {
107            let Some(oldest_key) = self.order.pop_front() else {
108                break;
109            };
110            if let Some(removed) = self.entries.remove(&oldest_key) {
111                self.total_bytes = self.total_bytes.saturating_sub(removed.retained_bytes());
112            }
113        }
114    }
115}
116
117static API_ERROR_METADATA_REGISTRY: LazyLock<Mutex<ApiErrorMetadataRegistry>> =
118    LazyLock::new(|| Mutex::new(ApiErrorMetadataRegistry::default()));
119
120#[derive(Debug, Error)]
121pub enum Error {
122    #[error("HTTP client error: {source}")]
123    HttpClient {
124        #[from]
125        source: reqwest::Error,
126    },
127
128    #[error("API error (status {status}): {message}")]
129    ApiError { status: u16, message: String },
130
131    #[error("Invalid configuration: {message}")]
132    InvalidConfig { message: String },
133
134    #[error("Parse error: {message}")]
135    Parse { message: String },
136
137    #[error("Serialization error: {source}")]
138    Serialization {
139        #[from]
140        source: serde_json::Error,
141    },
142
143    #[error("IO error: {source}")]
144    Io {
145        #[from]
146        source: std::io::Error,
147    },
148
149    #[error("Timeout: {message}")]
150    Timeout { message: String },
151
152    #[error("Missing thought signature: {message}")]
153    MissingThoughtSignature { message: String },
154
155    #[error("Auth error: {message}")]
156    Auth { message: String },
157
158    #[error("Channel closed")]
159    ChannelClosed,
160
161    #[error("WebSocket error: {source}")]
162    WebSocket {
163        #[from]
164        source: tokio_tungstenite::tungstenite::Error,
165    },
166
167    #[cfg(feature = "mcp")]
168    #[error("MCP error: {source}")]
169    Mcp {
170        #[from]
171        source: ServiceError,
172    },
173}
174
175impl Error {
176    pub(crate) fn api_error_with_retryable(
177        status: u16,
178        message: impl Into<String>,
179        retryable: bool,
180    ) -> Self {
181        let message = message.into();
182        set_api_metadata(
183            status,
184            &message,
185            ApiErrorMetadata {
186                retryable: Some(retryable),
187                ..Default::default()
188            },
189        );
190        Self::ApiError { status, message }
191    }
192
193    pub(crate) async fn api_error_from_response(
194        response: reqwest::Response,
195        retryable_override: Option<bool>,
196    ) -> Self {
197        let status = response.status().as_u16();
198        let retry_metadata = response.extensions().get::<RetryMetadata>().copied();
199        let headers = header_map_to_hash_map(response.headers());
200        let retry_after_secs = retry_after_secs(response.headers());
201        let body = response.text().await.unwrap_or_default();
202        let (message, code, details) = parse_google_error(&body, status);
203        set_api_metadata(
204            status,
205            &message,
206            ApiErrorMetadata {
207                code,
208                details,
209                headers,
210                body: if body.is_empty() { None } else { Some(body) },
211                retry_after_secs,
212                retryable: retryable_override
213                    .or(retry_metadata.map(|meta| meta.retryable))
214                    .or(Some(default_retryable_status(status))),
215                attempts: retry_metadata.map(|meta| meta.attempts),
216            },
217        );
218
219        Self::ApiError { status, message }
220    }
221
222    fn api_metadata(&self) -> Option<ApiErrorMetadata> {
223        match self {
224            Self::ApiError { status, message } => api_metadata(*status, message),
225            _ => None,
226        }
227    }
228
229    #[must_use]
230    pub fn status(&self) -> Option<StatusCode> {
231        match self {
232            Self::ApiError { status, .. } => StatusCode::from_u16(*status).ok(),
233            _ => None,
234        }
235    }
236
237    #[must_use]
238    pub fn code(&self) -> Option<String> {
239        self.api_metadata().and_then(|metadata| metadata.code)
240    }
241
242    #[must_use]
243    pub fn details(&self) -> Option<Value> {
244        self.api_metadata().and_then(|metadata| metadata.details)
245    }
246
247    #[must_use]
248    pub fn headers(&self) -> Option<HashMap<String, String>> {
249        self.api_metadata().and_then(|metadata| metadata.headers)
250    }
251
252    #[must_use]
253    pub fn body(&self) -> Option<String> {
254        self.api_metadata().and_then(|metadata| metadata.body)
255    }
256
257    #[must_use]
258    pub fn attempts(&self) -> Option<u32> {
259        self.api_metadata().and_then(|metadata| metadata.attempts)
260    }
261
262    #[must_use]
263    pub fn retry_after(&self) -> Option<Duration> {
264        self.api_metadata()
265            .and_then(|metadata| metadata.retry_after_secs)
266            .map(Duration::from_secs)
267    }
268
269    #[must_use]
270    pub fn is_rate_limited(&self) -> bool {
271        matches!(self, Self::ApiError { status: 429, .. })
272    }
273
274    #[must_use]
275    pub fn is_retryable(&self) -> bool {
276        match self {
277            Self::ApiError { status, .. } => self
278                .api_metadata()
279                .and_then(|metadata| metadata.retryable)
280                .unwrap_or_else(|| default_retryable_status(*status)),
281            _ => false,
282        }
283    }
284}
285
286fn default_retryable_status(status: u16) -> bool {
287    matches!(status, 408 | 429 | 500 | 502 | 503 | 504)
288}
289
290fn api_metadata(status: u16, message: &str) -> Option<ApiErrorMetadata> {
291    api_error_metadata_registry().get(&ApiErrorKey::new(status, message))
292}
293
294fn set_api_metadata(status: u16, message: &str, metadata: ApiErrorMetadata) {
295    api_error_metadata_registry().insert(ApiErrorKey::new(status, message), metadata);
296}
297
298fn api_error_metadata_registry() -> std::sync::MutexGuard<'static, ApiErrorMetadataRegistry> {
299    API_ERROR_METADATA_REGISTRY
300        .lock()
301        .unwrap_or_else(|poisoned| poisoned.into_inner())
302}
303
304fn truncate_string(mut value: String, max_bytes: usize) -> Option<String> {
305    if value.is_empty() || max_bytes == 0 {
306        return None;
307    }
308
309    if value.len() <= max_bytes {
310        return Some(value);
311    }
312
313    while value.len() > max_bytes.saturating_sub(3) && !value.is_empty() {
314        value.pop();
315    }
316    value.push_str("...");
317    Some(value)
318}
319
320fn bound_details(details: Value) -> Option<Value> {
321    let bytes = serde_json::to_vec(&details).ok()?;
322    if bytes.len() <= API_ERROR_METADATA_MAX_DETAILS_BYTES {
323        return Some(details);
324    }
325    Some(Value::String(format!(
326        "[truncated error.details: {} bytes]",
327        bytes.len()
328    )))
329}
330
331fn bound_headers(headers: HashMap<String, String>) -> Option<HashMap<String, String>> {
332    if headers.is_empty() || API_ERROR_METADATA_MAX_HEADERS_BYTES == 0 {
333        return None;
334    }
335
336    let mut remaining = API_ERROR_METADATA_MAX_HEADERS_BYTES;
337    let mut bounded = HashMap::new();
338
339    for (name, value) in headers {
340        let required = name.len() + value.len();
341        if required > remaining {
342            continue;
343        }
344        remaining -= required;
345        bounded.insert(name, value);
346    }
347
348    (!bounded.is_empty()).then_some(bounded)
349}
350
351fn header_map_to_hash_map(headers: &reqwest::header::HeaderMap) -> Option<HashMap<String, String>> {
352    let mut map = HashMap::new();
353    for (name, value) in headers {
354        let Ok(value_str) = value.to_str() else {
355            continue;
356        };
357        map.entry(name.as_str().to_string())
358            .and_modify(|existing: &mut String| {
359                if !existing.is_empty() {
360                    existing.push_str(", ");
361                }
362                existing.push_str(value_str);
363            })
364            .or_insert_with(|| value_str.to_string());
365    }
366    (!map.is_empty()).then_some(map)
367}
368
369fn retry_after_secs(headers: &reqwest::header::HeaderMap) -> Option<u64> {
370    let retry_after = headers
371        .get(reqwest::header::RETRY_AFTER)
372        .and_then(|value| value.to_str().ok())?
373        .trim();
374
375    retry_after.parse::<u64>().ok().or_else(|| {
376        httpdate::parse_http_date(retry_after).ok().map(|deadline| {
377            deadline
378                .duration_since(SystemTime::now())
379                .unwrap_or_default()
380                .as_secs()
381        })
382    })
383}
384
385fn parse_google_error(body: &str, status: u16) -> (String, Option<String>, Option<Value>) {
386    let fallback = if body.trim().is_empty() {
387        StatusCode::from_u16(status)
388            .ok()
389            .and_then(|code| code.canonical_reason().map(str::to_string))
390            .unwrap_or_else(|| format!("HTTP {status}"))
391    } else {
392        body.to_string()
393    };
394
395    let Ok(value) = serde_json::from_str::<Value>(body) else {
396        return (fallback, None, None);
397    };
398    let Some(error) = value.get("error") else {
399        return (fallback, None, None);
400    };
401
402    let message = error
403        .get("message")
404        .and_then(Value::as_str)
405        .map(str::to_string)
406        .unwrap_or(fallback);
407    let code = error
408        .get("status")
409        .and_then(Value::as_str)
410        .map(str::to_string)
411        .or_else(|| {
412            error
413                .get("code")
414                .and_then(Value::as_i64)
415                .map(|value| value.to_string())
416        });
417    let details = error.get("details").cloned();
418
419    (message, code, details)
420}
421
422pub type Result<T> = std::result::Result<T, Error>;
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427    use reqwest::header::{HeaderMap, HeaderValue, RETRY_AFTER};
428    use serde_json::json;
429    use std::time::SystemTime;
430
431    #[test]
432    fn parse_google_error_extracts_metadata() {
433        let body = json!({
434            "error": {
435                "message": "quota exceeded",
436                "status": "RESOURCE_EXHAUSTED",
437                "details": [{"kind": "quota"}]
438            }
439        })
440        .to_string();
441        let (message, code, details) = parse_google_error(&body, 429);
442
443        assert_eq!(message, "quota exceeded");
444        assert_eq!(code.as_deref(), Some("RESOURCE_EXHAUSTED"));
445        assert_eq!(details, Some(json!([{"kind": "quota"}])));
446    }
447
448    #[test]
449    fn parse_google_error_falls_back_to_body() {
450        let body = "plain-text failure";
451        let (message, code, details) = parse_google_error(body, 500);
452
453        assert_eq!(message, body);
454        assert!(code.is_none());
455        assert!(details.is_none());
456    }
457
458    #[test]
459    fn api_error_accessors_cover_defaults() {
460        let err =
461            Error::api_error_with_retryable(503, "unavailable", default_retryable_status(503));
462        assert_eq!(err.status(), Some(StatusCode::SERVICE_UNAVAILABLE));
463        assert_eq!(err.code(), None);
464        assert_eq!(err.details(), None);
465        assert_eq!(err.headers(), None);
466        assert_eq!(err.body(), None);
467        assert_eq!(err.attempts(), None);
468        assert_eq!(err.retry_after(), None);
469        assert!(err.is_retryable());
470        assert!(!err.is_rate_limited());
471
472        let bad_request =
473            Error::api_error_with_retryable(400, "bad request", default_retryable_status(400));
474        assert_eq!(bad_request.status(), Some(StatusCode::BAD_REQUEST));
475        assert!(!bad_request.is_retryable());
476
477        let terminal = Error::api_error_with_retryable(500, "terminal", false);
478        assert_eq!(terminal.status(), Some(StatusCode::INTERNAL_SERVER_ERROR));
479        assert!(!terminal.is_retryable());
480    }
481
482    #[test]
483    fn api_error_public_shape_stays_constructible() {
484        let err = Error::ApiError {
485            status: 418,
486            message: "teapot".into(),
487        };
488
489        assert_eq!(err.status(), Some(StatusCode::IM_A_TEAPOT));
490        assert_eq!(err.code(), None);
491        assert_eq!(err.details(), None);
492        assert_eq!(err.headers(), None);
493        assert_eq!(err.body(), None);
494        assert_eq!(err.attempts(), None);
495        assert_eq!(err.retry_after(), None);
496        assert!(!err.is_retryable());
497    }
498
499    #[test]
500    fn accessors_are_empty_for_non_api_errors() {
501        let err = Error::Parse {
502            message: "boom".into(),
503        };
504        assert_eq!(err.status(), None);
505        assert_eq!(err.code(), None);
506        assert_eq!(err.details(), None);
507        assert_eq!(err.headers(), None);
508        assert_eq!(err.body(), None);
509        assert_eq!(err.attempts(), None);
510        assert_eq!(err.retry_after(), None);
511        assert!(!err.is_retryable());
512        assert!(!err.is_rate_limited());
513    }
514
515    #[test]
516    fn header_helpers_collect_values_and_retry_after() {
517        let mut headers = HeaderMap::new();
518        headers.insert("x-test", HeaderValue::from_static("a"));
519        headers.append("x-test", HeaderValue::from_static("b"));
520        headers.insert(RETRY_AFTER, HeaderValue::from_static("7"));
521
522        let flattened = header_map_to_hash_map(&headers).unwrap();
523        assert_eq!(flattened.get("x-test").map(String::as_str), Some("a, b"));
524        assert_eq!(retry_after_secs(&headers), Some(7));
525    }
526
527    #[test]
528    fn retry_after_secs_parses_http_date() {
529        let mut headers = HeaderMap::new();
530        let deadline = SystemTime::now() + Duration::from_secs(60);
531        let header = httpdate::fmt_http_date(deadline);
532        headers.insert(RETRY_AFTER, HeaderValue::from_str(&header).unwrap());
533
534        let retry_after = retry_after_secs(&headers).unwrap();
535        assert!((58..=60).contains(&retry_after));
536    }
537
538    #[test]
539    fn api_error_metadata_bounds_large_payloads() {
540        let headers = (0..64)
541            .map(|idx| (format!("x-{idx}"), "v".repeat(128)))
542            .collect::<HashMap<_, _>>();
543        let metadata = ApiErrorMetadata {
544            code: Some("RESOURCE_EXHAUSTED".into()),
545            details: Some(json!({ "payload": "x".repeat(API_ERROR_METADATA_MAX_DETAILS_BYTES) })),
546            headers: Some(headers),
547            body: Some("b".repeat(API_ERROR_METADATA_MAX_BODY_BYTES + 32)),
548            retry_after_secs: Some(7),
549            retryable: Some(true),
550            attempts: Some(2),
551        }
552        .bounded();
553
554        assert!(metadata.body.unwrap().len() <= API_ERROR_METADATA_MAX_BODY_BYTES);
555        assert!(
556            metadata
557                .headers
558                .unwrap()
559                .into_iter()
560                .map(|(name, value)| name.len() + value.len())
561                .sum::<usize>()
562                <= API_ERROR_METADATA_MAX_HEADERS_BYTES
563        );
564        assert!(matches!(metadata.details, Some(Value::String(_))));
565    }
566
567    #[test]
568    fn bound_headers_keeps_smaller_headers_after_large_entries() {
569        let headers = HashMap::from([
570            (
571                "x-large".to_string(),
572                "v".repeat(API_ERROR_METADATA_MAX_HEADERS_BYTES + 1),
573            ),
574            ("retry-after".to_string(), "7".to_string()),
575            ("x-small".to_string(), "ok".to_string()),
576        ]);
577
578        let bounded = bound_headers(headers).unwrap();
579        assert_eq!(bounded.get("retry-after").map(String::as_str), Some("7"));
580        assert_eq!(bounded.get("x-small").map(String::as_str), Some("ok"));
581        assert!(!bounded.contains_key("x-large"));
582    }
583
584    #[test]
585    fn api_error_metadata_registry_evicts_by_total_bytes() {
586        let mut registry = ApiErrorMetadataRegistry::default();
587        let first_key = ApiErrorKey::new(500, "first");
588
589        registry.insert(
590            first_key,
591            ApiErrorMetadata {
592                body: Some("a".repeat(API_ERROR_METADATA_MAX_BODY_BYTES)),
593                ..Default::default()
594            },
595        );
596
597        for idx in 0..96 {
598            registry.insert(
599                ApiErrorKey::new(500, &format!("entry-{idx}")),
600                ApiErrorMetadata {
601                    body: Some("b".repeat(API_ERROR_METADATA_MAX_BODY_BYTES)),
602                    ..Default::default()
603                },
604            );
605        }
606
607        assert!(registry.total_bytes <= API_ERROR_METADATA_MAX_TOTAL_BYTES);
608        assert!(registry.get(&first_key).is_none());
609    }
610}