Skip to main content

rust_tg_bot_raw/request/
base.rs

1//! Abstract interface for making HTTP requests to the Telegram Bot API.
2//!
3//! This module mirrors `telegram.request.BaseRequest` from python-telegram-bot.
4//! The central piece is the [`BaseRequest`] trait; everything else in this
5//! module is shared infrastructure used by all implementations.
6
7use std::time::Duration;
8
9use serde_json::Value;
10use tracing::debug;
11
12use crate::error::{Result, TelegramError};
13
14use super::request_data::RequestData;
15
16// ---------------------------------------------------------------------------
17// Supporting types
18// ---------------------------------------------------------------------------
19
20/// HTTP methods used when issuing requests to the Bot API.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum HttpMethod {
23    /// `POST` — used for all Bot API method calls.
24    Post,
25    /// `GET` — used when downloading files from Telegram CDN URLs.
26    Get,
27}
28
29impl HttpMethod {
30    /// The method name as an uppercase string slice, ready to pass to reqwest.
31    pub fn as_str(self) -> &'static str {
32        match self {
33            Self::Post => "POST",
34            Self::Get => "GET",
35        }
36    }
37}
38
39impl std::fmt::Display for HttpMethod {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        f.write_str(self.as_str())
42    }
43}
44
45/// Per-request timeout configuration.
46///
47/// Each field is `Option<Option<Duration>>`:
48/// - `None` — caller did not specify; use the implementation's default.
49/// - `Some(None)` — caller explicitly requested "no timeout".
50/// - `Some(Some(d))` — caller explicitly set a specific duration.
51///
52/// This two-tier encoding mirrors the Python `DEFAULT_NONE` / `DefaultValue`
53/// sentinel mechanism.
54#[derive(Debug, Clone, Copy, Default)]
55pub struct TimeoutOverride {
56    /// Maximum time to wait for a TCP connection to be established.
57    pub connect: Option<Option<Duration>>,
58    /// Maximum time to wait for the full response to arrive.
59    pub read: Option<Option<Duration>>,
60    /// Maximum time to wait while sending the request body.
61    pub write: Option<Option<Duration>>,
62    /// Maximum time to wait for a free connection from the pool.
63    pub pool: Option<Option<Duration>>,
64}
65
66impl TimeoutOverride {
67    /// All fields left as `None` (use the implementation's defaults for every
68    /// timeout dimension).
69    pub const fn default_none() -> Self {
70        Self {
71            connect: None,
72            read: None,
73            write: None,
74            pool: None,
75        }
76    }
77}
78
79/// Concrete timeout values resolved by an implementation after applying
80/// caller overrides on top of its own defaults.
81#[derive(Debug, Clone, Copy)]
82pub struct ResolvedTimeouts {
83    /// Effective connect timeout (`None` = wait forever).
84    pub connect: Option<Duration>,
85    /// Effective read timeout (`None` = wait forever).
86    pub read: Option<Duration>,
87    /// Effective write timeout (`None` = wait forever).
88    pub write: Option<Duration>,
89    /// Effective pool timeout (`None` = wait forever).
90    pub pool: Option<Duration>,
91}
92
93// ---------------------------------------------------------------------------
94// The trait
95// ---------------------------------------------------------------------------
96
97/// Abstract interface for sending HTTP requests to the Telegram Bot API.
98///
99/// Implementors must provide:
100/// - [`BaseRequest::initialize`] — open connections / warm up resources.
101/// - [`BaseRequest::shutdown`] — close connections / release resources.
102/// - [`BaseRequest::do_request`] — the raw HTTP round-trip.
103/// - [`BaseRequest::do_request_json_bytes`] — POST pre-serialized JSON bytes.
104/// - [`BaseRequest::default_read_timeout`] — the default read timeout so that
105///   provided methods can forward it downstream.
106///
107/// All other methods are provided as default implementations that implementors
108/// may override.
109///
110/// # Context manager equivalent
111///
112/// The Python `async with request_object:` pattern maps to:
113///
114/// ```ignore
115/// request.initialize().await?;
116/// // ... work ...
117/// request.shutdown().await;
118/// ```
119#[async_trait::async_trait]
120pub trait BaseRequest: Send + Sync {
121    // ------------------------------------------------------------------
122    // Abstract methods
123    // ------------------------------------------------------------------
124
125    /// Open connections and allocate resources required by this implementation.
126    async fn initialize(&self) -> Result<()>;
127
128    /// Close connections and release resources held by this implementation.
129    ///
130    /// Must not return an error even if the implementation is already shut
131    /// down — log a debug message and return `Ok(())` instead.
132    async fn shutdown(&self) -> Result<()>;
133
134    /// The default read timeout used when the caller does not supply an
135    /// override.
136    fn default_read_timeout(&self) -> Option<Duration>;
137
138    /// Perform the actual HTTP round-trip.
139    ///
140    /// Returns `(status_code, response_body)`.
141    ///
142    /// Implementations MUST convert transport-level errors into
143    /// [`TelegramError::Network`] or [`TelegramError::TimedOut`] before
144    /// returning — they must never let `reqwest::Error` or similar leak out.
145    async fn do_request(
146        &self,
147        url: &str,
148        method: HttpMethod,
149        request_data: Option<&RequestData>,
150        timeouts: TimeoutOverride,
151    ) -> Result<(u16, bytes::Bytes)>;
152
153    /// POST pre-serialized JSON bytes directly, bypassing [`RequestData`]
154    /// construction.
155    ///
156    /// This eliminates the double-serialization overhead for text-only API
157    /// methods: the caller serializes a typed struct to `Vec<u8>` once via
158    /// `serde_json::to_vec`, and this method sends those bytes with
159    /// `Content-Type: application/json`.
160    ///
161    /// Returns `(status_code, response_body)`.
162    async fn do_request_json_bytes(
163        &self,
164        url: &str,
165        body: &[u8],
166        timeouts: TimeoutOverride,
167    ) -> Result<(u16, bytes::Bytes)>;
168
169    // ------------------------------------------------------------------
170    // Provided methods
171    // ------------------------------------------------------------------
172
173    /// High-level POST call used by `Bot` methods.
174    ///
175    /// Calls [`Self::request_wrapper`] and then extracts `result` from the
176    /// Telegram JSON envelope.
177    ///
178    /// Mirrors `BaseRequest.post` in Python.
179    async fn post(
180        &self,
181        url: &str,
182        request_data: Option<&RequestData>,
183        timeouts: TimeoutOverride,
184    ) -> Result<Value> {
185        let raw = self
186            .request_wrapper(url, HttpMethod::Post, request_data, timeouts)
187            .await?;
188        // Use the free-function variant so we are not constrained by Self: Sized.
189        let json_data = parse_json_payload_impl(&raw)?;
190        // https://core.telegram.org/bots/api#making-requests — successful
191        // responses always carry a "result" key.
192        json_data
193            .get("result")
194            .cloned()
195            .ok_or_else(|| TelegramError::Network("Missing 'result' field in API response".into()))
196    }
197
198    /// High-level POST call that sends pre-serialized JSON bytes.
199    ///
200    /// Eliminates double serialization for text-only API methods by sending
201    /// raw bytes directly with `Content-Type: application/json`, then
202    /// extracting `result` from the Telegram JSON envelope.
203    async fn post_json(&self, url: &str, body: &[u8], timeouts: TimeoutOverride) -> Result<Value> {
204        let (code, payload) = self.do_request_json_bytes(url, body, timeouts).await?;
205
206        if (200..=299).contains(&code) {
207            let json_data = parse_json_payload_impl(&payload)?;
208            return json_data.get("result").cloned().ok_or_else(|| {
209                TelegramError::Network("Missing 'result' field in API response".into())
210            });
211        }
212
213        // Reuse the same error-handling logic as request_wrapper.
214        let (message, migrate_chat_id, retry_after, extra_params) =
215            parse_error_body(&payload, code);
216
217        if let Some(new_chat_id) = migrate_chat_id {
218            return Err(TelegramError::ChatMigrated { new_chat_id });
219        }
220        if let Some(secs) = retry_after {
221            return Err(TelegramError::RetryAfter {
222                retry_after: Duration::from_secs(secs),
223            });
224        }
225
226        let full_message = if let Some(params) = extra_params {
227            format!("{message}. The server response contained unknown parameters: {params}")
228        } else {
229            message
230        };
231
232        let err = match code {
233            403 => TelegramError::Forbidden(full_message),
234            401 | 404 => TelegramError::InvalidToken(full_message),
235            400 => TelegramError::BadRequest(full_message),
236            409 => TelegramError::Conflict(full_message),
237            _ => TelegramError::Network(full_message),
238        };
239
240        Err(err)
241    }
242
243    /// File download helper — issues a GET request and returns raw bytes.
244    ///
245    /// Mirrors `BaseRequest.retrieve` in Python.
246    async fn retrieve(&self, url: &str, timeouts: TimeoutOverride) -> Result<bytes::Bytes> {
247        self.request_wrapper(url, HttpMethod::Get, None, timeouts)
248            .await
249    }
250
251    /// Wraps [`Self::do_request`], translating HTTP status codes into the
252    /// appropriate [`TelegramError`] variants.
253    ///
254    /// Mirrors `BaseRequest._request_wrapper` in Python.
255    async fn request_wrapper(
256        &self,
257        url: &str,
258        method: HttpMethod,
259        request_data: Option<&RequestData>,
260        timeouts: TimeoutOverride,
261    ) -> Result<bytes::Bytes> {
262        let (code, payload) = match self.do_request(url, method, request_data, timeouts).await {
263            Ok(pair) => pair,
264            // TelegramErrors that bubbled up from do_request are re-raised as-is.
265            Err(e) => return Err(e),
266        };
267
268        if (200..=299).contains(&code) {
269            return Ok(payload);
270        }
271
272        // Attempt to extract the Telegram error description from the JSON body.
273        let (message, migrate_chat_id, retry_after, extra_params) =
274            parse_error_body(&payload, code);
275
276        // Special-case response parameters before dispatching on status code.
277        if let Some(new_chat_id) = migrate_chat_id {
278            return Err(TelegramError::ChatMigrated { new_chat_id });
279        }
280        if let Some(secs) = retry_after {
281            return Err(TelegramError::RetryAfter {
282                retry_after: Duration::from_secs(secs),
283            });
284        }
285
286        let full_message = if let Some(params) = extra_params {
287            format!("{message}. The server response contained unknown parameters: {params}")
288        } else {
289            message
290        };
291
292        let err = match code {
293            403 => TelegramError::Forbidden(full_message),
294            // 401 Unauthorized and 404 Not Found both map to InvalidToken.
295            401 | 404 => TelegramError::InvalidToken(full_message),
296            400 => TelegramError::BadRequest(full_message),
297            409 => TelegramError::Conflict(full_message),
298            // 502 Bad Gateway and anything else are network errors.
299            _ => TelegramError::Network(full_message),
300        };
301
302        Err(err)
303    }
304
305    /// Parse a UTF-8 JSON payload returned by Telegram.
306    ///
307    /// Returns a [`TelegramError::Network`] when the bytes are not valid JSON,
308    /// mirroring the Python `TelegramError("Invalid server response")`.
309    ///
310    /// Implementors may override this method to use a custom JSON library.
311    ///
312    /// The default implementation delegates to [`parse_json_payload_impl`].
313    fn parse_json_payload(&self, payload: &[u8]) -> Result<Value> {
314        parse_json_payload_impl(payload)
315    }
316}
317
318// ---------------------------------------------------------------------------
319// Free-function helpers (callable without a receiver or Self: Sized bound)
320// ---------------------------------------------------------------------------
321
322/// Parse a UTF-8 byte slice as JSON, producing a [`TelegramError::Network`]
323/// error on failure.
324///
325/// ```
326/// use rust_tg_bot_raw::request::base::parse_json_payload_impl;
327///
328/// let raw = br#"{"ok":true,"result":42}"#;
329/// let v = parse_json_payload_impl(raw).unwrap();
330/// assert_eq!(v["result"], 42);
331/// ```
332pub fn parse_json_payload_impl(payload: &[u8]) -> Result<Value> {
333    // Decode with replacement characters on invalid UTF-8, matching the
334    // Python `errors="replace"` strategy.
335    let text = String::from_utf8_lossy(payload);
336    serde_json::from_str(&text).map_err(|e| {
337        debug!("Cannot parse server response as JSON: {e}  payload={text:?}");
338        TelegramError::Network(format!("Invalid server response: {e}"))
339    })
340}
341
342// ---------------------------------------------------------------------------
343// Internal helpers
344// ---------------------------------------------------------------------------
345
346/// Extract the human-readable message and any special parameters from an error
347/// body.
348///
349/// Returns `(message, migrate_to_chat_id, retry_after_secs, unknown_params)`.
350fn parse_error_body(
351    payload: &[u8],
352    code: u16,
353) -> (String, Option<i64>, Option<u64>, Option<String>) {
354    let fallback_message = http_status_phrase(code);
355
356    match parse_json_payload_impl(payload) {
357        Err(_) => {
358            // Body is not valid JSON — return a descriptive fallback.
359            let raw = String::from_utf8_lossy(payload);
360            let msg = format!("{fallback_message}. Parsing the server response {raw:?} failed");
361            (msg, None, None, None)
362        }
363        Ok(body) => {
364            let description = body
365                .get("description")
366                .and_then(Value::as_str)
367                .map(str::to_owned)
368                .unwrap_or(fallback_message);
369
370            let parameters = body.get("parameters");
371
372            let migrate_to_chat_id = parameters
373                .and_then(|p| p.get("migrate_to_chat_id"))
374                .and_then(Value::as_i64);
375
376            let retry_after = parameters
377                .and_then(|p| p.get("retry_after"))
378                .and_then(Value::as_u64);
379
380            // Any parameters that are neither migrate_to_chat_id nor
381            // retry_after are "unknown".
382            let extra = parameters.and_then(|p| {
383                if let Value::Object(map) = p {
384                    let unknown: serde_json::Map<String, Value> = map
385                        .iter()
386                        .filter(|(k, _)| {
387                            k.as_str() != "migrate_to_chat_id" && k.as_str() != "retry_after"
388                        })
389                        .map(|(k, v)| (k.clone(), v.clone()))
390                        .collect();
391                    if unknown.is_empty() {
392                        None
393                    } else {
394                        Some(Value::Object(unknown).to_string())
395                    }
396                } else {
397                    None
398                }
399            });
400
401            (description, migrate_to_chat_id, retry_after, extra)
402        }
403    }
404}
405
406/// Best-effort HTTP status phrase lookup.
407fn http_status_phrase(code: u16) -> String {
408    let phrase = match code {
409        200 => "OK",
410        201 => "Created",
411        204 => "No Content",
412        400 => "Bad Request",
413        401 => "Unauthorized",
414        403 => "Forbidden",
415        404 => "Not Found",
416        409 => "Conflict",
417        420 => "Enhance Your Calm",
418        429 => "Too Many Requests",
419        500 => "Internal Server Error",
420        502 => "Bad Gateway",
421        503 => "Service Unavailable",
422        504 => "Gateway Timeout",
423        _ => "Unknown HTTP Error",
424    };
425    format!("{phrase} ({code})")
426}
427
428// ---------------------------------------------------------------------------
429// Re-export async_trait so that callers don't need to depend on it directly.
430// ---------------------------------------------------------------------------
431pub use async_trait::async_trait;
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436
437    // ------------------------------------------------------------------
438    // HttpMethod
439    // ------------------------------------------------------------------
440
441    #[test]
442    fn http_method_as_str() {
443        assert_eq!(HttpMethod::Post.as_str(), "POST");
444        assert_eq!(HttpMethod::Get.as_str(), "GET");
445    }
446
447    #[test]
448    fn http_method_display() {
449        assert_eq!(HttpMethod::Post.to_string(), "POST");
450    }
451
452    // ------------------------------------------------------------------
453    // parse_json_payload_impl
454    // ------------------------------------------------------------------
455
456    #[test]
457    fn parse_valid_json() {
458        let raw = br#"{"ok":true,"result":{"id":1}}"#;
459        let v = parse_json_payload_impl(raw).unwrap();
460        assert_eq!(v["ok"], true);
461        assert_eq!(v["result"]["id"], 1);
462    }
463
464    #[test]
465    fn parse_invalid_json_returns_network_error() {
466        let raw = b"not json {{";
467        let err = parse_json_payload_impl(raw).unwrap_err();
468        assert!(
469            matches!(err, TelegramError::Network(_)),
470            "expected Network, got {err:?}"
471        );
472    }
473
474    #[test]
475    fn parse_invalid_utf8_with_replacement() {
476        // 0xFF is not valid UTF-8 but we must not panic — the JSON parse will
477        // fail gracefully with a Network error.
478        let raw = b"\xff\xfe{\"ok\":true}";
479        // Either a valid parse (if replacement chars still form valid JSON) or
480        // a graceful Network error — no panics.
481        let _ = parse_json_payload_impl(raw);
482    }
483
484    // ------------------------------------------------------------------
485    // parse_error_body
486    // ------------------------------------------------------------------
487
488    #[test]
489    fn parse_error_body_extracts_description() {
490        let body = br#"{"ok":false,"error_code":400,"description":"Bad Request: chat not found"}"#;
491        let (msg, migrate, retry, extra) = parse_error_body(body, 400);
492        assert_eq!(msg, "Bad Request: chat not found");
493        assert!(migrate.is_none());
494        assert!(retry.is_none());
495        assert!(extra.is_none());
496    }
497
498    #[test]
499    fn parse_error_body_migrate_chat_id() {
500        let body = br#"{"ok":false,"error_code":400,"description":"...","parameters":{"migrate_to_chat_id":-1001234567}}"#;
501        let (_, migrate, _, _) = parse_error_body(body, 400);
502        assert_eq!(migrate, Some(-1_001_234_567_i64));
503    }
504
505    #[test]
506    fn parse_error_body_retry_after() {
507        let body = br#"{"ok":false,"error_code":429,"description":"Too Many Requests","parameters":{"retry_after":30}}"#;
508        let (_, _, retry, _) = parse_error_body(body, 429);
509        assert_eq!(retry, Some(30));
510    }
511
512    #[test]
513    fn parse_error_body_invalid_json() {
514        let body = b"<html>502 Bad Gateway</html>";
515        let (msg, _, _, _) = parse_error_body(body, 502);
516        assert!(msg.contains("Parsing the server response"), "got: {msg}");
517    }
518
519    #[test]
520    fn parse_error_body_unknown_parameters() {
521        let body = br#"{"ok":false,"description":"err","parameters":{"some_future_field":1}}"#;
522        let (msg, _, _, extra) = parse_error_body(body, 400);
523        assert_eq!(msg, "err");
524        assert!(extra.is_some(), "expected extra params, got none");
525    }
526
527    // ------------------------------------------------------------------
528    // http_status_phrase
529    // ------------------------------------------------------------------
530
531    #[test]
532    fn known_status_codes() {
533        assert!(http_status_phrase(400).contains("Bad Request"));
534        assert!(http_status_phrase(403).contains("Forbidden"));
535        assert!(http_status_phrase(409).contains("Conflict"));
536        assert!(http_status_phrase(502).contains("Bad Gateway"));
537    }
538
539    #[test]
540    fn unknown_status_code() {
541        assert!(http_status_phrase(418).contains("Unknown HTTP Error"));
542    }
543}