Skip to main content

zai_rs/client/
http.rs

1//! # HTTP Client Implementation
2//!
3//! Provides a robust HTTP client for communicating with the Zhipu AI API.
4//! This module implements connection pooling, error handling, and
5//! request/response processing.
6//!
7//! ## Features
8//!
9//! - Connection Pooling - Reuses HTTP connections for better performance
10//! - Error Handling - Comprehensive error parsing and reporting
11//! - Authentication - Bearer token authentication support
12//! - Retry with Jitter - Automatic retry with exponential backoff and random
13//!   jitter
14//! - Sensitive Data Masking - Automatic masking of API keys in logs
15//! - Structured Logging - Uses tracing for detailed request/response logging
16//!
17//! ## Usage
18//!
19//! The `HttpClient` trait provides a standardized interface for making HTTP
20//! requests to the Zhipu AI API endpoints.
21//!
22//! # Retry Configuration
23//!
24//! The HTTP client supports configurable retry behavior:
25//!
26//! ```ignore
27//! use zai_rs::client::http::HttpClientConfig;
28//!
29//! let config = HttpClientConfig::builder()
30//!     .max_retries(5)
31//!     .timeout(Duration::from_secs(120))
32//!     .retry_delay(RetryDelay::exponential(Duration::from_millis(100), Duration::from_secs(10)))
33//!     .build();
34//! ```
35
36use std::{
37    sync::{Arc, OnceLock},
38    time::Duration,
39};
40
41use serde::Deserialize;
42use tracing::{debug, info, warn};
43
44use crate::client::error::{ZaiError, ZaiResult, mask_sensitive_info};
45
46#[derive(Debug, Deserialize)]
47
48struct ApiErrorEnvelope {
49    error: ApiError,
50}
51
52#[derive(Debug, Deserialize)]
53
54struct ApiError {
55    code: ErrorCode,
56
57    message: String,
58}
59
60#[derive(Debug, Deserialize)]
61#[serde(untagged)]
62enum ErrorCode {
63    Str(String),
64
65    Num(i64),
66}
67
68impl std::fmt::Display for ErrorCode {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        match self {
71            ErrorCode::Str(s) => write!(f, "{}", s),
72
73            ErrorCode::Num(n) => write!(f, "{}", n),
74        }
75    }
76}
77
78fn to_api_code(code: &ErrorCode) -> u16 {
79    match code {
80        ErrorCode::Num(n) => (*n).try_into().unwrap_or(0),
81        ErrorCode::Str(s) => s.parse::<u16>().unwrap_or(0),
82    }
83}
84
85/// Parse an API error response body into a ZaiError.
86///
87/// Attempts to deserialize the body as `{"error":{"code":...,"message":...}}`
88/// and maps it to the appropriate ZaiError variant. Falls back to a generic
89/// HttpError if parsing fails.
90pub fn parse_api_error_response(status: u16, body: String) -> crate::client::error::ZaiError {
91    if let Ok(parsed) = serde_json::from_str::<ApiErrorEnvelope>(&body) {
92        let api_code = to_api_code(&parsed.error.code);
93        crate::client::error::ZaiError::from_api_response(status, api_code, parsed.error.message)
94    } else {
95        crate::client::error::ZaiError::from_api_response(status, 0, body)
96    }
97}
98
99/// Retry delay strategy.
100#[derive(Debug, Clone, Copy, PartialEq, Eq)]
101pub enum RetryDelay {
102    /// Fixed delay between retries
103    Fixed(Duration),
104
105    /// Exponential backoff with jitter
106    Exponential { base: Duration, max: Duration },
107
108    /// No delay (not recommended for production)
109    None,
110}
111
112impl RetryDelay {
113    /// Create a fixed delay strategy
114    pub fn fixed(delay: Duration) -> Self {
115        Self::Fixed(delay)
116    }
117
118    /// Create an exponential backoff strategy
119    pub fn exponential(base: Duration, max: Duration) -> Self {
120        Self::Exponential { base, max }
121    }
122
123    /// Create a no-delay strategy (not recommended)
124    pub fn none() -> Self {
125        Self::None
126    }
127}
128
129impl Default for RetryDelay {
130    fn default() -> Self {
131        Self::Exponential {
132            base: Duration::from_millis(500),
133            max: Duration::from_secs(5),
134        }
135    }
136}
137
138/// Configuration for HTTP client behavior.
139///
140/// Use the builder pattern for fluent configuration:
141///
142/// ```ignore
143/// use zai_rs::client::http::HttpClientConfig;
144///
145/// let config = HttpClientConfig::builder()
146///     .max_retries(5)
147///     .timeout(Duration::from_secs(120))
148///     .retry_delay(RetryDelay::exponential(Duration::from_millis(100), Duration::from_secs(10)))
149///     .enable_logging(true)
150///     .build();
151/// ```
152#[derive(Debug, Clone)]
153pub struct HttpClientConfig {
154    /// Request timeout duration (default: 60 seconds)
155    pub timeout: Duration,
156
157    /// Maximum number of retry attempts (default: 3)
158    pub max_retries: u32,
159
160    /// Enable gzip compression (default: true)
161    pub enable_compression: bool,
162
163    /// Retry delay strategy
164    pub retry_delay: RetryDelay,
165
166    /// Enable detailed logging (default: false)
167    pub enable_logging: bool,
168
169    /// Enable sensitive data masking in logs (default: true)
170    pub mask_sensitive_data: bool,
171}
172
173impl Default for HttpClientConfig {
174    fn default() -> Self {
175        Self {
176            timeout: Duration::from_secs(60),
177            max_retries: 3,
178            enable_compression: true,
179            retry_delay: RetryDelay::default(),
180            enable_logging: false,
181            mask_sensitive_data: true,
182        }
183    }
184}
185
186impl HttpClientConfig {
187    /// Create a new builder for fluent configuration
188    pub fn builder() -> HttpClientConfigBuilder {
189        HttpClientConfigBuilder::new()
190    }
191}
192
193/// Builder for creating `HttpClientConfig` instances.
194///
195/// Provides a fluent API for configuring HTTP client behavior.
196///
197/// # Example
198///
199/// ```ignore
200/// use zai_rs::client::http::HttpClientConfig;
201///
202/// let config = HttpClientConfig::builder()
203///     .max_retries(5)
204///     .timeout(Duration::from_secs(120))
205///     .retry_delay(RetryDelay::exponential(Duration::from_millis(100), Duration::from_secs(10)))
206///     .build();
207/// ```
208pub struct HttpClientConfigBuilder {
209    config: HttpClientConfig,
210}
211
212impl HttpClientConfigBuilder {
213    /// Create a new builder with default configuration
214    pub fn new() -> Self {
215        Self {
216            config: HttpClientConfig::default(),
217        }
218    }
219
220    /// Set the request timeout duration
221    pub fn timeout(mut self, timeout: Duration) -> Self {
222        self.config.timeout = timeout;
223        self
224    }
225
226    /// Set the maximum number of retry attempts
227    pub fn max_retries(mut self, max_retries: u32) -> Self {
228        self.config.max_retries = max_retries;
229        self
230    }
231
232    /// Enable or disable gzip compression
233    pub fn compression(mut self, enable: bool) -> Self {
234        self.config.enable_compression = enable;
235        self
236    }
237
238    /// Set the retry delay strategy
239    pub fn retry_delay(mut self, delay: RetryDelay) -> Self {
240        self.config.retry_delay = delay;
241        self
242    }
243
244    /// Enable or disable detailed logging
245    pub fn logging(mut self, enable: bool) -> Self {
246        self.config.enable_logging = enable;
247        self
248    }
249
250    /// Enable or disable sensitive data masking in logs
251    pub fn mask_sensitive_data(mut self, enable: bool) -> Self {
252        self.config.mask_sensitive_data = enable;
253        self
254    }
255
256    /// Build the configuration
257    pub fn build(self) -> HttpClientConfig {
258        self.config
259    }
260}
261
262impl Default for HttpClientConfigBuilder {
263    fn default() -> Self {
264        Self::new()
265    }
266}
267
268/// A global HTTP client registry for connection pooling and configuration
269/// caching.
270static HTTP_CLIENTS: OnceLock<dashmap::DashMap<String, reqwest::Client>> = OnceLock::new();
271
272/// Get or create an HTTP client with the specified configuration
273///
274/// Clients are cached by configuration to allow connection reuse.
275pub fn http_client_with_config(config: &HttpClientConfig) -> reqwest::Client {
276    let config_key = format!(
277        "timeout:{:?}|compression:{}",
278        config.timeout, config.enable_compression
279    );
280
281    let clients = HTTP_CLIENTS.get_or_init(dashmap::DashMap::new);
282
283    clients
284        .entry(config_key)
285        .or_insert_with(|| {
286            let builder = reqwest::Client::builder().timeout(config.timeout);
287
288            // Note: reqwest enables gzip compression by default
289            // if config.enable_compression {
290            //     builder = builder.gzip(true);
291            // }
292
293            builder.build().expect("Failed to build reqwest Client")
294        })
295        .clone()
296}
297
298/// Trait for HTTP clients that communicate with the Zhipu AI API.
299pub trait HttpClient {
300    type Body: serde::Serialize;
301    type ApiUrl: AsRef<str>;
302    type ApiKey: AsRef<str>;
303
304    fn api_url(&self) -> &Self::ApiUrl;
305    fn api_key(&self) -> &Self::ApiKey;
306    fn body(&self) -> &Self::Body;
307
308    /// Get HTTP client configuration for this request
309    ///
310    /// Override this method to provide custom configuration.
311    /// Default implementation returns default configuration.
312    fn http_config(&self) -> Arc<HttpClientConfig> {
313        static DEFAULT: std::sync::OnceLock<Arc<HttpClientConfig>> = std::sync::OnceLock::new();
314        DEFAULT
315            .get_or_init(|| Arc::new(HttpClientConfig::default()))
316            .clone()
317    }
318
319    /// Sends a POST request to the API endpoint.
320    ///
321    /// This method implements retry logic with exponential backoff and jitter.
322    /// It supports configuration through `http_config` method.
323    fn post(&self) -> impl std::future::Future<Output = ZaiResult<reqwest::Response>> + Send {
324        let body_compact =
325            serde_json::to_string(self.body()).map_err(|e| ZaiError::JsonError(Arc::new(e)));
326
327        let config = self.http_config().clone();
328        let enable_logging = config.enable_logging;
329        let mask_sensitive = config.mask_sensitive_data;
330
331        let body_pretty_opt = if enable_logging {
332            match serde_json::to_string_pretty(self.body()) {
333                Ok(pretty) => Some(pretty),
334                Err(e) => {
335                    warn!("Failed to pretty-print request body: {}", e);
336                    None
337                },
338            }
339        } else {
340            None
341        };
342
343        let url = self.api_url().as_ref().to_owned();
344        let key = self.api_key().as_ref().to_owned();
345
346        async move {
347            let body = body_compact?;
348
349            if enable_logging {
350                let log_body = if mask_sensitive {
351                    mask_sensitive_info(body.as_str())
352                } else {
353                    body.clone()
354                };
355                if let Some(pretty) = body_pretty_opt {
356                    let log_pretty = if mask_sensitive {
357                        mask_sensitive_info(&pretty)
358                    } else {
359                        pretty
360                    };
361                    info!(request_body = %log_pretty, "Sending POST request");
362                } else {
363                    debug!(request_body = %log_body, "Sending POST request");
364                }
365            }
366
367            let client = http_client_with_config(&config);
368            let request_builder = client
369                .post(&url)
370                .bearer_auth(&key)
371                .header("Content-Type", "application/json")
372                .body(body);
373
374            send_with_retry(request_builder, &config).await
375        }
376    }
377
378    /// Sends a GET request to the API endpoint.
379    ///
380    /// This method implements retry logic with exponential backoff and jitter.
381    /// It supports configuration through the `http_config` method.
382    fn get(&self) -> impl std::future::Future<Output = ZaiResult<reqwest::Response>> + Send {
383        let config = self.http_config().clone();
384        let url = self.api_url().as_ref().to_owned();
385        let key = self.api_key().as_ref().to_owned();
386
387        async move {
388            let client = http_client_with_config(&config);
389            let request_builder = client.get(&url).bearer_auth(&key);
390            send_with_retry(request_builder, &config).await
391        }
392    }
393}
394
395/// Internal helper: executes a request with retry logic.
396///
397/// This function encapsulates the common retry loop shared by both POST and
398/// GET methods, avoiding code duplication.
399async fn send_with_retry(
400    request_builder: reqwest::RequestBuilder,
401    config: &HttpClientConfig,
402) -> ZaiResult<reqwest::Response> {
403    let mut last_error: Option<ZaiError> = None;
404
405    // Extract request parts so we can rebuild for each retry attempt.
406    let req = request_builder.build()?;
407    let url = req.url().clone();
408    let method = req.method().clone();
409    let headers = req.headers().clone();
410    let body_bytes = req.body().and_then(|b| b.as_bytes().map(|b| b.to_vec()));
411    // Reuse a client built from the same config (preserves timeout, TLS, etc.)
412    let client = http_client_with_config(config);
413
414    for attempt in 0..=config.max_retries {
415        let mut builder = client
416            .request(method.clone(), url.clone())
417            .headers(headers.clone());
418
419        if let Some(ref body) = body_bytes {
420            builder = builder.body(body.clone());
421        }
422
423        let resp = builder.send().await;
424
425        match resp {
426            Ok(resp) => {
427                let status = resp.status();
428
429                if status.is_success() {
430                    debug!(http_status = %status, "Request succeeded");
431                    return Ok(resp);
432                }
433
434                let text = resp.text().await.unwrap_or_default();
435                let error = parse_api_error_response(status.as_u16(), text);
436
437                if should_retry(&error, attempt, config.max_retries) {
438                    last_error = Some(error.clone());
439                    let delay = calculate_retry_delay(attempt, &config.retry_delay);
440                    let delay_with_jitter = add_jitter(delay);
441                    warn!(
442                        attempt = attempt + 1,
443                        max_attempts = config.max_retries + 1,
444                        retry_delay = ?delay_with_jitter,
445                        error = %error.compact(),
446                        "Request failed, retrying"
447                    );
448                    tokio::time::sleep(delay_with_jitter).await;
449                } else {
450                    return Err(error);
451                }
452            },
453            Err(e) => {
454                let error = ZaiError::from(e);
455
456                if should_retry(&error, attempt, config.max_retries) {
457                    last_error = Some(error.clone());
458                    let delay = calculate_retry_delay(attempt, &config.retry_delay);
459                    let delay_with_jitter = add_jitter(delay);
460                    warn!(
461                        attempt = attempt + 1,
462                        max_attempts = config.max_retries + 1,
463                        retry_delay = ?delay_with_jitter,
464                        error = %error.compact(),
465                        "Request failed, retrying"
466                    );
467                    tokio::time::sleep(delay_with_jitter).await;
468                } else {
469                    return Err(error);
470                }
471            },
472        }
473    }
474
475    Err(last_error.unwrap_or_else(|| ZaiError::HttpError {
476        status: 500,
477        message: "Unknown error after retries".to_string(),
478    }))
479}
480
481/// Calculate delay for a retry attempt based on retry delay strategy.
482fn calculate_retry_delay(attempt: u32, strategy: &RetryDelay) -> Duration {
483    match strategy {
484        RetryDelay::Fixed(delay) => *delay,
485        RetryDelay::Exponential { base, max } => {
486            let delay = *base * 2u32.pow(attempt.min(10));
487            delay.min(*max)
488        },
489        RetryDelay::None => Duration::ZERO,
490    }
491}
492
493/// Determines if an error should trigger a retry.
494fn should_retry(error: &ZaiError, attempt: u32, max_retries: u32) -> bool {
495    if attempt >= max_retries {
496        return false;
497    }
498
499    match error {
500        // Retry on server errors (5xx)
501        ZaiError::HttpError { status, .. } => (500..600).contains(status),
502        // Retry on rate limit errors (API code 1301)
503        ZaiError::RateLimitError { .. } => true,
504        // Retry on network errors
505        ZaiError::NetworkError(_) => true,
506        // Don't retry on client errors (4xx), auth errors, account errors, etc.
507        _ => false,
508    }
509}
510
511/// Adds jitter to delay to avoid thundering herd.
512fn add_jitter(delay: Duration) -> Duration {
513    let jitter_ms = fastrand::u64(0..=delay.as_millis() as u64 / 4);
514    delay + Duration::from_millis(jitter_ms)
515}
516
517#[cfg(test)]
518mod tests {
519    use super::*;
520
521    #[test]
522    fn test_error_code_display_num() {
523        let code = ErrorCode::Num(123);
524        assert_eq!(format!("{}", code), "123");
525    }
526
527    #[test]
528    fn test_error_code_display_str() {
529        let code = ErrorCode::Str("auth_error".to_string());
530        assert_eq!(format!("{}", code), "auth_error");
531    }
532
533    #[test]
534    fn test_to_api_code_num() {
535        let code = ErrorCode::Num(401);
536        assert_eq!(to_api_code(&code), 401);
537    }
538
539    #[test]
540    fn test_to_api_code_str_valid() {
541        let code = ErrorCode::Str("429".to_string());
542        assert_eq!(to_api_code(&code), 429);
543    }
544
545    #[test]
546    fn test_to_api_code_str_invalid() {
547        let code = ErrorCode::Str("invalid".to_string());
548        assert_eq!(to_api_code(&code), 0);
549    }
550
551    #[test]
552    fn test_to_api_code_num_overflow() {
553        let code = ErrorCode::Num(99999);
554        assert_eq!(to_api_code(&code), 0);
555    }
556
557    #[test]
558    fn test_api_error_envelope_deserialize() {
559        let json = r#"{"error":{"code":401,"message":"Unauthorized"}}"#;
560        let envelope: ApiErrorEnvelope = serde_json::from_str(json).unwrap();
561        assert_eq!(envelope.error.message, "Unauthorized");
562    }
563
564    #[test]
565    fn test_api_error_envelope_deserialize_str_code() {
566        let json = r#"{"error":{"code":"1300","message":"Rate limit exceeded"}}"#;
567        let envelope: ApiErrorEnvelope = serde_json::from_str(json).unwrap();
568        assert_eq!(envelope.error.message, "Rate limit exceeded");
569        assert_eq!(to_api_code(&envelope.error.code), 1300);
570    }
571
572    #[test]
573    fn test_calculate_retry_delay_fixed() {
574        let delay = Duration::from_secs(2);
575        let strategy = RetryDelay::Fixed(delay);
576        assert_eq!(calculate_retry_delay(0, &strategy), delay);
577        assert_eq!(calculate_retry_delay(1, &strategy), delay);
578        assert_eq!(calculate_retry_delay(5, &strategy), delay);
579    }
580
581    #[test]
582    fn test_calculate_retry_delay_exponential() {
583        let base = Duration::from_millis(500);
584        let max = Duration::from_secs(5);
585        let strategy = RetryDelay::Exponential { base, max };
586
587        assert_eq!(
588            calculate_retry_delay(0, &strategy),
589            Duration::from_millis(500)
590        );
591        assert_eq!(
592            calculate_retry_delay(1, &strategy),
593            Duration::from_millis(1000)
594        );
595        assert_eq!(
596            calculate_retry_delay(2, &strategy),
597            Duration::from_millis(2000)
598        );
599        assert_eq!(
600            calculate_retry_delay(3, &strategy),
601            Duration::from_millis(4000)
602        );
603        assert_eq!(calculate_retry_delay(4, &strategy), max);
604        assert_eq!(calculate_retry_delay(10, &strategy), max);
605    }
606
607    #[test]
608    fn test_calculate_retry_delay_none() {
609        let strategy = RetryDelay::None;
610        assert_eq!(calculate_retry_delay(0, &strategy), Duration::ZERO);
611        assert_eq!(calculate_retry_delay(5, &strategy), Duration::ZERO);
612    }
613
614    #[test]
615    fn test_add_jitter() {
616        let delay = Duration::from_millis(1000);
617        let with_jitter = add_jitter(delay);
618
619        // Jitter should be between 0 and 25% of the delay
620        assert!(with_jitter >= delay);
621        assert!(with_jitter <= delay + Duration::from_millis(250));
622    }
623
624    #[test]
625    fn test_should_retry_server_error() {
626        let error = ZaiError::HttpError {
627            status: 500,
628            message: "Internal server error".to_string(),
629        };
630        assert!(should_retry(&error, 0, 3));
631        assert!(should_retry(&error, 2, 3));
632        assert!(!should_retry(&error, 3, 3));
633    }
634
635    #[test]
636    fn test_should_retry_gateway_timeout() {
637        let error = ZaiError::HttpError {
638            status: 504,
639            message: "Gateway timeout".to_string(),
640        };
641        assert!(should_retry(&error, 0, 3));
642    }
643
644    #[test]
645    fn test_should_retry_rate_limit() {
646        let error = ZaiError::RateLimitError {
647            code: 1301,
648            message: "Rate limit exceeded".to_string(),
649        };
650        assert!(should_retry(&error, 0, 3));
651    }
652
653    #[test]
654    fn test_should_retry_network_error() {
655        // Since we can't construct reqwest::Error directly in tests,
656        // simulate network error behavior with a 503 status
657        let error = ZaiError::HttpError {
658            status: 503,
659            message: "Network error".to_string(),
660        };
661        assert!(should_retry(&error, 0, 3));
662    }
663
664    #[test]
665    fn test_should_not_retry_client_error() {
666        let error = ZaiError::HttpError {
667            status: 400,
668            message: "Bad request".to_string(),
669        };
670        assert!(!should_retry(&error, 0, 3));
671    }
672
673    #[test]
674    fn test_should_not_retry_unauthorized() {
675        let error = ZaiError::AuthError {
676            code: 1001,
677            message: "Invalid API key".to_string(),
678        };
679        assert!(!should_retry(&error, 0, 3));
680    }
681
682    #[test]
683    fn test_should_not_retry_account_error() {
684        let error = ZaiError::AccountError {
685            code: 1110,
686            message: "Account not found".to_string(),
687        };
688        assert!(!should_retry(&error, 0, 3));
689    }
690
691    #[test]
692    fn test_should_not_retry_not_found() {
693        let error = ZaiError::HttpError {
694            status: 404,
695            message: "Resource not found".to_string(),
696        };
697        assert!(!should_retry(&error, 0, 3));
698    }
699
700    #[test]
701    fn test_http_client_config_default() {
702        let config = HttpClientConfig::default();
703        assert_eq!(config.timeout, Duration::from_secs(60));
704        assert_eq!(config.max_retries, 3);
705        assert!(config.enable_compression);
706        matches!(config.retry_delay, RetryDelay::Exponential { .. });
707    }
708
709    #[test]
710    fn test_retry_delay_default() {
711        let delay = RetryDelay::default();
712        matches!(delay, RetryDelay::Exponential { base, max } if base == Duration::from_millis(500) && max == Duration::from_secs(5));
713    }
714}