tact_client/
http.rs

1//! HTTP client for TACT protocol
2
3use crate::{CdnEntry, Error, Region, Result, VersionEntry, response_types};
4use reqwest::{Client, Response};
5use std::time::Duration;
6use tokio::time::sleep;
7use tracing::{debug, trace, warn};
8
9/// Default maximum retries (0 = no retries, maintains backward compatibility)
10const DEFAULT_MAX_RETRIES: u32 = 0;
11
12/// Default initial backoff in milliseconds
13const DEFAULT_INITIAL_BACKOFF_MS: u64 = 100;
14
15/// Default maximum backoff in milliseconds
16const DEFAULT_MAX_BACKOFF_MS: u64 = 10_000;
17
18/// Default backoff multiplier
19const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0;
20
21/// Default jitter factor (0.0 to 1.0)
22const DEFAULT_JITTER_FACTOR: f64 = 0.1;
23
24/// TACT protocol version
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum ProtocolVersion {
27    /// Version 1: HTTP-based protocol on port 1119
28    V1,
29    /// Version 2: HTTPS-based REST API
30    V2,
31}
32
33/// HTTP client for TACT protocol
34#[derive(Debug, Clone)]
35pub struct HttpClient {
36    client: Client,
37    region: Region,
38    version: ProtocolVersion,
39    max_retries: u32,
40    initial_backoff_ms: u64,
41    max_backoff_ms: u64,
42    backoff_multiplier: f64,
43    jitter_factor: f64,
44    user_agent: Option<String>,
45}
46
47impl HttpClient {
48    /// Create a new HTTP client for the specified region and protocol version
49    pub fn new(region: Region, version: ProtocolVersion) -> Result<Self> {
50        let client = Client::builder().timeout(Duration::from_secs(30)).build()?;
51
52        Ok(Self {
53            client,
54            region,
55            version,
56            max_retries: DEFAULT_MAX_RETRIES,
57            initial_backoff_ms: DEFAULT_INITIAL_BACKOFF_MS,
58            max_backoff_ms: DEFAULT_MAX_BACKOFF_MS,
59            backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER,
60            jitter_factor: DEFAULT_JITTER_FACTOR,
61            user_agent: None,
62        })
63    }
64
65    /// Create a new HTTP client with custom reqwest client
66    pub fn with_client(client: Client, region: Region, version: ProtocolVersion) -> Self {
67        Self {
68            client,
69            region,
70            version,
71            max_retries: DEFAULT_MAX_RETRIES,
72            initial_backoff_ms: DEFAULT_INITIAL_BACKOFF_MS,
73            max_backoff_ms: DEFAULT_MAX_BACKOFF_MS,
74            backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER,
75            jitter_factor: DEFAULT_JITTER_FACTOR,
76            user_agent: None,
77        }
78    }
79
80    /// Set the maximum number of retries for failed requests
81    ///
82    /// Default is 0 (no retries) to maintain backward compatibility.
83    /// Only network and connection errors are retried, not parsing errors.
84    pub fn with_max_retries(mut self, max_retries: u32) -> Self {
85        self.max_retries = max_retries;
86        self
87    }
88
89    /// Set the initial backoff duration in milliseconds
90    ///
91    /// Default is 100ms. This is the base delay before the first retry.
92    pub fn with_initial_backoff_ms(mut self, initial_backoff_ms: u64) -> Self {
93        self.initial_backoff_ms = initial_backoff_ms;
94        self
95    }
96
97    /// Set the maximum backoff duration in milliseconds
98    ///
99    /// Default is 10,000ms (10 seconds). Backoff will not exceed this value.
100    pub fn with_max_backoff_ms(mut self, max_backoff_ms: u64) -> Self {
101        self.max_backoff_ms = max_backoff_ms;
102        self
103    }
104
105    /// Set the backoff multiplier
106    ///
107    /// Default is 2.0. The backoff duration is multiplied by this value after each retry.
108    pub fn with_backoff_multiplier(mut self, backoff_multiplier: f64) -> Self {
109        self.backoff_multiplier = backoff_multiplier;
110        self
111    }
112
113    /// Set the jitter factor (0.0 to 1.0)
114    ///
115    /// Default is 0.1 (10% jitter). Adds randomness to prevent thundering herd.
116    pub fn with_jitter_factor(mut self, jitter_factor: f64) -> Self {
117        self.jitter_factor = jitter_factor.clamp(0.0, 1.0);
118        self
119    }
120
121    /// Set a custom user agent string
122    ///
123    /// If not set, reqwest's default user agent will be used.
124    pub fn with_user_agent(mut self, user_agent: impl Into<String>) -> Self {
125        self.user_agent = Some(user_agent.into());
126        self
127    }
128
129    /// Get the base URL for the current configuration
130    pub fn base_url(&self) -> String {
131        match self.version {
132            ProtocolVersion::V1 => {
133                format!("http://{}.patch.battle.net:1119", self.region)
134            }
135            ProtocolVersion::V2 => {
136                format!("https://{}.version.battle.net/v2/products", self.region)
137            }
138        }
139    }
140
141    /// Get the current region
142    pub fn region(&self) -> Region {
143        self.region
144    }
145
146    /// Get the current protocol version
147    pub fn version(&self) -> ProtocolVersion {
148        self.version
149    }
150
151    /// Set the region
152    pub fn set_region(&mut self, region: Region) {
153        self.region = region;
154    }
155
156    /// Calculate backoff duration with exponential backoff and jitter
157    #[allow(
158        clippy::cast_precision_loss,
159        clippy::cast_possible_wrap,
160        clippy::cast_possible_truncation,
161        clippy::cast_sign_loss
162    )]
163    fn calculate_backoff(&self, attempt: u32) -> Duration {
164        let base_backoff =
165            self.initial_backoff_ms as f64 * self.backoff_multiplier.powi(attempt as i32);
166        let capped_backoff = base_backoff.min(self.max_backoff_ms as f64);
167
168        // Add jitter
169        let jitter_range = capped_backoff * self.jitter_factor;
170        let jitter = rand::random::<f64>() * 2.0 * jitter_range - jitter_range;
171        let final_backoff = (capped_backoff + jitter).max(0.0) as u64;
172
173        Duration::from_millis(final_backoff)
174    }
175
176    /// Execute an HTTP request with retry logic
177    async fn execute_with_retry(&self, url: &str) -> Result<Response> {
178        let mut last_error = None;
179
180        for attempt in 0..=self.max_retries {
181            if attempt > 0 {
182                let backoff = self.calculate_backoff(attempt - 1);
183                debug!("Retry attempt {} after {:?} backoff", attempt, backoff);
184                sleep(backoff).await;
185            }
186
187            debug!("HTTP request to {} (attempt {})", url, attempt + 1);
188
189            let mut request = self.client.get(url);
190            if let Some(ref user_agent) = self.user_agent {
191                request = request.header("User-Agent", user_agent);
192            }
193
194            match request.send().await {
195                Ok(response) => {
196                    trace!("Response status: {}", response.status());
197
198                    // Check if we should retry based on status code
199                    let status = response.status();
200                    if (status.is_server_error()
201                        || status == reqwest::StatusCode::TOO_MANY_REQUESTS)
202                        && attempt < self.max_retries
203                    {
204                        warn!(
205                            "Request returned {} (attempt {}): will retry",
206                            status,
207                            attempt + 1
208                        );
209                        last_error = Some(Error::InvalidResponse);
210                        continue;
211                    }
212
213                    return Ok(response);
214                }
215                Err(e) => {
216                    // Check if error is retryable
217                    let is_retryable = e.is_connect() || e.is_timeout() || e.is_request();
218
219                    if is_retryable && attempt < self.max_retries {
220                        warn!(
221                            "Request failed (attempt {}): {}, will retry",
222                            attempt + 1,
223                            e
224                        );
225                        last_error = Some(Error::Http(e));
226                    } else {
227                        // Non-retryable error or final attempt
228                        debug!(
229                            "Request failed (attempt {}): {}, not retrying",
230                            attempt + 1,
231                            e
232                        );
233                        return Err(Error::Http(e));
234                    }
235                }
236            }
237        }
238
239        // This should only be reached if all retries failed
240        Err(last_error.unwrap_or(Error::InvalidResponse))
241    }
242
243    /// Execute an HTTP request with additional headers and retry logic
244    async fn execute_with_retry_and_headers(
245        &self,
246        url: &str,
247        headers: &[(&str, &str)],
248    ) -> Result<Response> {
249        let mut last_error = None;
250
251        for attempt in 0..=self.max_retries {
252            if attempt > 0 {
253                let backoff = self.calculate_backoff(attempt - 1);
254                debug!("Retry attempt {} after {:?} backoff", attempt, backoff);
255                sleep(backoff).await;
256            }
257
258            debug!("HTTP request to {} (attempt {})", url, attempt + 1);
259
260            let mut request = self.client.get(url);
261            if let Some(ref user_agent) = self.user_agent {
262                request = request.header("User-Agent", user_agent);
263            }
264
265            // Add custom headers
266            for &(key, value) in headers {
267                request = request.header(key, value);
268            }
269
270            match request.send().await {
271                Ok(response) => {
272                    trace!("Response status: {}", response.status());
273
274                    // Check if we should retry based on status code
275                    let status = response.status();
276                    if (status.is_server_error()
277                        || status == reqwest::StatusCode::TOO_MANY_REQUESTS)
278                        && attempt < self.max_retries
279                    {
280                        warn!(
281                            "Request returned {} (attempt {}): will retry",
282                            status,
283                            attempt + 1
284                        );
285                        last_error = Some(Error::InvalidResponse);
286                        continue;
287                    }
288
289                    return Ok(response);
290                }
291                Err(e) => {
292                    // Check if error is retryable
293                    let is_retryable = e.is_connect() || e.is_timeout() || e.is_request();
294
295                    if is_retryable && attempt < self.max_retries {
296                        warn!(
297                            "Request failed (attempt {}): {}, will retry",
298                            attempt + 1,
299                            e
300                        );
301                        last_error = Some(Error::Http(e));
302                    } else {
303                        // Non-retryable error or final attempt
304                        debug!(
305                            "Request failed (attempt {}): {}, not retrying",
306                            attempt + 1,
307                            e
308                        );
309                        return Err(Error::Http(e));
310                    }
311                }
312            }
313        }
314
315        // This should only be reached if all retries failed
316        Err(last_error.unwrap_or(Error::InvalidResponse))
317    }
318
319    /// Get versions manifest for a product (V1 protocol)
320    pub async fn get_versions(&self, product: &str) -> Result<Response> {
321        if self.version != ProtocolVersion::V1 {
322            return Err(Error::InvalidProtocolVersion);
323        }
324
325        let url = format!("{}/{}/versions", self.base_url(), product);
326        self.execute_with_retry(&url).await
327    }
328
329    /// Get CDN configuration for a product (V1 protocol)
330    pub async fn get_cdns(&self, product: &str) -> Result<Response> {
331        if self.version != ProtocolVersion::V1 {
332            return Err(Error::InvalidProtocolVersion);
333        }
334
335        let url = format!("{}/{}/cdns", self.base_url(), product);
336        self.execute_with_retry(&url).await
337    }
338
339    /// Get BGDL manifest for a product (V1 protocol)
340    pub async fn get_bgdl(&self, product: &str) -> Result<Response> {
341        if self.version != ProtocolVersion::V1 {
342            return Err(Error::InvalidProtocolVersion);
343        }
344
345        let url = format!("{}/{}/bgdl", self.base_url(), product);
346        self.execute_with_retry(&url).await
347    }
348
349    /// Get product summary (V2 protocol)
350    pub async fn get_summary(&self) -> Result<Response> {
351        if self.version != ProtocolVersion::V2 {
352            return Err(Error::InvalidProtocolVersion);
353        }
354
355        let url = self.base_url();
356        self.execute_with_retry(&url).await
357    }
358
359    /// Get product details (V2 protocol)
360    pub async fn get_product(&self, product: &str) -> Result<Response> {
361        if self.version != ProtocolVersion::V2 {
362            return Err(Error::InvalidProtocolVersion);
363        }
364
365        let url = format!("{}/{}", self.base_url(), product);
366        self.execute_with_retry(&url).await
367    }
368
369    /// Make a raw GET request to a path
370    pub async fn get(&self, path: &str) -> Result<Response> {
371        let url = if path.starts_with('/') {
372            format!("{}{}", self.base_url(), path)
373        } else {
374            format!("{}/{}", self.base_url(), path)
375        };
376
377        self.execute_with_retry(&url).await
378    }
379
380    /// Download a file from CDN
381    pub async fn download_file(&self, cdn_host: &str, path: &str, hash: &str) -> Result<Response> {
382        let url = format!(
383            "http://{}/{}/{}/{}/{}",
384            cdn_host,
385            path,
386            &hash[0..2],
387            &hash[2..4],
388            hash
389        );
390
391        // Use execute_with_retry for CDN downloads as well
392        let response = self.execute_with_retry(&url).await?;
393
394        if response.status() == reqwest::StatusCode::NOT_FOUND {
395            return Err(Error::file_not_found(hash));
396        }
397
398        Ok(response)
399    }
400
401    /// Download a file from CDN with HTTP range request for partial content
402    ///
403    /// # Arguments
404    /// * `cdn_host` - CDN hostname
405    /// * `path` - Path prefix for the CDN
406    /// * `hash` - File hash
407    /// * `range` - Byte range to download (e.g., (0, Some(1023)) for first 1024 bytes)
408    ///
409    /// # Returns
410    /// Returns a response with the requested byte range. The response will have status 206
411    /// (Partial Content) if the range is supported, or status 200 (OK) with full content
412    /// if range requests are not supported.
413    pub async fn download_file_range(
414        &self,
415        cdn_host: &str,
416        path: &str,
417        hash: &str,
418        range: (u64, Option<u64>),
419    ) -> Result<Response> {
420        let url = format!(
421            "http://{}/{}/{}/{}/{}",
422            cdn_host,
423            path,
424            &hash[0..2],
425            &hash[2..4],
426            hash
427        );
428
429        // Build Range header value
430        let range_header = match range {
431            (start, Some(end)) => format!("bytes={}-{}", start, end),
432            (start, None) => format!("bytes={}-", start),
433        };
434
435        debug!("Range request: {} Range: {}", url, range_header);
436
437        let response = self
438            .execute_with_retry_and_headers(&url, &[("Range", &range_header)])
439            .await?;
440
441        if response.status() == reqwest::StatusCode::NOT_FOUND {
442            return Err(Error::file_not_found(hash));
443        }
444
445        // Check if server supports range requests
446        match response.status() {
447            reqwest::StatusCode::PARTIAL_CONTENT => {
448                trace!("Server returned partial content (206)");
449            }
450            reqwest::StatusCode::OK => {
451                warn!("Server returned full content (200) - range requests not supported");
452            }
453            status => {
454                warn!(
455                    "Unexpected status code for range request: {} (expected 206 or 200)",
456                    status
457                );
458                // Still return the response - let the caller handle unexpected status codes
459            }
460        }
461
462        Ok(response)
463    }
464
465    /// Download multiple ranges from a file in a single request
466    ///
467    /// # Arguments
468    /// * `cdn_host` - CDN hostname
469    /// * `path` - Path prefix for the CDN
470    /// * `hash` - File hash
471    /// * `ranges` - Multiple byte ranges to download
472    ///
473    /// # Note
474    /// Multi-range requests return multipart/byteranges content type that needs
475    /// special parsing. Use with caution - not all CDN servers support this.
476    pub async fn download_file_multirange(
477        &self,
478        cdn_host: &str,
479        path: &str,
480        hash: &str,
481        ranges: &[(u64, Option<u64>)],
482    ) -> Result<Response> {
483        let url = format!(
484            "http://{}/{}/{}/{}/{}",
485            cdn_host,
486            path,
487            &hash[0..2],
488            &hash[2..4],
489            hash
490        );
491
492        // Build multi-range header value
493        let mut range_specs = Vec::new();
494        for &(start, end) in ranges {
495            match end {
496                Some(end) => range_specs.push(format!("{}-{}", start, end)),
497                None => range_specs.push(format!("{}-", start)),
498            }
499        }
500        let range_header = format!("bytes={}", range_specs.join(", "));
501
502        debug!("Multi-range request: {} Range: {}", url, range_header);
503
504        let response = self
505            .execute_with_retry_and_headers(&url, &[("Range", &range_header)])
506            .await?;
507
508        if response.status() == reqwest::StatusCode::NOT_FOUND {
509            return Err(Error::file_not_found(hash));
510        }
511
512        Ok(response)
513    }
514
515    /// Get parsed versions manifest for a product
516    pub async fn get_versions_parsed(&self, product: &str) -> Result<Vec<VersionEntry>> {
517        let response = self.get_versions(product).await?;
518        let text = response.text().await?;
519        response_types::parse_versions(&text)
520    }
521
522    /// Get parsed CDN manifest for a product
523    pub async fn get_cdns_parsed(&self, product: &str) -> Result<Vec<CdnEntry>> {
524        let response = self.get_cdns(product).await?;
525        let text = response.text().await?;
526        response_types::parse_cdns(&text)
527    }
528
529    /// Get parsed BGDL manifest for a product
530    pub async fn get_bgdl_parsed(&self, product: &str) -> Result<Vec<response_types::BgdlEntry>> {
531        let response = self.get_bgdl(product).await?;
532        let text = response.text().await?;
533        response_types::parse_bgdl(&text)
534    }
535}
536
537impl Default for HttpClient {
538    fn default() -> Self {
539        Self::new(Region::US, ProtocolVersion::V2).expect("Failed to create default HTTP client")
540    }
541}
542
543#[cfg(test)]
544mod tests {
545    use super::*;
546
547    #[test]
548    fn test_base_url_v1() {
549        let client = HttpClient::new(Region::US, ProtocolVersion::V1).unwrap();
550        assert_eq!(client.base_url(), "http://us.patch.battle.net:1119");
551
552        let client = HttpClient::new(Region::EU, ProtocolVersion::V1).unwrap();
553        assert_eq!(client.base_url(), "http://eu.patch.battle.net:1119");
554    }
555
556    #[test]
557    fn test_base_url_v2() {
558        let client = HttpClient::new(Region::US, ProtocolVersion::V2).unwrap();
559        assert_eq!(
560            client.base_url(),
561            "https://us.version.battle.net/v2/products"
562        );
563
564        let client = HttpClient::new(Region::EU, ProtocolVersion::V2).unwrap();
565        assert_eq!(
566            client.base_url(),
567            "https://eu.version.battle.net/v2/products"
568        );
569    }
570
571    #[test]
572    fn test_region_setting() {
573        let mut client = HttpClient::new(Region::US, ProtocolVersion::V1).unwrap();
574        assert_eq!(client.region(), Region::US);
575
576        client.set_region(Region::EU);
577        assert_eq!(client.region(), Region::EU);
578        assert_eq!(client.base_url(), "http://eu.patch.battle.net:1119");
579    }
580
581    #[test]
582    fn test_retry_configuration() {
583        let client = HttpClient::new(Region::US, ProtocolVersion::V1)
584            .unwrap()
585            .with_max_retries(3)
586            .with_initial_backoff_ms(200)
587            .with_max_backoff_ms(5000)
588            .with_backoff_multiplier(1.5)
589            .with_jitter_factor(0.2);
590
591        assert_eq!(client.max_retries, 3);
592        assert_eq!(client.initial_backoff_ms, 200);
593        assert_eq!(client.max_backoff_ms, 5000);
594        assert_eq!(client.backoff_multiplier, 1.5);
595        assert_eq!(client.jitter_factor, 0.2);
596    }
597
598    #[test]
599    fn test_jitter_factor_clamping() {
600        let client1 = HttpClient::new(Region::US, ProtocolVersion::V1)
601            .unwrap()
602            .with_jitter_factor(1.5);
603        assert_eq!(client1.jitter_factor, 1.0); // Should be clamped to 1.0
604
605        let client2 = HttpClient::new(Region::US, ProtocolVersion::V1)
606            .unwrap()
607            .with_jitter_factor(-0.5);
608        assert_eq!(client2.jitter_factor, 0.0); // Should be clamped to 0.0
609    }
610
611    #[test]
612    fn test_backoff_calculation() {
613        let client = HttpClient::new(Region::US, ProtocolVersion::V1)
614            .unwrap()
615            .with_initial_backoff_ms(100)
616            .with_max_backoff_ms(1000)
617            .with_backoff_multiplier(2.0)
618            .with_jitter_factor(0.0); // No jitter for predictable test
619
620        // Test exponential backoff
621        let backoff0 = client.calculate_backoff(0);
622        assert_eq!(backoff0.as_millis(), 100); // 100ms * 2^0 = 100ms
623
624        let backoff1 = client.calculate_backoff(1);
625        assert_eq!(backoff1.as_millis(), 200); // 100ms * 2^1 = 200ms
626
627        let backoff2 = client.calculate_backoff(2);
628        assert_eq!(backoff2.as_millis(), 400); // 100ms * 2^2 = 400ms
629
630        // Test max backoff capping
631        let backoff5 = client.calculate_backoff(5);
632        assert_eq!(backoff5.as_millis(), 1000); // Would be 3200ms but capped at 1000ms
633    }
634
635    #[test]
636    fn test_default_retry_configuration() {
637        let client = HttpClient::new(Region::US, ProtocolVersion::V1).unwrap();
638        assert_eq!(client.max_retries, 0); // Default should be 0 for backward compatibility
639    }
640
641    #[test]
642    fn test_user_agent_configuration() {
643        let client = HttpClient::new(Region::US, ProtocolVersion::V1)
644            .unwrap()
645            .with_user_agent("MyCustomAgent/1.0");
646
647        assert_eq!(client.user_agent, Some("MyCustomAgent/1.0".to_string()));
648    }
649
650    #[test]
651    fn test_user_agent_default_none() {
652        let client = HttpClient::new(Region::US, ProtocolVersion::V1).unwrap();
653        assert!(client.user_agent.is_none());
654    }
655
656    // Range request tests
657    #[test]
658    fn test_range_request_header_formatting() {
659        // Test range header formatting
660        let range1 = (0, Some(1023));
661        let header1 = match range1 {
662            (start, Some(end)) => format!("bytes={}-{}", start, end),
663            (start, None) => format!("bytes={}-", start),
664        };
665        assert_eq!(header1, "bytes=0-1023");
666
667        let range2 = (1024, None::<u64>);
668        let header2 = match range2 {
669            (start, Some(end)) => format!("bytes={}-{}", start, end),
670            (start, None) => format!("bytes={}-", start),
671        };
672        assert_eq!(header2, "bytes=1024-");
673    }
674
675    #[test]
676    fn test_multirange_header_building() {
677        let ranges = [(0, Some(31)), (64, Some(95)), (128, None)];
678        let mut range_specs = Vec::new();
679
680        for &(start, end) in &ranges {
681            match end {
682                Some(end) => range_specs.push(format!("{}-{}", start, end)),
683                None => range_specs.push(format!("{}-", start)),
684            }
685        }
686
687        let range_header = format!("bytes={}", range_specs.join(", "));
688        assert_eq!(range_header, "bytes=0-31, 64-95, 128-");
689    }
690}