1use std::sync::atomic::{AtomicBool, Ordering};
32use std::sync::Arc;
33use std::time::Duration;
34
35use bytes::Bytes;
36use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE, USER_AGENT};
37use reqwest::multipart::{Form, Part};
38use reqwest::Client;
39use tracing::{debug, warn};
40
41use crate::error::{Result, TelegramError};
42
43use super::base::{async_trait, BaseRequest, HttpMethod, TimeoutOverride};
44use super::request_data::RequestData;
45
46pub const USER_AGENT_STRING: &str = concat!(
52 "rust-telegram-bot/",
53 env!("CARGO_PKG_VERSION"),
54 " (https://github.com/nicegram/rust-telegram-bot)"
55);
56
57pub const DEFAULT_CONNECTION_POOL_SIZE: usize = 256;
59
60pub const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(5);
62pub const DEFAULT_WRITE_TIMEOUT: Duration = Duration::from_secs(5);
64pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
66pub const DEFAULT_POOL_TIMEOUT: Duration = Duration::from_secs(1);
68pub const DEFAULT_MEDIA_WRITE_TIMEOUT: Duration = Duration::from_secs(20);
70
71#[derive(Debug, Clone)]
87pub struct ReqwestRequestBuilder {
88 connection_pool_size: usize,
89 read_timeout: Option<Duration>,
90 write_timeout: Option<Duration>,
91 connect_timeout: Option<Duration>,
92 pool_timeout: Option<Duration>,
93 media_write_timeout: Option<Duration>,
94 proxy: Option<String>,
96}
97
98impl Default for ReqwestRequestBuilder {
99 fn default() -> Self {
100 Self {
101 connection_pool_size: DEFAULT_CONNECTION_POOL_SIZE,
102 read_timeout: Some(DEFAULT_READ_TIMEOUT),
103 write_timeout: Some(DEFAULT_WRITE_TIMEOUT),
104 connect_timeout: Some(DEFAULT_CONNECT_TIMEOUT),
105 pool_timeout: Some(DEFAULT_POOL_TIMEOUT),
106 media_write_timeout: Some(DEFAULT_MEDIA_WRITE_TIMEOUT),
107 proxy: None,
108 }
109 }
110}
111
112impl ReqwestRequestBuilder {
113 pub fn connection_pool_size(mut self, size: usize) -> Self {
115 self.connection_pool_size = size;
116 self
117 }
118
119 pub fn read_timeout(mut self, t: Option<Duration>) -> Self {
121 self.read_timeout = t;
122 self
123 }
124
125 pub fn write_timeout(mut self, t: Option<Duration>) -> Self {
127 self.write_timeout = t;
128 self
129 }
130
131 pub fn connect_timeout(mut self, t: Option<Duration>) -> Self {
133 self.connect_timeout = t;
134 self
135 }
136
137 pub fn pool_timeout(mut self, t: Option<Duration>) -> Self {
139 self.pool_timeout = t;
140 self
141 }
142
143 pub fn media_write_timeout(mut self, t: Option<Duration>) -> Self {
145 self.media_write_timeout = t;
146 self
147 }
148
149 pub fn proxy(mut self, url: impl Into<String>) -> Self {
153 self.proxy = Some(url.into());
154 self
155 }
156
157 pub fn build(self) -> std::result::Result<ReqwestRequest, reqwest::Error> {
159 let headers = {
160 let mut h = HeaderMap::new();
161 h.insert(USER_AGENT, HeaderValue::from_static(USER_AGENT_STRING));
162 h
163 };
164
165 let client = build_client(
166 self.connection_pool_size,
167 self.connect_timeout,
168 self.pool_timeout,
169 headers,
170 self.proxy.as_deref(),
171 )?;
172
173 Ok(ReqwestRequest {
174 client,
175 defaults: Arc::new(DefaultTimeouts {
176 read: self.read_timeout,
177 write: self.write_timeout,
178 connect: self.connect_timeout,
179 pool: self.pool_timeout,
180 media_write: self.media_write_timeout,
181 }),
182 initialized: Arc::new(AtomicBool::new(false)),
183 })
184 }
185}
186
187fn build_client(
192 pool_size: usize,
193 connect_timeout: Option<Duration>,
194 pool_idle_timeout: Option<Duration>,
195 default_headers: HeaderMap,
196 proxy_url: Option<&str>,
197) -> std::result::Result<Client, reqwest::Error> {
198 let mut builder = Client::builder()
199 .default_headers(default_headers)
200 .pool_max_idle_per_host(pool_size);
201
202 if let Some(ct) = connect_timeout {
203 builder = builder.connect_timeout(ct);
204 }
205
206 if let Some(pit) = pool_idle_timeout {
207 builder = builder.pool_idle_timeout(pit);
208 }
209
210 if let Some(url) = proxy_url {
212 let proxy = reqwest::Proxy::all(url)?;
213 builder = builder.proxy(proxy);
214 }
215
216 builder.build()
217}
218
219#[derive(Debug, Clone, Copy)]
224struct DefaultTimeouts {
225 read: Option<Duration>,
226 write: Option<Duration>,
227 connect: Option<Duration>,
228 pool: Option<Duration>,
229 media_write: Option<Duration>,
230}
231
232#[derive(Debug, Clone, Copy)]
234struct ResolvedTimeouts {
235 read: Option<Duration>,
236 write: Option<Duration>,
237 #[allow(dead_code)]
240 connect: Option<Duration>,
241 #[allow(dead_code)]
243 pool: Option<Duration>,
244}
245
246impl DefaultTimeouts {
247 fn resolve(&self, override_: TimeoutOverride, has_files: bool) -> ResolvedTimeouts {
253 let write = match override_.write {
254 Some(v) => v,
255 None => {
256 if has_files {
257 self.media_write
258 } else {
259 self.write
260 }
261 }
262 };
263
264 ResolvedTimeouts {
265 read: override_.read.unwrap_or(self.read),
266 write,
267 connect: override_.connect.unwrap_or(self.connect),
268 pool: override_.pool.unwrap_or(self.pool),
269 }
270 }
271}
272
273#[derive(Clone)]
284pub struct ReqwestRequest {
285 client: Client,
287 defaults: Arc<DefaultTimeouts>,
289 initialized: Arc<AtomicBool>,
291}
292
293impl std::fmt::Debug for ReqwestRequest {
294 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295 f.debug_struct("ReqwestRequest")
296 .field("defaults", &self.defaults)
297 .field("initialized", &self.initialized.load(Ordering::Relaxed))
298 .finish_non_exhaustive()
299 }
300}
301
302impl ReqwestRequest {
303 pub fn builder() -> ReqwestRequestBuilder {
305 ReqwestRequestBuilder::default()
306 }
307
308 pub fn new() -> std::result::Result<Self, reqwest::Error> {
310 Self::builder().build()
311 }
312
313 pub fn is_initialized(&self) -> bool {
315 self.initialized.load(Ordering::Relaxed)
316 }
317}
318
319#[async_trait]
324impl BaseRequest for ReqwestRequest {
325 async fn initialize(&self) -> Result<()> {
326 self.initialized.store(true, Ordering::Relaxed);
330 debug!("ReqwestRequest initialised");
331 Ok(())
332 }
333
334 async fn shutdown(&self) -> Result<()> {
335 if !self.initialized.load(Ordering::Relaxed) {
336 debug!("ReqwestRequest.shutdown called but already shut down — returning");
337 return Ok(());
338 }
339 self.initialized.store(false, Ordering::Relaxed);
342 debug!("ReqwestRequest shut down");
343 Ok(())
344 }
345
346 fn default_read_timeout(&self) -> Option<Duration> {
347 self.defaults.read
348 }
349
350 async fn do_request(
351 &self,
352 url: &str,
353 method: HttpMethod,
354 request_data: Option<&RequestData>,
355 timeouts: TimeoutOverride,
356 ) -> Result<(u16, Bytes)> {
357 let has_files = request_data.is_some_and(RequestData::contains_files);
358 let resolved = self.defaults.resolve(timeouts, has_files);
359
360 let mut req_builder = match method {
362 HttpMethod::Post => self.client.post(url),
363 HttpMethod::Get => self.client.get(url),
364 };
365
366 let effective_timeout = max_duration(resolved.read, resolved.write);
369 if let Some(t) = effective_timeout {
370 req_builder = req_builder.timeout(t);
371 }
372
373 req_builder = match request_data {
375 None => req_builder,
376 Some(data) if data.contains_files() => {
377 let form = build_multipart_form(data)?;
378 req_builder.multipart(form)
379 }
380 Some(data) => {
381 let params = data.json_parameters();
384 req_builder.form(¶ms)
385 }
386 };
387
388 let response = req_builder.send().await.map_err(map_reqwest_error)?;
390
391 let status = response.status().as_u16();
392 let body = response
393 .bytes()
394 .await
395 .map_err(|e| TelegramError::Network(format!("Failed to read response body: {e}")))?;
396
397 Ok((status, body))
398 }
399
400 async fn do_request_json_bytes(
401 &self,
402 url: &str,
403 body: &[u8],
404 timeouts: TimeoutOverride,
405 ) -> Result<(u16, Bytes)> {
406 let resolved = self.defaults.resolve(timeouts, false);
408
409 let mut req_builder = self
410 .client
411 .post(url)
412 .header(CONTENT_TYPE, HeaderValue::from_static("application/json"))
413 .body(body.to_vec());
414
415 let effective_timeout = max_duration(resolved.read, resolved.write);
416 if let Some(t) = effective_timeout {
417 req_builder = req_builder.timeout(t);
418 }
419
420 let response = req_builder.send().await.map_err(map_reqwest_error)?;
421
422 let status = response.status().as_u16();
423 let resp_body = response
424 .bytes()
425 .await
426 .map_err(|e| TelegramError::Network(format!("Failed to read response body: {e}")))?;
427
428 Ok((status, resp_body))
429 }
430}
431
432fn map_reqwest_error(e: reqwest::Error) -> TelegramError {
442 if e.is_timeout() || e.is_connect() {
443 let msg = if e.is_timeout() {
447 format!("Request timed out: {e}")
448 } else {
449 format!("Connection error: {e}")
450 };
451 warn!("{msg}");
452 TelegramError::TimedOut(msg)
453 } else {
454 let msg = format!("reqwest error: {e}");
455 warn!("{msg}");
456 TelegramError::Network(msg)
457 }
458}
459
460fn build_multipart_form(data: &RequestData) -> Result<Form> {
462 let parts = data
463 .multipart_data()
464 .expect("called only when contains_files() is true");
465
466 let mut form = Form::new();
467
468 for (part_name, multipart_part) in &parts {
470 let bytes = multipart_part.bytes.clone();
471 let mut part = Part::bytes(bytes)
472 .mime_str(&multipart_part.mime_type)
473 .map_err(|e| {
474 TelegramError::Network(format!(
475 "Invalid MIME type '{}': {e}",
476 multipart_part.mime_type
477 ))
478 })?;
479
480 if let Some(ref fname) = multipart_part.file_name {
481 part = part.file_name(fname.clone());
482 }
483
484 form = form.part(part_name.clone(), part);
485 }
486
487 for (name, value) in data.json_parameters() {
489 form = form.text(name, value);
490 }
491
492 Ok(form)
493}
494
495fn max_duration(a: Option<Duration>, b: Option<Duration>) -> Option<Duration> {
502 match (a, b) {
503 (None, None) => None,
504 (Some(v), None) | (None, Some(v)) => Some(v),
505 (Some(x), Some(y)) => Some(x.max(y)),
506 }
507}
508
509#[cfg(test)]
510mod tests {
511 use super::*;
512
513 #[test]
518 fn builder_defaults_produce_valid_client() {
519 ReqwestRequest::new().expect("default client construction must succeed");
520 }
521
522 #[test]
523 fn builder_custom_pool_size() {
524 ReqwestRequest::builder()
525 .connection_pool_size(4)
526 .build()
527 .expect("small pool size should be valid");
528 }
529
530 #[test]
531 fn builder_with_proxy() {
532 let req = ReqwestRequest::builder()
536 .proxy("http://127.0.0.1:8080")
537 .build()
538 .expect("proxy builder should succeed");
539 assert!(!req.is_initialized());
540 }
541
542 #[tokio::test]
547 async fn initialize_sets_initialized_flag() {
548 let req = ReqwestRequest::new().unwrap();
549 assert!(!req.is_initialized());
550 req.initialize().await.unwrap();
551 assert!(req.is_initialized());
552 }
553
554 #[tokio::test]
555 async fn shutdown_clears_initialized_flag() {
556 let req = ReqwestRequest::new().unwrap();
557 req.initialize().await.unwrap();
558 req.shutdown().await.unwrap();
559 assert!(!req.is_initialized());
560 }
561
562 #[tokio::test]
563 async fn shutdown_idempotent() {
564 let req = ReqwestRequest::new().unwrap();
565 req.shutdown().await.unwrap();
567 req.shutdown().await.unwrap();
568 }
569
570 #[test]
575 fn default_read_timeout_matches_builder() {
576 let req = ReqwestRequest::builder()
577 .read_timeout(Some(Duration::from_secs(99)))
578 .build()
579 .unwrap();
580 assert_eq!(req.default_read_timeout(), Some(Duration::from_secs(99)));
581 }
582
583 #[test]
584 fn default_read_timeout_none_when_unset() {
585 let req = ReqwestRequest::builder()
586 .read_timeout(None)
587 .build()
588 .unwrap();
589 assert_eq!(req.default_read_timeout(), None);
590 }
591
592 #[test]
597 fn max_duration_both_none() {
598 assert_eq!(max_duration(None, None), None);
599 }
600
601 #[test]
602 fn max_duration_left_some() {
603 let d = Duration::from_secs(5);
604 assert_eq!(max_duration(Some(d), None), Some(d));
605 }
606
607 #[test]
608 fn max_duration_right_some() {
609 let d = Duration::from_secs(3);
610 assert_eq!(max_duration(None, Some(d)), Some(d));
611 }
612
613 #[test]
614 fn max_duration_returns_larger() {
615 let a = Duration::from_secs(5);
616 let b = Duration::from_secs(20);
617 assert_eq!(max_duration(Some(a), Some(b)), Some(b));
618 assert_eq!(max_duration(Some(b), Some(a)), Some(b));
619 }
620
621 #[test]
626 fn resolve_uses_defaults_when_no_overrides() {
627 let defaults = DefaultTimeouts {
628 read: Some(Duration::from_secs(5)),
629 write: Some(Duration::from_secs(5)),
630 connect: Some(Duration::from_secs(5)),
631 pool: Some(Duration::from_secs(1)),
632 media_write: Some(Duration::from_secs(20)),
633 };
634 let resolved = defaults.resolve(TimeoutOverride::default_none(), false);
635 assert_eq!(resolved.read, Some(Duration::from_secs(5)));
636 assert_eq!(resolved.write, Some(Duration::from_secs(5)));
637 }
638
639 #[test]
640 fn resolve_uses_media_write_timeout_when_has_files() {
641 let defaults = DefaultTimeouts {
642 read: Some(Duration::from_secs(5)),
643 write: Some(Duration::from_secs(5)),
644 connect: Some(Duration::from_secs(5)),
645 pool: Some(Duration::from_secs(1)),
646 media_write: Some(Duration::from_secs(20)),
647 };
648 let resolved = defaults.resolve(TimeoutOverride::default_none(), true);
649 assert_eq!(resolved.write, Some(Duration::from_secs(20)));
650 }
651
652 #[test]
653 fn resolve_caller_override_takes_precedence() {
654 let defaults = DefaultTimeouts {
655 read: Some(Duration::from_secs(5)),
656 write: Some(Duration::from_secs(5)),
657 connect: Some(Duration::from_secs(5)),
658 pool: Some(Duration::from_secs(1)),
659 media_write: Some(Duration::from_secs(20)),
660 };
661 let overrides = TimeoutOverride {
662 read: Some(Some(Duration::from_secs(30))),
663 write: Some(None), ..TimeoutOverride::default_none()
665 };
666 let resolved = defaults.resolve(overrides, false);
667 assert_eq!(resolved.read, Some(Duration::from_secs(30)));
668 assert_eq!(resolved.write, None);
669 }
670
671 #[test]
672 fn resolve_explicit_none_overrides_media_timeout_even_with_files() {
673 let defaults = DefaultTimeouts {
674 read: Some(Duration::from_secs(5)),
675 write: Some(Duration::from_secs(5)),
676 connect: Some(Duration::from_secs(5)),
677 pool: Some(Duration::from_secs(1)),
678 media_write: Some(Duration::from_secs(20)),
679 };
680 let overrides = TimeoutOverride {
682 write: Some(None),
683 ..TimeoutOverride::default_none()
684 };
685 let resolved = defaults.resolve(overrides, true);
686 assert_eq!(
687 resolved.write, None,
688 "explicit None must win over media_write"
689 );
690 }
691}