Skip to main content

rust_tg_bot_raw/request/
reqwest_impl.rs

1//! [`reqwest`]-backed implementation of [`BaseRequest`].
2//!
3//! This mirrors `telegram.request.HTTPXRequest` from python-telegram-bot, which
4//! uses `httpx` as its HTTP back-end.  Here we use `reqwest` instead but the
5//! public contract and behaviour are identical.
6//!
7//! # Single-client design
8//!
9//! A single `reqwest::Client` is used for both API calls and file downloads.
10//! File downloads use a longer per-request timeout via
11//! `reqwest::RequestBuilder::timeout` to avoid interfering with
12//! time-sensitive API calls while sharing the same connection pool.
13//!
14//! # Timeouts
15//!
16//! `reqwest` supports per-request timeouts via
17//! `reqwest::RequestBuilder::timeout`.  Since it does not expose separate
18//! connect / read / write / pool timeouts the way `httpx` does, we use the
19//! following mapping:
20//!
21//! | Python timeout | reqwest mapping |
22//! |---|---|
23//! | `connect_timeout` | `reqwest::ClientBuilder::connect_timeout` |
24//! | `read_timeout` | per-request `timeout` (overall) |
25//! | `write_timeout` | per-request `timeout` (overall) |
26//! | `pool_timeout` | `reqwest::ClientBuilder::pool_idle_timeout` |
27//!
28//! The effective per-request timeout is `max(read_timeout, write_timeout)` so
29//! that neither receiving nor sending is cut short prematurely.
30
31use std::sync::atomic::{AtomicBool, Ordering};
32use std::sync::Arc;
33use std::time::Duration;
34
35use bytes::Bytes;
36use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE, USER_AGENT};
37use reqwest::multipart::{Form, Part};
38use reqwest::Client;
39use tracing::{debug, warn};
40
41use crate::error::{Result, TelegramError};
42
43use super::base::{async_trait, BaseRequest, HttpMethod, TimeoutOverride};
44use super::request_data::RequestData;
45
46// ---------------------------------------------------------------------------
47// Public constants
48// ---------------------------------------------------------------------------
49
50/// User-agent header value sent with every request.
51pub const USER_AGENT_STRING: &str = concat!(
52    "rust-telegram-bot/",
53    env!("CARGO_PKG_VERSION"),
54    " (https://github.com/nicegram/rust-telegram-bot)"
55);
56
57/// Default connection pool size (matches `HTTPXRequest` default of 256).
58pub const DEFAULT_CONNECTION_POOL_SIZE: usize = 256;
59
60/// Default read timeout — 5 seconds.
61pub const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(5);
62/// Default write timeout — 5 seconds.
63pub const DEFAULT_WRITE_TIMEOUT: Duration = Duration::from_secs(5);
64/// Default connect timeout — 5 seconds.
65pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
66/// Default pool idle-connection timeout — 1 second.
67pub const DEFAULT_POOL_TIMEOUT: Duration = Duration::from_secs(1);
68/// Default media (large-file upload) write timeout — 20 seconds.
69pub const DEFAULT_MEDIA_WRITE_TIMEOUT: Duration = Duration::from_secs(20);
70
71// ---------------------------------------------------------------------------
72// Builder
73// ---------------------------------------------------------------------------
74
75/// Builder for [`ReqwestRequest`].
76///
77/// ```rust,no_run
78/// # use rust_tg_bot_raw::request::reqwest_impl::ReqwestRequest;
79/// # use std::time::Duration;
80/// let req = ReqwestRequest::builder()
81///     .connection_pool_size(128)
82///     .read_timeout(Some(Duration::from_secs(10)))
83///     .build()
84///     .expect("valid configuration");
85/// ```
86#[derive(Debug, Clone)]
87pub struct ReqwestRequestBuilder {
88    connection_pool_size: usize,
89    read_timeout: Option<Duration>,
90    write_timeout: Option<Duration>,
91    connect_timeout: Option<Duration>,
92    pool_timeout: Option<Duration>,
93    media_write_timeout: Option<Duration>,
94    /// Optional proxy URL applied to all requests (M6).
95    proxy: Option<String>,
96}
97
98impl Default for ReqwestRequestBuilder {
99    fn default() -> Self {
100        Self {
101            connection_pool_size: DEFAULT_CONNECTION_POOL_SIZE,
102            read_timeout: Some(DEFAULT_READ_TIMEOUT),
103            write_timeout: Some(DEFAULT_WRITE_TIMEOUT),
104            connect_timeout: Some(DEFAULT_CONNECT_TIMEOUT),
105            pool_timeout: Some(DEFAULT_POOL_TIMEOUT),
106            media_write_timeout: Some(DEFAULT_MEDIA_WRITE_TIMEOUT),
107            proxy: None,
108        }
109    }
110}
111
112impl ReqwestRequestBuilder {
113    /// Maximum number of idle connections kept in the pool per host.
114    pub fn connection_pool_size(mut self, size: usize) -> Self {
115        self.connection_pool_size = size;
116        self
117    }
118
119    /// Default read timeout (`None` = wait forever).
120    pub fn read_timeout(mut self, t: Option<Duration>) -> Self {
121        self.read_timeout = t;
122        self
123    }
124
125    /// Default write timeout (`None` = wait forever).
126    pub fn write_timeout(mut self, t: Option<Duration>) -> Self {
127        self.write_timeout = t;
128        self
129    }
130
131    /// Default connect timeout (`None` = wait forever).
132    pub fn connect_timeout(mut self, t: Option<Duration>) -> Self {
133        self.connect_timeout = t;
134        self
135    }
136
137    /// Default pool (idle-connection) timeout (`None` = wait forever).
138    pub fn pool_timeout(mut self, t: Option<Duration>) -> Self {
139        self.pool_timeout = t;
140        self
141    }
142
143    /// Write timeout used for large file uploads (`None` = wait forever).
144    pub fn media_write_timeout(mut self, t: Option<Duration>) -> Self {
145        self.media_write_timeout = t;
146        self
147    }
148
149    /// Set a proxy URL (e.g. `socks5://127.0.0.1:1080` or
150    /// `http://proxy.example.com:8080`).  The proxy is applied to both the API
151    /// client and the file-download client.
152    pub fn proxy(mut self, url: impl Into<String>) -> Self {
153        self.proxy = Some(url.into());
154        self
155    }
156
157    /// Consume the builder and produce a [`ReqwestRequest`].
158    pub fn build(self) -> std::result::Result<ReqwestRequest, reqwest::Error> {
159        let headers = {
160            let mut h = HeaderMap::new();
161            h.insert(USER_AGENT, HeaderValue::from_static(USER_AGENT_STRING));
162            h
163        };
164
165        let client = build_client(
166            self.connection_pool_size,
167            self.connect_timeout,
168            self.pool_timeout,
169            headers,
170            self.proxy.as_deref(),
171        )?;
172
173        Ok(ReqwestRequest {
174            client,
175            defaults: Arc::new(DefaultTimeouts {
176                read: self.read_timeout,
177                write: self.write_timeout,
178                connect: self.connect_timeout,
179                pool: self.pool_timeout,
180                media_write: self.media_write_timeout,
181            }),
182            initialized: Arc::new(AtomicBool::new(false)),
183        })
184    }
185}
186
187// ---------------------------------------------------------------------------
188// Internal helper: build a single reqwest::Client
189// ---------------------------------------------------------------------------
190
191fn build_client(
192    pool_size: usize,
193    connect_timeout: Option<Duration>,
194    pool_idle_timeout: Option<Duration>,
195    default_headers: HeaderMap,
196    proxy_url: Option<&str>,
197) -> std::result::Result<Client, reqwest::Error> {
198    let mut builder = Client::builder()
199        .default_headers(default_headers)
200        .pool_max_idle_per_host(pool_size);
201
202    if let Some(ct) = connect_timeout {
203        builder = builder.connect_timeout(ct);
204    }
205
206    if let Some(pit) = pool_idle_timeout {
207        builder = builder.pool_idle_timeout(pit);
208    }
209
210    // M6: apply proxy if configured.
211    if let Some(url) = proxy_url {
212        let proxy = reqwest::Proxy::all(url)?;
213        builder = builder.proxy(proxy);
214    }
215
216    builder.build()
217}
218
219// ---------------------------------------------------------------------------
220// Default timeouts store
221// ---------------------------------------------------------------------------
222
223#[derive(Debug, Clone, Copy)]
224struct DefaultTimeouts {
225    read: Option<Duration>,
226    write: Option<Duration>,
227    connect: Option<Duration>,
228    pool: Option<Duration>,
229    media_write: Option<Duration>,
230}
231
232/// Concrete timeout values after applying caller overrides on top of defaults.
233#[derive(Debug, Clone, Copy)]
234struct ResolvedTimeouts {
235    read: Option<Duration>,
236    write: Option<Duration>,
237    /// Kept for documentation purposes; not currently applied per-request since
238    /// reqwest applies it at client construction time.
239    #[allow(dead_code)]
240    connect: Option<Duration>,
241    /// Kept for documentation purposes; reqwest applies this at construction.
242    #[allow(dead_code)]
243    pool: Option<Duration>,
244}
245
246impl DefaultTimeouts {
247    /// Resolve caller overrides against these defaults.
248    ///
249    /// - `Some(Some(d))` = caller explicitly set `d`.
250    /// - `Some(None)` = caller explicitly set "no timeout".
251    /// - `None` = caller did not specify (use our default).
252    fn resolve(&self, override_: TimeoutOverride, has_files: bool) -> ResolvedTimeouts {
253        let write = match override_.write {
254            Some(v) => v,
255            None => {
256                if has_files {
257                    self.media_write
258                } else {
259                    self.write
260                }
261            }
262        };
263
264        ResolvedTimeouts {
265            read: override_.read.unwrap_or(self.read),
266            write,
267            connect: override_.connect.unwrap_or(self.connect),
268            pool: override_.pool.unwrap_or(self.pool),
269        }
270    }
271}
272
273// ---------------------------------------------------------------------------
274// ReqwestRequest
275// ---------------------------------------------------------------------------
276
277/// `reqwest`-backed implementation of [`BaseRequest`].
278///
279/// Construct via [`ReqwestRequest::builder()`] or [`ReqwestRequest::new()`]
280/// for sensible defaults.
281///
282/// This type is `Clone` — cloning shares the same underlying connection pools.
283#[derive(Clone)]
284pub struct ReqwestRequest {
285    /// Shared HTTP client used for both API calls and file downloads.
286    client: Client,
287    /// Default timeout configuration.
288    defaults: Arc<DefaultTimeouts>,
289    /// Whether `initialize()` has been called at least once.
290    initialized: Arc<AtomicBool>,
291}
292
293impl std::fmt::Debug for ReqwestRequest {
294    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295        f.debug_struct("ReqwestRequest")
296            .field("defaults", &self.defaults)
297            .field("initialized", &self.initialized.load(Ordering::Relaxed))
298            .finish_non_exhaustive()
299    }
300}
301
302impl ReqwestRequest {
303    /// Create a builder to customise the client.
304    pub fn builder() -> ReqwestRequestBuilder {
305        ReqwestRequestBuilder::default()
306    }
307
308    /// Create a [`ReqwestRequest`] with all default settings.
309    pub fn new() -> std::result::Result<Self, reqwest::Error> {
310        Self::builder().build()
311    }
312
313    /// `true` after [`BaseRequest::initialize`] has been called.
314    pub fn is_initialized(&self) -> bool {
315        self.initialized.load(Ordering::Relaxed)
316    }
317}
318
319// ---------------------------------------------------------------------------
320// BaseRequest implementation
321// ---------------------------------------------------------------------------
322
323#[async_trait]
324impl BaseRequest for ReqwestRequest {
325    async fn initialize(&self) -> Result<()> {
326        // reqwest::Client is ready immediately after construction.  We record
327        // the fact so that callers who check `is_initialized()` get a
328        // meaningful answer.
329        self.initialized.store(true, Ordering::Relaxed);
330        debug!("ReqwestRequest initialised");
331        Ok(())
332    }
333
334    async fn shutdown(&self) -> Result<()> {
335        if !self.initialized.load(Ordering::Relaxed) {
336            debug!("ReqwestRequest.shutdown called but already shut down — returning");
337            return Ok(());
338        }
339        // reqwest manages its own connection-pool lifecycle; there is no
340        // explicit close call.  We just mark the instance as shut down.
341        self.initialized.store(false, Ordering::Relaxed);
342        debug!("ReqwestRequest shut down");
343        Ok(())
344    }
345
346    fn default_read_timeout(&self) -> Option<Duration> {
347        self.defaults.read
348    }
349
350    async fn do_request(
351        &self,
352        url: &str,
353        method: HttpMethod,
354        request_data: Option<&RequestData>,
355        timeouts: TimeoutOverride,
356    ) -> Result<(u16, Bytes)> {
357        let has_files = request_data.is_some_and(RequestData::contains_files);
358        let resolved = self.defaults.resolve(timeouts, has_files);
359
360        // Build the reqwest request.
361        let mut req_builder = match method {
362            HttpMethod::Post => self.client.post(url),
363            HttpMethod::Get => self.client.get(url),
364        };
365
366        // Apply the effective per-request timeout.
367        // We use the max of read and write so neither operation is cut short.
368        let effective_timeout = max_duration(resolved.read, resolved.write);
369        if let Some(t) = effective_timeout {
370            req_builder = req_builder.timeout(t);
371        }
372
373        // Attach body.
374        req_builder = match request_data {
375            None => req_builder,
376            Some(data) if data.contains_files() => {
377                let form = build_multipart_form(data)?;
378                req_builder.multipart(form)
379            }
380            Some(data) => {
381                // JSON parameters sent as `application/x-www-form-urlencoded`
382                // body — matches what httpx sends when `data=` is passed.
383                let params = data.json_parameters();
384                req_builder.form(&params)
385            }
386        };
387
388        // Execute.
389        let response = req_builder.send().await.map_err(map_reqwest_error)?;
390
391        let status = response.status().as_u16();
392        let body = response
393            .bytes()
394            .await
395            .map_err(|e| TelegramError::Network(format!("Failed to read response body: {e}")))?;
396
397        Ok((status, body))
398    }
399
400    async fn do_request_json_bytes(
401        &self,
402        url: &str,
403        body: &[u8],
404        timeouts: TimeoutOverride,
405    ) -> Result<(u16, Bytes)> {
406        // Text-only requests never carry files, so use normal write timeout.
407        let resolved = self.defaults.resolve(timeouts, false);
408
409        let mut req_builder = self
410            .client
411            .post(url)
412            .header(CONTENT_TYPE, HeaderValue::from_static("application/json"))
413            .body(body.to_vec());
414
415        let effective_timeout = max_duration(resolved.read, resolved.write);
416        if let Some(t) = effective_timeout {
417            req_builder = req_builder.timeout(t);
418        }
419
420        let response = req_builder.send().await.map_err(map_reqwest_error)?;
421
422        let status = response.status().as_u16();
423        let resp_body = response
424            .bytes()
425            .await
426            .map_err(|e| TelegramError::Network(format!("Failed to read response body: {e}")))?;
427
428        Ok((status, resp_body))
429    }
430}
431
432// ---------------------------------------------------------------------------
433// Internal helpers
434// ---------------------------------------------------------------------------
435
436/// Convert a `reqwest::Error` into an appropriate [`TelegramError`].
437///
438/// Mirrors the httpx error-mapping in `HTTPXRequest.do_request`:
439/// - Timeouts → [`TelegramError::TimedOut`]
440/// - Everything else → [`TelegramError::Network`]
441fn map_reqwest_error(e: reqwest::Error) -> TelegramError {
442    if e.is_timeout() || e.is_connect() {
443        // reqwest surfaces both timeout-during-connect and read-timeout as
444        // `is_timeout()`.  Pool-exhaustion manifests as a connect error with
445        // an "operation timed out" message.
446        let msg = if e.is_timeout() {
447            format!("Request timed out: {e}")
448        } else {
449            format!("Connection error: {e}")
450        };
451        warn!("{msg}");
452        TelegramError::TimedOut(msg)
453    } else {
454        let msg = format!("reqwest error: {e}");
455        warn!("{msg}");
456        TelegramError::Network(msg)
457    }
458}
459
460/// Build a `reqwest` multipart form from a [`RequestData`].
461fn build_multipart_form(data: &RequestData) -> Result<Form> {
462    let parts = data
463        .multipart_data()
464        .expect("called only when contains_files() is true");
465
466    let mut form = Form::new();
467
468    // Add all file parts first.
469    for (part_name, multipart_part) in &parts {
470        let bytes = multipart_part.bytes.clone();
471        let mut part = Part::bytes(bytes)
472            .mime_str(&multipart_part.mime_type)
473            .map_err(|e| {
474                TelegramError::Network(format!(
475                    "Invalid MIME type '{}': {e}",
476                    multipart_part.mime_type
477                ))
478            })?;
479
480        if let Some(ref fname) = multipart_part.file_name {
481            part = part.file_name(fname.clone());
482        }
483
484        form = form.part(part_name.clone(), part);
485    }
486
487    // Add all non-file (JSON) parameters as text parts.
488    for (name, value) in data.json_parameters() {
489        form = form.text(name, value);
490    }
491
492    Ok(form)
493}
494
495/// Return the larger of two `Option<Duration>` values.
496///
497/// - `(None, None)` → `None`
498/// - `(Some(a), None)` → `Some(a)`
499/// - `(None, Some(b))` → `Some(b)`
500/// - `(Some(a), Some(b))` → `Some(max(a, b))`
501fn max_duration(a: Option<Duration>, b: Option<Duration>) -> Option<Duration> {
502    match (a, b) {
503        (None, None) => None,
504        (Some(v), None) | (None, Some(v)) => Some(v),
505        (Some(x), Some(y)) => Some(x.max(y)),
506    }
507}
508
509#[cfg(test)]
510mod tests {
511    use super::*;
512
513    // ------------------------------------------------------------------
514    // Builder / construction
515    // ------------------------------------------------------------------
516
517    #[test]
518    fn builder_defaults_produce_valid_client() {
519        ReqwestRequest::new().expect("default client construction must succeed");
520    }
521
522    #[test]
523    fn builder_custom_pool_size() {
524        ReqwestRequest::builder()
525            .connection_pool_size(4)
526            .build()
527            .expect("small pool size should be valid");
528    }
529
530    #[test]
531    fn builder_with_proxy() {
532        // Constructing with a proxy URL must not panic.  We cannot actually
533        // verify the proxy is applied without a live server, but the builder
534        // path must succeed.
535        let req = ReqwestRequest::builder()
536            .proxy("http://127.0.0.1:8080")
537            .build()
538            .expect("proxy builder should succeed");
539        assert!(!req.is_initialized());
540    }
541
542    // ------------------------------------------------------------------
543    // initialize / shutdown
544    // ------------------------------------------------------------------
545
546    #[tokio::test]
547    async fn initialize_sets_initialized_flag() {
548        let req = ReqwestRequest::new().unwrap();
549        assert!(!req.is_initialized());
550        req.initialize().await.unwrap();
551        assert!(req.is_initialized());
552    }
553
554    #[tokio::test]
555    async fn shutdown_clears_initialized_flag() {
556        let req = ReqwestRequest::new().unwrap();
557        req.initialize().await.unwrap();
558        req.shutdown().await.unwrap();
559        assert!(!req.is_initialized());
560    }
561
562    #[tokio::test]
563    async fn shutdown_idempotent() {
564        let req = ReqwestRequest::new().unwrap();
565        // Not yet initialized — should return Ok without panicking.
566        req.shutdown().await.unwrap();
567        req.shutdown().await.unwrap();
568    }
569
570    // ------------------------------------------------------------------
571    // default_read_timeout
572    // ------------------------------------------------------------------
573
574    #[test]
575    fn default_read_timeout_matches_builder() {
576        let req = ReqwestRequest::builder()
577            .read_timeout(Some(Duration::from_secs(99)))
578            .build()
579            .unwrap();
580        assert_eq!(req.default_read_timeout(), Some(Duration::from_secs(99)));
581    }
582
583    #[test]
584    fn default_read_timeout_none_when_unset() {
585        let req = ReqwestRequest::builder()
586            .read_timeout(None)
587            .build()
588            .unwrap();
589        assert_eq!(req.default_read_timeout(), None);
590    }
591
592    // ------------------------------------------------------------------
593    // max_duration helper
594    // ------------------------------------------------------------------
595
596    #[test]
597    fn max_duration_both_none() {
598        assert_eq!(max_duration(None, None), None);
599    }
600
601    #[test]
602    fn max_duration_left_some() {
603        let d = Duration::from_secs(5);
604        assert_eq!(max_duration(Some(d), None), Some(d));
605    }
606
607    #[test]
608    fn max_duration_right_some() {
609        let d = Duration::from_secs(3);
610        assert_eq!(max_duration(None, Some(d)), Some(d));
611    }
612
613    #[test]
614    fn max_duration_returns_larger() {
615        let a = Duration::from_secs(5);
616        let b = Duration::from_secs(20);
617        assert_eq!(max_duration(Some(a), Some(b)), Some(b));
618        assert_eq!(max_duration(Some(b), Some(a)), Some(b));
619    }
620
621    // ------------------------------------------------------------------
622    // DefaultTimeouts::resolve
623    // ------------------------------------------------------------------
624
625    #[test]
626    fn resolve_uses_defaults_when_no_overrides() {
627        let defaults = DefaultTimeouts {
628            read: Some(Duration::from_secs(5)),
629            write: Some(Duration::from_secs(5)),
630            connect: Some(Duration::from_secs(5)),
631            pool: Some(Duration::from_secs(1)),
632            media_write: Some(Duration::from_secs(20)),
633        };
634        let resolved = defaults.resolve(TimeoutOverride::default_none(), false);
635        assert_eq!(resolved.read, Some(Duration::from_secs(5)));
636        assert_eq!(resolved.write, Some(Duration::from_secs(5)));
637    }
638
639    #[test]
640    fn resolve_uses_media_write_timeout_when_has_files() {
641        let defaults = DefaultTimeouts {
642            read: Some(Duration::from_secs(5)),
643            write: Some(Duration::from_secs(5)),
644            connect: Some(Duration::from_secs(5)),
645            pool: Some(Duration::from_secs(1)),
646            media_write: Some(Duration::from_secs(20)),
647        };
648        let resolved = defaults.resolve(TimeoutOverride::default_none(), true);
649        assert_eq!(resolved.write, Some(Duration::from_secs(20)));
650    }
651
652    #[test]
653    fn resolve_caller_override_takes_precedence() {
654        let defaults = DefaultTimeouts {
655            read: Some(Duration::from_secs(5)),
656            write: Some(Duration::from_secs(5)),
657            connect: Some(Duration::from_secs(5)),
658            pool: Some(Duration::from_secs(1)),
659            media_write: Some(Duration::from_secs(20)),
660        };
661        let overrides = TimeoutOverride {
662            read: Some(Some(Duration::from_secs(30))),
663            write: Some(None), // explicit "no timeout"
664            ..TimeoutOverride::default_none()
665        };
666        let resolved = defaults.resolve(overrides, false);
667        assert_eq!(resolved.read, Some(Duration::from_secs(30)));
668        assert_eq!(resolved.write, None);
669    }
670
671    #[test]
672    fn resolve_explicit_none_overrides_media_timeout_even_with_files() {
673        let defaults = DefaultTimeouts {
674            read: Some(Duration::from_secs(5)),
675            write: Some(Duration::from_secs(5)),
676            connect: Some(Duration::from_secs(5)),
677            pool: Some(Duration::from_secs(1)),
678            media_write: Some(Duration::from_secs(20)),
679        };
680        // Caller says: no write timeout for this particular upload.
681        let overrides = TimeoutOverride {
682            write: Some(None),
683            ..TimeoutOverride::default_none()
684        };
685        let resolved = defaults.resolve(overrides, true);
686        assert_eq!(
687            resolved.write, None,
688            "explicit None must win over media_write"
689        );
690    }
691}