1use std::{
37 sync::{Arc, OnceLock},
38 time::Duration,
39};
40
41use serde::Deserialize;
42use tracing::{debug, info, warn};
43
44use crate::client::error::{ZaiError, ZaiResult, mask_sensitive_info};
45
46#[derive(Debug, Deserialize)]
47
48struct ApiErrorEnvelope {
49 error: ApiError,
50}
51
52#[derive(Debug, Deserialize)]
53
54struct ApiError {
55 code: ErrorCode,
56
57 message: String,
58}
59
60#[derive(Debug, Deserialize)]
61#[serde(untagged)]
62enum ErrorCode {
63 Str(String),
64
65 Num(i64),
66}
67
68impl std::fmt::Display for ErrorCode {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 match self {
71 ErrorCode::Str(s) => write!(f, "{}", s),
72
73 ErrorCode::Num(n) => write!(f, "{}", n),
74 }
75 }
76}
77
78fn to_api_code(code: &ErrorCode) -> u16 {
79 match code {
80 ErrorCode::Num(n) => (*n).try_into().unwrap_or(0),
81 ErrorCode::Str(s) => s.parse::<u16>().unwrap_or(0),
82 }
83}
84
85pub fn parse_api_error_response(status: u16, body: String) -> crate::client::error::ZaiError {
91 if let Ok(parsed) = serde_json::from_str::<ApiErrorEnvelope>(&body) {
92 let api_code = to_api_code(&parsed.error.code);
93 crate::client::error::ZaiError::from_api_response(status, api_code, parsed.error.message)
94 } else {
95 crate::client::error::ZaiError::from_api_response(status, 0, body)
96 }
97}
98
99#[derive(Debug, Clone, Copy, PartialEq, Eq)]
101pub enum RetryDelay {
102 Fixed(Duration),
104
105 Exponential { base: Duration, max: Duration },
107
108 None,
110}
111
112impl RetryDelay {
113 pub fn fixed(delay: Duration) -> Self {
115 Self::Fixed(delay)
116 }
117
118 pub fn exponential(base: Duration, max: Duration) -> Self {
120 Self::Exponential { base, max }
121 }
122
123 pub fn none() -> Self {
125 Self::None
126 }
127}
128
129impl Default for RetryDelay {
130 fn default() -> Self {
131 Self::Exponential {
132 base: Duration::from_millis(500),
133 max: Duration::from_secs(5),
134 }
135 }
136}
137
138#[derive(Debug, Clone)]
153pub struct HttpClientConfig {
154 pub timeout: Duration,
156
157 pub max_retries: u32,
159
160 pub enable_compression: bool,
162
163 pub retry_delay: RetryDelay,
165
166 pub enable_logging: bool,
168
169 pub mask_sensitive_data: bool,
171}
172
173impl Default for HttpClientConfig {
174 fn default() -> Self {
175 Self {
176 timeout: Duration::from_secs(60),
177 max_retries: 3,
178 enable_compression: true,
179 retry_delay: RetryDelay::default(),
180 enable_logging: false,
181 mask_sensitive_data: true,
182 }
183 }
184}
185
186impl HttpClientConfig {
187 pub fn builder() -> HttpClientConfigBuilder {
189 HttpClientConfigBuilder::new()
190 }
191}
192
193pub struct HttpClientConfigBuilder {
209 config: HttpClientConfig,
210}
211
212impl HttpClientConfigBuilder {
213 pub fn new() -> Self {
215 Self {
216 config: HttpClientConfig::default(),
217 }
218 }
219
220 pub fn timeout(mut self, timeout: Duration) -> Self {
222 self.config.timeout = timeout;
223 self
224 }
225
226 pub fn max_retries(mut self, max_retries: u32) -> Self {
228 self.config.max_retries = max_retries;
229 self
230 }
231
232 pub fn compression(mut self, enable: bool) -> Self {
234 self.config.enable_compression = enable;
235 self
236 }
237
238 pub fn retry_delay(mut self, delay: RetryDelay) -> Self {
240 self.config.retry_delay = delay;
241 self
242 }
243
244 pub fn logging(mut self, enable: bool) -> Self {
246 self.config.enable_logging = enable;
247 self
248 }
249
250 pub fn mask_sensitive_data(mut self, enable: bool) -> Self {
252 self.config.mask_sensitive_data = enable;
253 self
254 }
255
256 pub fn build(self) -> HttpClientConfig {
258 self.config
259 }
260}
261
262impl Default for HttpClientConfigBuilder {
263 fn default() -> Self {
264 Self::new()
265 }
266}
267
268static HTTP_CLIENTS: OnceLock<dashmap::DashMap<String, reqwest::Client>> = OnceLock::new();
271
272pub fn http_client_with_config(config: &HttpClientConfig) -> reqwest::Client {
276 let config_key = format!(
277 "timeout:{:?}|compression:{}",
278 config.timeout, config.enable_compression
279 );
280
281 let clients = HTTP_CLIENTS.get_or_init(dashmap::DashMap::new);
282
283 clients
284 .entry(config_key)
285 .or_insert_with(|| {
286 let builder = reqwest::Client::builder().timeout(config.timeout);
287
288 builder.build().expect("Failed to build reqwest Client")
294 })
295 .clone()
296}
297
298pub trait HttpClient {
300 type Body: serde::Serialize;
301 type ApiUrl: AsRef<str>;
302 type ApiKey: AsRef<str>;
303
304 fn api_url(&self) -> &Self::ApiUrl;
305 fn api_key(&self) -> &Self::ApiKey;
306 fn body(&self) -> &Self::Body;
307
308 fn http_config(&self) -> Arc<HttpClientConfig> {
313 static DEFAULT: std::sync::OnceLock<Arc<HttpClientConfig>> = std::sync::OnceLock::new();
314 DEFAULT
315 .get_or_init(|| Arc::new(HttpClientConfig::default()))
316 .clone()
317 }
318
319 fn post(&self) -> impl std::future::Future<Output = ZaiResult<reqwest::Response>> + Send {
324 let body_compact =
325 serde_json::to_string(self.body()).map_err(|e| ZaiError::JsonError(Arc::new(e)));
326
327 let config = self.http_config().clone();
328 let enable_logging = config.enable_logging;
329 let mask_sensitive = config.mask_sensitive_data;
330
331 let body_pretty_opt = if enable_logging {
332 match serde_json::to_string_pretty(self.body()) {
333 Ok(pretty) => Some(pretty),
334 Err(e) => {
335 warn!("Failed to pretty-print request body: {}", e);
336 None
337 },
338 }
339 } else {
340 None
341 };
342
343 let url = self.api_url().as_ref().to_owned();
344 let key = self.api_key().as_ref().to_owned();
345
346 async move {
347 let body = body_compact?;
348
349 if enable_logging {
350 let log_body = if mask_sensitive {
351 mask_sensitive_info(body.as_str())
352 } else {
353 body.clone()
354 };
355 if let Some(pretty) = body_pretty_opt {
356 let log_pretty = if mask_sensitive {
357 mask_sensitive_info(&pretty)
358 } else {
359 pretty
360 };
361 info!(request_body = %log_pretty, "Sending POST request");
362 } else {
363 debug!(request_body = %log_body, "Sending POST request");
364 }
365 }
366
367 let client = http_client_with_config(&config);
368 let request_builder = client
369 .post(&url)
370 .bearer_auth(&key)
371 .header("Content-Type", "application/json")
372 .body(body);
373
374 send_with_retry(request_builder, &config).await
375 }
376 }
377
378 fn get(&self) -> impl std::future::Future<Output = ZaiResult<reqwest::Response>> + Send {
383 let config = self.http_config().clone();
384 let url = self.api_url().as_ref().to_owned();
385 let key = self.api_key().as_ref().to_owned();
386
387 async move {
388 let client = http_client_with_config(&config);
389 let request_builder = client.get(&url).bearer_auth(&key);
390 send_with_retry(request_builder, &config).await
391 }
392 }
393}
394
395async fn send_with_retry(
400 request_builder: reqwest::RequestBuilder,
401 config: &HttpClientConfig,
402) -> ZaiResult<reqwest::Response> {
403 let mut last_error: Option<ZaiError> = None;
404
405 let req = request_builder.build()?;
407 let url = req.url().clone();
408 let method = req.method().clone();
409 let headers = req.headers().clone();
410 let body_bytes = req.body().and_then(|b| b.as_bytes().map(|b| b.to_vec()));
411 let client = http_client_with_config(config);
413
414 for attempt in 0..=config.max_retries {
415 let mut builder = client
416 .request(method.clone(), url.clone())
417 .headers(headers.clone());
418
419 if let Some(ref body) = body_bytes {
420 builder = builder.body(body.clone());
421 }
422
423 let resp = builder.send().await;
424
425 match resp {
426 Ok(resp) => {
427 let status = resp.status();
428
429 if status.is_success() {
430 debug!(http_status = %status, "Request succeeded");
431 return Ok(resp);
432 }
433
434 let text = resp.text().await.unwrap_or_default();
435 let error = parse_api_error_response(status.as_u16(), text);
436
437 if should_retry(&error, attempt, config.max_retries) {
438 last_error = Some(error.clone());
439 let delay = calculate_retry_delay(attempt, &config.retry_delay);
440 let delay_with_jitter = add_jitter(delay);
441 warn!(
442 attempt = attempt + 1,
443 max_attempts = config.max_retries + 1,
444 retry_delay = ?delay_with_jitter,
445 error = %error.compact(),
446 "Request failed, retrying"
447 );
448 tokio::time::sleep(delay_with_jitter).await;
449 } else {
450 return Err(error);
451 }
452 },
453 Err(e) => {
454 let error = ZaiError::from(e);
455
456 if should_retry(&error, attempt, config.max_retries) {
457 last_error = Some(error.clone());
458 let delay = calculate_retry_delay(attempt, &config.retry_delay);
459 let delay_with_jitter = add_jitter(delay);
460 warn!(
461 attempt = attempt + 1,
462 max_attempts = config.max_retries + 1,
463 retry_delay = ?delay_with_jitter,
464 error = %error.compact(),
465 "Request failed, retrying"
466 );
467 tokio::time::sleep(delay_with_jitter).await;
468 } else {
469 return Err(error);
470 }
471 },
472 }
473 }
474
475 Err(last_error.unwrap_or_else(|| ZaiError::HttpError {
476 status: 500,
477 message: "Unknown error after retries".to_string(),
478 }))
479}
480
481fn calculate_retry_delay(attempt: u32, strategy: &RetryDelay) -> Duration {
483 match strategy {
484 RetryDelay::Fixed(delay) => *delay,
485 RetryDelay::Exponential { base, max } => {
486 let delay = *base * 2u32.pow(attempt.min(10));
487 delay.min(*max)
488 },
489 RetryDelay::None => Duration::ZERO,
490 }
491}
492
493fn should_retry(error: &ZaiError, attempt: u32, max_retries: u32) -> bool {
495 if attempt >= max_retries {
496 return false;
497 }
498
499 match error {
500 ZaiError::HttpError { status, .. } => (500..600).contains(status),
502 ZaiError::RateLimitError { .. } => true,
504 ZaiError::NetworkError(_) => true,
506 _ => false,
508 }
509}
510
511fn add_jitter(delay: Duration) -> Duration {
513 let jitter_ms = fastrand::u64(0..=delay.as_millis() as u64 / 4);
514 delay + Duration::from_millis(jitter_ms)
515}
516
517#[cfg(test)]
518mod tests {
519 use super::*;
520
521 #[test]
522 fn test_error_code_display_num() {
523 let code = ErrorCode::Num(123);
524 assert_eq!(format!("{}", code), "123");
525 }
526
527 #[test]
528 fn test_error_code_display_str() {
529 let code = ErrorCode::Str("auth_error".to_string());
530 assert_eq!(format!("{}", code), "auth_error");
531 }
532
533 #[test]
534 fn test_to_api_code_num() {
535 let code = ErrorCode::Num(401);
536 assert_eq!(to_api_code(&code), 401);
537 }
538
539 #[test]
540 fn test_to_api_code_str_valid() {
541 let code = ErrorCode::Str("429".to_string());
542 assert_eq!(to_api_code(&code), 429);
543 }
544
545 #[test]
546 fn test_to_api_code_str_invalid() {
547 let code = ErrorCode::Str("invalid".to_string());
548 assert_eq!(to_api_code(&code), 0);
549 }
550
551 #[test]
552 fn test_to_api_code_num_overflow() {
553 let code = ErrorCode::Num(99999);
554 assert_eq!(to_api_code(&code), 0);
555 }
556
557 #[test]
558 fn test_api_error_envelope_deserialize() {
559 let json = r#"{"error":{"code":401,"message":"Unauthorized"}}"#;
560 let envelope: ApiErrorEnvelope = serde_json::from_str(json).unwrap();
561 assert_eq!(envelope.error.message, "Unauthorized");
562 }
563
564 #[test]
565 fn test_api_error_envelope_deserialize_str_code() {
566 let json = r#"{"error":{"code":"1300","message":"Rate limit exceeded"}}"#;
567 let envelope: ApiErrorEnvelope = serde_json::from_str(json).unwrap();
568 assert_eq!(envelope.error.message, "Rate limit exceeded");
569 assert_eq!(to_api_code(&envelope.error.code), 1300);
570 }
571
572 #[test]
573 fn test_calculate_retry_delay_fixed() {
574 let delay = Duration::from_secs(2);
575 let strategy = RetryDelay::Fixed(delay);
576 assert_eq!(calculate_retry_delay(0, &strategy), delay);
577 assert_eq!(calculate_retry_delay(1, &strategy), delay);
578 assert_eq!(calculate_retry_delay(5, &strategy), delay);
579 }
580
581 #[test]
582 fn test_calculate_retry_delay_exponential() {
583 let base = Duration::from_millis(500);
584 let max = Duration::from_secs(5);
585 let strategy = RetryDelay::Exponential { base, max };
586
587 assert_eq!(
588 calculate_retry_delay(0, &strategy),
589 Duration::from_millis(500)
590 );
591 assert_eq!(
592 calculate_retry_delay(1, &strategy),
593 Duration::from_millis(1000)
594 );
595 assert_eq!(
596 calculate_retry_delay(2, &strategy),
597 Duration::from_millis(2000)
598 );
599 assert_eq!(
600 calculate_retry_delay(3, &strategy),
601 Duration::from_millis(4000)
602 );
603 assert_eq!(calculate_retry_delay(4, &strategy), max);
604 assert_eq!(calculate_retry_delay(10, &strategy), max);
605 }
606
607 #[test]
608 fn test_calculate_retry_delay_none() {
609 let strategy = RetryDelay::None;
610 assert_eq!(calculate_retry_delay(0, &strategy), Duration::ZERO);
611 assert_eq!(calculate_retry_delay(5, &strategy), Duration::ZERO);
612 }
613
614 #[test]
615 fn test_add_jitter() {
616 let delay = Duration::from_millis(1000);
617 let with_jitter = add_jitter(delay);
618
619 assert!(with_jitter >= delay);
621 assert!(with_jitter <= delay + Duration::from_millis(250));
622 }
623
624 #[test]
625 fn test_should_retry_server_error() {
626 let error = ZaiError::HttpError {
627 status: 500,
628 message: "Internal server error".to_string(),
629 };
630 assert!(should_retry(&error, 0, 3));
631 assert!(should_retry(&error, 2, 3));
632 assert!(!should_retry(&error, 3, 3));
633 }
634
635 #[test]
636 fn test_should_retry_gateway_timeout() {
637 let error = ZaiError::HttpError {
638 status: 504,
639 message: "Gateway timeout".to_string(),
640 };
641 assert!(should_retry(&error, 0, 3));
642 }
643
644 #[test]
645 fn test_should_retry_rate_limit() {
646 let error = ZaiError::RateLimitError {
647 code: 1301,
648 message: "Rate limit exceeded".to_string(),
649 };
650 assert!(should_retry(&error, 0, 3));
651 }
652
653 #[test]
654 fn test_should_retry_network_error() {
655 let error = ZaiError::HttpError {
658 status: 503,
659 message: "Network error".to_string(),
660 };
661 assert!(should_retry(&error, 0, 3));
662 }
663
664 #[test]
665 fn test_should_not_retry_client_error() {
666 let error = ZaiError::HttpError {
667 status: 400,
668 message: "Bad request".to_string(),
669 };
670 assert!(!should_retry(&error, 0, 3));
671 }
672
673 #[test]
674 fn test_should_not_retry_unauthorized() {
675 let error = ZaiError::AuthError {
676 code: 1001,
677 message: "Invalid API key".to_string(),
678 };
679 assert!(!should_retry(&error, 0, 3));
680 }
681
682 #[test]
683 fn test_should_not_retry_account_error() {
684 let error = ZaiError::AccountError {
685 code: 1110,
686 message: "Account not found".to_string(),
687 };
688 assert!(!should_retry(&error, 0, 3));
689 }
690
691 #[test]
692 fn test_should_not_retry_not_found() {
693 let error = ZaiError::HttpError {
694 status: 404,
695 message: "Resource not found".to_string(),
696 };
697 assert!(!should_retry(&error, 0, 3));
698 }
699
700 #[test]
701 fn test_http_client_config_default() {
702 let config = HttpClientConfig::default();
703 assert_eq!(config.timeout, Duration::from_secs(60));
704 assert_eq!(config.max_retries, 3);
705 assert!(config.enable_compression);
706 matches!(config.retry_delay, RetryDelay::Exponential { .. });
707 }
708
709 #[test]
710 fn test_retry_delay_default() {
711 let delay = RetryDelay::default();
712 matches!(delay, RetryDelay::Exponential { base, max } if base == Duration::from_millis(500) && max == Duration::from_secs(5));
713 }
714}