1use crate::{CdnEntry, Error, Region, Result, VersionEntry, response_types};
4use reqwest::{Client, Response};
5use std::time::Duration;
6use tokio::time::sleep;
7use tracing::{debug, trace, warn};
8
9const DEFAULT_MAX_RETRIES: u32 = 0;
11
12const DEFAULT_INITIAL_BACKOFF_MS: u64 = 100;
14
15const DEFAULT_MAX_BACKOFF_MS: u64 = 10_000;
17
18const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0;
20
21const DEFAULT_JITTER_FACTOR: f64 = 0.1;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum ProtocolVersion {
27 V1,
29 V2,
31}
32
33#[derive(Debug, Clone)]
35pub struct HttpClient {
36 client: Client,
37 region: Region,
38 version: ProtocolVersion,
39 max_retries: u32,
40 initial_backoff_ms: u64,
41 max_backoff_ms: u64,
42 backoff_multiplier: f64,
43 jitter_factor: f64,
44 user_agent: Option<String>,
45}
46
47impl HttpClient {
48 pub fn new(region: Region, version: ProtocolVersion) -> Result<Self> {
50 let client = Client::builder().timeout(Duration::from_secs(30)).build()?;
51
52 Ok(Self {
53 client,
54 region,
55 version,
56 max_retries: DEFAULT_MAX_RETRIES,
57 initial_backoff_ms: DEFAULT_INITIAL_BACKOFF_MS,
58 max_backoff_ms: DEFAULT_MAX_BACKOFF_MS,
59 backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER,
60 jitter_factor: DEFAULT_JITTER_FACTOR,
61 user_agent: None,
62 })
63 }
64
65 pub fn with_shared_pool(region: Region, version: ProtocolVersion) -> Self {
70 let client = crate::pool::get_global_pool().clone();
71
72 Self {
73 client,
74 region,
75 version,
76 max_retries: DEFAULT_MAX_RETRIES,
77 initial_backoff_ms: DEFAULT_INITIAL_BACKOFF_MS,
78 max_backoff_ms: DEFAULT_MAX_BACKOFF_MS,
79 backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER,
80 jitter_factor: DEFAULT_JITTER_FACTOR,
81 user_agent: None,
82 }
83 }
84
85 pub fn with_client(client: Client, region: Region, version: ProtocolVersion) -> Self {
87 Self {
88 client,
89 region,
90 version,
91 max_retries: DEFAULT_MAX_RETRIES,
92 initial_backoff_ms: DEFAULT_INITIAL_BACKOFF_MS,
93 max_backoff_ms: DEFAULT_MAX_BACKOFF_MS,
94 backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER,
95 jitter_factor: DEFAULT_JITTER_FACTOR,
96 user_agent: None,
97 }
98 }
99
100 pub fn with_max_retries(mut self, max_retries: u32) -> Self {
105 self.max_retries = max_retries;
106 self
107 }
108
109 pub fn with_initial_backoff_ms(mut self, initial_backoff_ms: u64) -> Self {
113 self.initial_backoff_ms = initial_backoff_ms;
114 self
115 }
116
117 pub fn with_max_backoff_ms(mut self, max_backoff_ms: u64) -> Self {
121 self.max_backoff_ms = max_backoff_ms;
122 self
123 }
124
125 pub fn with_backoff_multiplier(mut self, backoff_multiplier: f64) -> Self {
129 self.backoff_multiplier = backoff_multiplier;
130 self
131 }
132
133 pub fn with_jitter_factor(mut self, jitter_factor: f64) -> Self {
137 self.jitter_factor = jitter_factor.clamp(0.0, 1.0);
138 self
139 }
140
141 pub fn with_user_agent(mut self, user_agent: impl Into<String>) -> Self {
145 self.user_agent = Some(user_agent.into());
146 self
147 }
148
149 pub fn base_url(&self) -> String {
151 match self.version {
152 ProtocolVersion::V1 => {
153 format!("http://{}.patch.battle.net:1119", self.region)
154 }
155 ProtocolVersion::V2 => {
156 format!("https://{}.version.battle.net", self.region)
157 }
158 }
159 }
160
161 pub fn region(&self) -> Region {
163 self.region
164 }
165
166 pub fn version(&self) -> ProtocolVersion {
168 self.version
169 }
170
171 pub fn set_region(&mut self, region: Region) {
173 self.region = region;
174 }
175
176 #[allow(
178 clippy::cast_precision_loss,
179 clippy::cast_possible_wrap,
180 clippy::cast_possible_truncation,
181 clippy::cast_sign_loss
182 )]
183 fn calculate_backoff(&self, attempt: u32) -> Duration {
184 let base_backoff =
185 self.initial_backoff_ms as f64 * self.backoff_multiplier.powi(attempt as i32);
186 let capped_backoff = base_backoff.min(self.max_backoff_ms as f64);
187
188 let jitter_range = capped_backoff * self.jitter_factor;
190 let jitter = rand::random::<f64>() * 2.0 * jitter_range - jitter_range;
191 let final_backoff = (capped_backoff + jitter).max(0.0) as u64;
192
193 Duration::from_millis(final_backoff)
194 }
195
196 async fn execute_with_retry_internal(
198 &self,
199 url: &str,
200 headers: Option<&[(&str, &str)]>,
201 ) -> Result<Response> {
202 let mut last_error = None;
203
204 for attempt in 0..=self.max_retries {
205 if attempt > 0 {
206 let backoff = self.calculate_backoff(attempt - 1);
207 debug!("Retry attempt {} after {:?} backoff", attempt, backoff);
208 sleep(backoff).await;
209 }
210
211 debug!("HTTP request to {} (attempt {})", url, attempt + 1);
212
213 let mut request = self.client.get(url);
214 if let Some(ref user_agent) = self.user_agent {
215 request = request.header("User-Agent", user_agent);
216 }
217
218 if let Some(headers) = headers {
220 for &(key, value) in headers {
221 request = request.header(key, value);
222 }
223 }
224
225 match request.send().await {
226 Ok(response) => {
227 trace!("Response status: {}", response.status());
228
229 let status = response.status();
231 if (status.is_server_error()
232 || status == reqwest::StatusCode::TOO_MANY_REQUESTS)
233 && attempt < self.max_retries
234 {
235 warn!(
236 "Request returned {} (attempt {}): will retry",
237 status,
238 attempt + 1
239 );
240 last_error = Some(Error::InvalidResponse);
241 continue;
242 }
243
244 return Ok(response);
245 }
246 Err(e) => {
247 let is_retryable = e.is_connect() || e.is_timeout() || e.is_request();
249
250 if is_retryable && attempt < self.max_retries {
251 warn!(
252 "Request failed (attempt {}): {}, will retry",
253 attempt + 1,
254 e
255 );
256 last_error = Some(Error::Http(e));
257 } else {
258 debug!(
260 "Request failed (attempt {}): {}, not retrying",
261 attempt + 1,
262 e
263 );
264 return Err(Error::Http(e));
265 }
266 }
267 }
268 }
269
270 Err(last_error.unwrap_or(Error::InvalidResponse))
272 }
273
274 async fn execute_with_retry(&self, url: &str) -> Result<Response> {
276 self.execute_with_retry_internal(url, None).await
277 }
278
279 async fn execute_with_retry_and_headers(
281 &self,
282 url: &str,
283 headers: &[(&str, &str)],
284 ) -> Result<Response> {
285 self.execute_with_retry_internal(url, Some(headers)).await
286 }
287
288 pub async fn get_versions(&self, product: &str) -> Result<Response> {
290 if self.version != ProtocolVersion::V1 {
291 return Err(Error::InvalidProtocolVersion);
292 }
293
294 let url = format!("{}/{}/versions", self.base_url(), product);
295 self.execute_with_retry(&url).await
296 }
297
298 pub async fn get_cdns(&self, product: &str) -> Result<Response> {
300 if self.version != ProtocolVersion::V1 {
301 return Err(Error::InvalidProtocolVersion);
302 }
303
304 let url = format!("{}/{}/cdns", self.base_url(), product);
305 self.execute_with_retry(&url).await
306 }
307
308 pub async fn get_bgdl(&self, product: &str) -> Result<Response> {
310 if self.version != ProtocolVersion::V1 {
311 return Err(Error::InvalidProtocolVersion);
312 }
313
314 let url = format!("{}/{}/bgdl", self.base_url(), product);
315 self.execute_with_retry(&url).await
316 }
317
318 pub async fn get_summary(&self) -> Result<Response> {
320 if self.version != ProtocolVersion::V2 {
321 return Err(Error::InvalidProtocolVersion);
322 }
323
324 let url = self.base_url();
325 self.execute_with_retry(&url).await
326 }
327
328 pub async fn get_product(&self, product: &str) -> Result<Response> {
330 if self.version != ProtocolVersion::V2 {
331 return Err(Error::InvalidProtocolVersion);
332 }
333
334 let url = format!("{}/v2/products/{}", self.base_url(), product);
335 self.execute_with_retry(&url).await
336 }
337
338 pub async fn get_product_versions_http(&self, product: &str) -> Result<Response> {
341 if self.version != ProtocolVersion::V2 {
342 return Err(Error::InvalidProtocolVersion);
343 }
344
345 let url = format!("{}/{}/versions", self.base_url(), product);
346 debug!("Fetching product versions from HTTP endpoint: {}", url);
347 self.execute_with_retry(&url).await
348 }
349
350 pub async fn get_product_cdns_http(&self, product: &str) -> Result<Response> {
353 if self.version != ProtocolVersion::V2 {
354 return Err(Error::InvalidProtocolVersion);
355 }
356
357 let url = format!("{}/{}/cdns", self.base_url(), product);
358 debug!("Fetching CDN configuration from HTTP endpoint: {}", url);
359 self.execute_with_retry(&url).await
360 }
361
362 pub async fn get(&self, path: &str) -> Result<Response> {
364 let url = if path.starts_with('/') {
365 format!("{}{}", self.base_url(), path)
366 } else {
367 format!("{}/{}", self.base_url(), path)
368 };
369
370 self.execute_with_retry(&url).await
371 }
372
373 pub async fn download_file(&self, cdn_host: &str, path: &str, hash: &str) -> Result<Response> {
375 let url = format!(
376 "http://{}/{}/{}/{}/{}",
377 cdn_host,
378 path,
379 &hash[0..2],
380 &hash[2..4],
381 hash
382 );
383
384 let response = self.execute_with_retry(&url).await?;
386
387 if response.status() == reqwest::StatusCode::NOT_FOUND {
388 return Err(Error::file_not_found(hash));
389 }
390
391 Ok(response)
392 }
393
394 pub async fn download_file_range(
407 &self,
408 cdn_host: &str,
409 path: &str,
410 hash: &str,
411 range: (u64, Option<u64>),
412 ) -> Result<Response> {
413 let url = format!(
414 "http://{}/{}/{}/{}/{}",
415 cdn_host,
416 path,
417 &hash[0..2],
418 &hash[2..4],
419 hash
420 );
421
422 let range_header = match range {
424 (start, Some(end)) => format!("bytes={start}-{end}"),
425 (start, None) => format!("bytes={start}-"),
426 };
427
428 debug!("Range request: {} Range: {}", url, range_header);
429
430 let response = self
431 .execute_with_retry_and_headers(&url, &[("Range", &range_header)])
432 .await?;
433
434 if response.status() == reqwest::StatusCode::NOT_FOUND {
435 return Err(Error::file_not_found(hash));
436 }
437
438 match response.status() {
440 reqwest::StatusCode::PARTIAL_CONTENT => {
441 trace!("Server returned partial content (206)");
442 }
443 reqwest::StatusCode::OK => {
444 warn!("Server returned full content (200) - range requests not supported");
445 }
446 status => {
447 warn!(
448 "Unexpected status code for range request: {} (expected 206 or 200)",
449 status
450 );
451 }
453 }
454
455 Ok(response)
456 }
457
458 pub async fn download_file_multirange(
470 &self,
471 cdn_host: &str,
472 path: &str,
473 hash: &str,
474 ranges: &[(u64, Option<u64>)],
475 ) -> Result<Response> {
476 let url = format!(
477 "http://{}/{}/{}/{}/{}",
478 cdn_host,
479 path,
480 &hash[0..2],
481 &hash[2..4],
482 hash
483 );
484
485 let mut range_specs = Vec::new();
487 for &(start, end) in ranges {
488 match end {
489 Some(end) => range_specs.push(format!("{start}-{end}")),
490 None => range_specs.push(format!("{start}-")),
491 }
492 }
493 let range_header = format!("bytes={}", range_specs.join(", "));
494
495 debug!("Multi-range request: {} Range: {}", url, range_header);
496
497 let response = self
498 .execute_with_retry_and_headers(&url, &[("Range", &range_header)])
499 .await?;
500
501 if response.status() == reqwest::StatusCode::NOT_FOUND {
502 return Err(Error::file_not_found(hash));
503 }
504
505 Ok(response)
506 }
507
508 pub async fn get_versions_parsed(&self, product: &str) -> Result<Vec<VersionEntry>> {
510 let response = self.get_versions(product).await?;
511 let text = response.text().await?;
512 response_types::parse_versions(&text)
513 }
514
515 pub async fn get_cdns_parsed(&self, product: &str) -> Result<Vec<CdnEntry>> {
517 let response = self.get_cdns(product).await?;
518 let text = response.text().await?;
519 response_types::parse_cdns(&text)
520 }
521
522 pub async fn get_product_versions_http_parsed(
524 &self,
525 product: &str,
526 ) -> Result<Vec<VersionEntry>> {
527 let response = self.get_product_versions_http(product).await?;
528 let text = response.text().await?;
529 response_types::parse_versions(&text)
530 }
531
532 pub async fn get_product_cdns_http_parsed(&self, product: &str) -> Result<Vec<CdnEntry>> {
534 let response = self.get_product_cdns_http(product).await?;
535 let text = response.text().await?;
536 response_types::parse_cdns(&text)
537 }
538
539 pub async fn get_bgdl_parsed(&self, product: &str) -> Result<Vec<response_types::BgdlEntry>> {
541 let response = self.get_bgdl(product).await?;
542 let text = response.text().await?;
543 response_types::parse_bgdl(&text)
544 }
545}
546
547impl Default for HttpClient {
548 fn default() -> Self {
549 Self::new(Region::US, ProtocolVersion::V2).expect("Failed to create default HTTP client")
550 }
551}
552
553#[cfg(test)]
554mod tests {
555 use super::*;
556
557 #[test]
558 fn test_base_url_v1() {
559 let client = HttpClient::new(Region::US, ProtocolVersion::V1).unwrap();
560 assert_eq!(client.base_url(), "http://us.patch.battle.net:1119");
561
562 let client = HttpClient::new(Region::EU, ProtocolVersion::V1).unwrap();
563 assert_eq!(client.base_url(), "http://eu.patch.battle.net:1119");
564 }
565
566 #[test]
567 fn test_base_url_v2() {
568 let client = HttpClient::new(Region::US, ProtocolVersion::V2).unwrap();
569 assert_eq!(client.base_url(), "https://us.version.battle.net");
570
571 let client = HttpClient::new(Region::EU, ProtocolVersion::V2).unwrap();
572 assert_eq!(client.base_url(), "https://eu.version.battle.net");
573 }
574
575 #[test]
576 fn test_region_setting() {
577 let mut client = HttpClient::new(Region::US, ProtocolVersion::V1).unwrap();
578 assert_eq!(client.region(), Region::US);
579
580 client.set_region(Region::EU);
581 assert_eq!(client.region(), Region::EU);
582 assert_eq!(client.base_url(), "http://eu.patch.battle.net:1119");
583 }
584
585 #[test]
586 fn test_retry_configuration() {
587 let client = HttpClient::new(Region::US, ProtocolVersion::V1)
588 .unwrap()
589 .with_max_retries(3)
590 .with_initial_backoff_ms(200)
591 .with_max_backoff_ms(5000)
592 .with_backoff_multiplier(1.5)
593 .with_jitter_factor(0.2);
594
595 assert_eq!(client.max_retries, 3);
596 assert_eq!(client.initial_backoff_ms, 200);
597 assert_eq!(client.max_backoff_ms, 5000);
598 assert_eq!(client.backoff_multiplier, 1.5);
599 assert_eq!(client.jitter_factor, 0.2);
600 }
601
602 #[test]
603 fn test_jitter_factor_clamping() {
604 let client1 = HttpClient::new(Region::US, ProtocolVersion::V1)
605 .unwrap()
606 .with_jitter_factor(1.5);
607 assert_eq!(client1.jitter_factor, 1.0); let client2 = HttpClient::new(Region::US, ProtocolVersion::V1)
610 .unwrap()
611 .with_jitter_factor(-0.5);
612 assert_eq!(client2.jitter_factor, 0.0); }
614
615 #[test]
616 fn test_backoff_calculation() {
617 let client = HttpClient::new(Region::US, ProtocolVersion::V1)
618 .unwrap()
619 .with_initial_backoff_ms(100)
620 .with_max_backoff_ms(1000)
621 .with_backoff_multiplier(2.0)
622 .with_jitter_factor(0.0); let backoff0 = client.calculate_backoff(0);
626 assert_eq!(backoff0.as_millis(), 100); let backoff1 = client.calculate_backoff(1);
629 assert_eq!(backoff1.as_millis(), 200); let backoff2 = client.calculate_backoff(2);
632 assert_eq!(backoff2.as_millis(), 400); let backoff5 = client.calculate_backoff(5);
636 assert_eq!(backoff5.as_millis(), 1000); }
638
639 #[test]
640 fn test_default_retry_configuration() {
641 let client = HttpClient::new(Region::US, ProtocolVersion::V1).unwrap();
642 assert_eq!(client.max_retries, 0); }
644
645 #[test]
646 fn test_user_agent_configuration() {
647 let client = HttpClient::new(Region::US, ProtocolVersion::V1)
648 .unwrap()
649 .with_user_agent("MyCustomAgent/1.0");
650
651 assert_eq!(client.user_agent, Some("MyCustomAgent/1.0".to_string()));
652 }
653
654 #[test]
655 fn test_user_agent_default_none() {
656 let client = HttpClient::new(Region::US, ProtocolVersion::V1).unwrap();
657 assert!(client.user_agent.is_none());
658 }
659
660 #[test]
662 fn test_range_request_header_formatting() {
663 let range1 = (0, Some(1023));
665 let header1 = match range1 {
666 (start, Some(end)) => format!("bytes={start}-{end}"),
667 (start, None) => format!("bytes={start}-"),
668 };
669 assert_eq!(header1, "bytes=0-1023");
670
671 let range2 = (1024, None::<u64>);
672 let header2 = match range2 {
673 (start, Some(end)) => format!("bytes={start}-{end}"),
674 (start, None) => format!("bytes={start}-"),
675 };
676 assert_eq!(header2, "bytes=1024-");
677 }
678
679 #[test]
680 fn test_multirange_header_building() {
681 let ranges = [(0, Some(31)), (64, Some(95)), (128, None)];
682 let mut range_specs = Vec::new();
683
684 for &(start, end) in &ranges {
685 match end {
686 Some(end) => range_specs.push(format!("{start}-{end}")),
687 None => range_specs.push(format!("{start}-")),
688 }
689 }
690
691 let range_header = format!("bytes={}", range_specs.join(", "));
692 assert_eq!(range_header, "bytes=0-31, 64-95, 128-");
693 }
694}