Skip to main content

xai_rust/
error.rs

1//! Error types for the xAI SDK.
2
3/// Result type alias using the SDK's error type.
4pub type Result<T> = std::result::Result<T, Error>;
5
6/// Errors that can occur when using the xAI SDK.
7#[derive(Debug, thiserror::Error)]
8pub enum Error {
9    /// HTTP request failed.
10    #[error("HTTP request failed: {0}")]
11    Http(#[from] reqwest::Error),
12
13    /// JSON serialization/deserialization failed.
14    #[error("JSON error: {0}")]
15    Json(#[from] serde_json::Error),
16
17    /// URL parsing failed.
18    #[error("Invalid URL: {0}")]
19    Url(#[from] url::ParseError),
20
21    /// API returned an error response.
22    #[error("API error ({status}): {message}")]
23    Api {
24        /// HTTP status code.
25        status: u16,
26        /// Error message from the API.
27        message: String,
28        /// Error type/code if provided.
29        error_type: Option<String>,
30    },
31
32    /// Authentication failed.
33    #[error("Authentication failed: {0}")]
34    Authentication(String),
35
36    /// Rate limit exceeded.
37    #[error(
38        "{}",
39        format_rate_limited_error(retry_after.as_ref().copied(), message.as_deref())
40    )]
41    RateLimited {
42        /// Suggested retry delay in seconds.
43        retry_after: Option<u64>,
44        /// Optional API/fallback message for debugging.
45        message: Option<String>,
46    },
47
48    /// Request timeout.
49    #[error("Request timed out")]
50    Timeout,
51
52    /// Invalid request parameters.
53    #[error("Invalid request: {0}")]
54    InvalidRequest(String),
55
56    /// WebSocket error for realtime API.
57    #[cfg(feature = "realtime")]
58    #[error("WebSocket error: {0}")]
59    WebSocket(Box<tokio_tungstenite::tungstenite::Error>),
60
61    /// Stream parsing error.
62    #[error("Stream error: {0}")]
63    Stream(String),
64
65    /// Base64 decoding error.
66    #[error("Base64 decode error: {0}")]
67    Base64(#[from] base64::DecodeError),
68
69    /// Configuration error.
70    #[error("Configuration error: {0}")]
71    Config(String),
72
73    /// I/O error.
74    #[error("I/O error: {0}")]
75    Io(#[from] std::io::Error),
76}
77
78#[cfg(feature = "realtime")]
79impl From<tokio_tungstenite::tungstenite::Error> for Error {
80    fn from(err: tokio_tungstenite::tungstenite::Error) -> Self {
81        Error::WebSocket(Box::new(err))
82    }
83}
84
85/// API error response structure.
86#[derive(Debug, serde::Deserialize)]
87pub(crate) struct ApiErrorResponse {
88    /// The error details.
89    pub error: ApiErrorDetail,
90}
91
92/// API error detail.
93#[derive(Debug, serde::Deserialize)]
94#[allow(dead_code)]
95pub(crate) struct ApiErrorDetail {
96    /// The error message.
97    pub message: String,
98    /// The error type/code.
99    #[serde(rename = "type")]
100    pub error_type: Option<String>,
101    /// Optional error code.
102    pub code: Option<String>,
103}
104
105const ERROR_BODY_SNIPPET_LIMIT: usize = 4096;
106
107fn format_rate_limited_error(retry_after: Option<u64>, message: Option<&str>) -> String {
108    let mut out = match retry_after {
109        Some(seconds) => format!("Rate limit exceeded. Retry after {seconds} seconds."),
110        None => "Rate limit exceeded.".to_string(),
111    };
112
113    if let Some(message) = message.filter(|m| !m.trim().is_empty()) {
114        out.push_str(" Server message: ");
115        out.push_str(message);
116    }
117
118    out
119}
120
121fn body_snippet(body: &str) -> Option<String> {
122    let trimmed = body.trim();
123    if trimmed.is_empty() {
124        return None;
125    }
126
127    let mut chars = trimmed.chars();
128    let snippet: String = chars.by_ref().take(ERROR_BODY_SNIPPET_LIMIT).collect();
129
130    if chars.next().is_some() {
131        Some(format!("{snippet}...[truncated]"))
132    } else {
133        Some(snippet)
134    }
135}
136
137impl Error {
138    /// Create an API error from an HTTP response.
139    pub async fn from_response(response: reqwest::Response) -> Self {
140        let status = response.status().as_u16();
141        let retry_after = response
142            .headers()
143            .get("retry-after")
144            .and_then(|v| v.to_str().ok())
145            .and_then(|v| v.parse().ok());
146
147        let body_text = match response.bytes().await {
148            Ok(bytes) => String::from_utf8_lossy(&bytes).into_owned(),
149            Err(_) => String::new(),
150        };
151
152        let parsed_error = serde_json::from_str::<ApiErrorResponse>(&body_text).ok();
153
154        // Check for rate limiting
155        if status == 429 {
156            let message = parsed_error
157                .as_ref()
158                .map(|api_error| api_error.error.message.clone())
159                .or_else(|| body_snippet(&body_text));
160
161            return Error::RateLimited {
162                retry_after,
163                message,
164            };
165        }
166
167        // Try to parse error body
168        match parsed_error {
169            Some(api_error) => Error::Api {
170                status,
171                message: api_error.error.message,
172                error_type: api_error.error.error_type,
173            },
174            None => Error::Api {
175                status,
176                message: body_snippet(&body_text)
177                    .map(|snippet| format!("HTTP {status}: {snippet}"))
178                    .unwrap_or_else(|| format!("HTTP {status}")),
179                error_type: None,
180            },
181        }
182    }
183
184    /// Check if this is a retryable error.
185    pub fn is_retryable(&self) -> bool {
186        match self {
187            Error::RateLimited { .. } => true,
188            Error::Timeout => true,
189            Error::Api { status, .. } => {
190                // Server errors are typically retryable
191                *status >= 500 && *status < 600
192            }
193            Error::Http(e) => e.is_timeout() || e.is_connect(),
194            _ => false,
195        }
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    // ── is_retryable ──────────────────────────────────────────────────
204
205    #[test]
206    fn rate_limited_is_retryable() {
207        let err = Error::RateLimited {
208            retry_after: Some(30),
209            message: None,
210        };
211        assert!(err.is_retryable());
212    }
213
214    #[test]
215    fn rate_limited_without_retry_after_is_retryable() {
216        let err = Error::RateLimited {
217            retry_after: None,
218            message: None,
219        };
220        assert!(err.is_retryable());
221    }
222
223    #[test]
224    fn timeout_is_retryable() {
225        let err = Error::Timeout;
226        assert!(err.is_retryable());
227    }
228
229    #[test]
230    fn api_500_is_retryable() {
231        let err = Error::Api {
232            status: 500,
233            message: "Internal Server Error".to_string(),
234            error_type: None,
235        };
236        assert!(err.is_retryable());
237    }
238
239    #[test]
240    fn api_502_is_retryable() {
241        let err = Error::Api {
242            status: 502,
243            message: "Bad Gateway".to_string(),
244            error_type: None,
245        };
246        assert!(err.is_retryable());
247    }
248
249    #[test]
250    fn api_503_is_retryable() {
251        let err = Error::Api {
252            status: 503,
253            message: "Service Unavailable".to_string(),
254            error_type: None,
255        };
256        assert!(err.is_retryable());
257    }
258
259    #[test]
260    fn api_400_is_not_retryable() {
261        let err = Error::Api {
262            status: 400,
263            message: "Bad Request".to_string(),
264            error_type: None,
265        };
266        assert!(!err.is_retryable());
267    }
268
269    #[test]
270    fn api_401_is_not_retryable() {
271        let err = Error::Api {
272            status: 401,
273            message: "Unauthorized".to_string(),
274            error_type: None,
275        };
276        assert!(!err.is_retryable());
277    }
278
279    #[test]
280    fn api_403_is_not_retryable() {
281        let err = Error::Api {
282            status: 403,
283            message: "Forbidden".to_string(),
284            error_type: None,
285        };
286        assert!(!err.is_retryable());
287    }
288
289    #[test]
290    fn api_404_is_not_retryable() {
291        let err = Error::Api {
292            status: 404,
293            message: "Not Found".to_string(),
294            error_type: None,
295        };
296        assert!(!err.is_retryable());
297    }
298
299    #[test]
300    fn api_422_is_not_retryable() {
301        let err = Error::Api {
302            status: 422,
303            message: "Unprocessable Entity".to_string(),
304            error_type: Some("validation_error".to_string()),
305        };
306        assert!(!err.is_retryable());
307    }
308
309    #[test]
310    fn authentication_error_is_not_retryable() {
311        let err = Error::Authentication("Invalid token".to_string());
312        assert!(!err.is_retryable());
313    }
314
315    #[test]
316    fn invalid_request_is_not_retryable() {
317        let err = Error::InvalidRequest("Missing field".to_string());
318        assert!(!err.is_retryable());
319    }
320
321    #[test]
322    fn json_error_is_not_retryable() {
323        let serde_err = serde_json::from_str::<serde_json::Value>("bad json").unwrap_err();
324        let err = Error::Json(serde_err);
325        assert!(!err.is_retryable());
326    }
327
328    #[test]
329    fn stream_error_is_not_retryable() {
330        let err = Error::Stream("connection reset".to_string());
331        assert!(!err.is_retryable());
332    }
333
334    #[test]
335    fn config_error_is_not_retryable() {
336        let err = Error::Config("Missing API key".to_string());
337        assert!(!err.is_retryable());
338    }
339
340    // ── Error Display ──────────────────────────────────────────────────
341
342    #[test]
343    fn error_display_api() {
344        let err = Error::Api {
345            status: 404,
346            message: "Not Found".to_string(),
347            error_type: None,
348        };
349        assert_eq!(format!("{err}"), "API error (404): Not Found");
350    }
351
352    #[test]
353    fn error_display_rate_limited_with_retry_and_message() {
354        let err = Error::RateLimited {
355            retry_after: Some(60),
356            message: Some("Too many requests".to_string()),
357        };
358        let display = format!("{err}");
359        assert_eq!(
360            display,
361            "Rate limit exceeded. Retry after 60 seconds. Server message: Too many requests"
362        );
363        assert!(!display.contains("Some("));
364        assert!(!display.contains("None"));
365    }
366
367    #[test]
368    fn error_display_rate_limited_without_retry_or_message() {
369        let err = Error::RateLimited {
370            retry_after: None,
371            message: None,
372        };
373        let display = format!("{err}");
374        assert_eq!(display, "Rate limit exceeded.");
375        assert!(!display.contains("Some("));
376        assert!(!display.contains("None"));
377    }
378
379    #[tokio::test]
380    async fn from_response_parses_api_error_body() {
381        use wiremock::matchers::{method, path};
382        use wiremock::{Mock, MockServer, ResponseTemplate};
383
384        let server = MockServer::start().await;
385        Mock::given(method("GET"))
386            .and(path("/err-json"))
387            .respond_with(ResponseTemplate::new(400).set_body_json(serde_json::json!({
388                "error": {
389                    "message": "Model not found",
390                    "type": "invalid_request_error"
391                }
392            })))
393            .mount(&server)
394            .await;
395
396        let response = reqwest::Client::new()
397            .get(format!("{}/err-json", server.uri()))
398            .send()
399            .await
400            .unwrap();
401        let err = Error::from_response(response).await;
402
403        match err {
404            Error::Api {
405                status,
406                message,
407                error_type,
408            } => {
409                assert_eq!(status, 400);
410                assert_eq!(message, "Model not found");
411                assert_eq!(error_type.as_deref(), Some("invalid_request_error"));
412            }
413            _ => panic!("expected Error::Api"),
414        }
415    }
416
417    #[tokio::test]
418    async fn from_response_non_json_includes_snippet() {
419        use wiremock::matchers::{method, path};
420        use wiremock::{Mock, MockServer, ResponseTemplate};
421
422        let server = MockServer::start().await;
423        Mock::given(method("GET"))
424            .and(path("/err-text"))
425            .respond_with(ResponseTemplate::new(500).set_body_string("upstream exploded"))
426            .mount(&server)
427            .await;
428
429        let response = reqwest::Client::new()
430            .get(format!("{}/err-text", server.uri()))
431            .send()
432            .await
433            .unwrap();
434        let err = Error::from_response(response).await;
435
436        match err {
437            Error::Api {
438                status,
439                message,
440                error_type,
441            } => {
442                assert_eq!(status, 500);
443                assert!(message.contains("HTTP 500"));
444                assert!(message.contains("upstream exploded"));
445                assert!(error_type.is_none());
446            }
447            _ => panic!("expected Error::Api"),
448        }
449    }
450
451    #[tokio::test]
452    async fn from_response_non_json_truncates_long_body() {
453        use wiremock::matchers::{method, path};
454        use wiremock::{Mock, MockServer, ResponseTemplate};
455
456        let long_body = "a".repeat(ERROR_BODY_SNIPPET_LIMIT + 128);
457        let server = MockServer::start().await;
458        Mock::given(method("GET"))
459            .and(path("/err-long"))
460            .respond_with(ResponseTemplate::new(502).set_body_string(long_body))
461            .mount(&server)
462            .await;
463
464        let response = reqwest::Client::new()
465            .get(format!("{}/err-long", server.uri()))
466            .send()
467            .await
468            .unwrap();
469        let err = Error::from_response(response).await;
470
471        match err {
472            Error::Api {
473                status, message, ..
474            } => {
475                assert_eq!(status, 502);
476                assert!(message.contains("[truncated]"));
477            }
478            _ => panic!("expected Error::Api"),
479        }
480    }
481
482    #[tokio::test]
483    async fn from_response_429_includes_retry_after_and_message() {
484        use wiremock::matchers::{method, path};
485        use wiremock::{Mock, MockServer, ResponseTemplate};
486
487        let server = MockServer::start().await;
488        Mock::given(method("GET"))
489            .and(path("/err-429"))
490            .respond_with(
491                ResponseTemplate::new(429)
492                    .insert_header("retry-after", "7")
493                    .set_body_json(serde_json::json!({
494                        "error": {
495                            "message": "Too many requests"
496                        }
497                    })),
498            )
499            .mount(&server)
500            .await;
501
502        let response = reqwest::Client::new()
503            .get(format!("{}/err-429", server.uri()))
504            .send()
505            .await
506            .unwrap();
507        let err = Error::from_response(response).await;
508
509        match err {
510            Error::RateLimited {
511                retry_after,
512                message,
513            } => {
514                assert_eq!(retry_after, Some(7));
515                assert_eq!(message.as_deref(), Some("Too many requests"));
516            }
517            _ => panic!("expected Error::RateLimited"),
518        }
519    }
520
521    #[test]
522    fn error_display_timeout() {
523        let err = Error::Timeout;
524        assert_eq!(format!("{err}"), "Request timed out");
525    }
526
527    #[test]
528    fn error_display_authentication() {
529        let err = Error::Authentication("bad key".to_string());
530        assert_eq!(format!("{err}"), "Authentication failed: bad key");
531    }
532
533    #[test]
534    fn error_display_config() {
535        let err = Error::Config("missing key".to_string());
536        assert_eq!(format!("{err}"), "Configuration error: missing key");
537    }
538
539    #[test]
540    fn error_display_stream() {
541        let err = Error::Stream("parse failure".to_string());
542        assert_eq!(format!("{err}"), "Stream error: parse failure");
543    }
544
545    // ── ApiErrorResponse deserialization ───────────────────────────────
546
547    #[test]
548    fn api_error_response_deserialize() {
549        let json = serde_json::json!({
550            "error": {
551                "message": "Model not found",
552                "type": "invalid_request_error",
553                "code": "model_not_found"
554            }
555        });
556        let resp: ApiErrorResponse = serde_json::from_value(json).unwrap();
557        assert_eq!(resp.error.message, "Model not found");
558        assert_eq!(
559            resp.error.error_type.as_deref(),
560            Some("invalid_request_error")
561        );
562        assert_eq!(resp.error.code.as_deref(), Some("model_not_found"));
563    }
564
565    #[test]
566    fn api_error_response_minimal() {
567        let json = serde_json::json!({
568            "error": {
569                "message": "Something went wrong"
570            }
571        });
572        let resp: ApiErrorResponse = serde_json::from_value(json).unwrap();
573        assert_eq!(resp.error.message, "Something went wrong");
574        assert!(resp.error.error_type.is_none());
575    }
576}