telegram_bot/connector/
hyper.rs

1use std::io::{Cursor, Read};
2use std::path::Path;
3use std::pin::Pin;
4use std::str::FromStr;
5
6use bytes::Bytes;
7use futures::{Future, FutureExt};
8use hyper::{
9    body::to_bytes,
10    client::{connect::Connect, Client},
11    header::CONTENT_TYPE,
12    http::Error as HttpError,
13    Method, Request, Uri,
14};
15#[cfg(feature = "rustls")]
16use hyper_rustls::HttpsConnector;
17#[cfg(feature = "openssl")]
18use hyper_tls::HttpsConnector;
19use multipart::client::lazy::Multipart;
20use telegram_bot_raw::{
21    Body as TelegramBody, HttpRequest, HttpResponse, Method as TelegramMethod, MultipartValue, Text,
22};
23
24use super::Connector;
25use crate::errors::{Error, ErrorKind};
26
27#[derive(Debug)]
28pub struct HyperConnector<C>(Client<C>);
29
30enum MultipartTemporaryValue {
31    Text(Text),
32    Data { file_name: Text, data: Bytes },
33}
34
35impl<C> HyperConnector<C> {
36    pub fn new(client: Client<C>) -> Self {
37        HyperConnector(client)
38    }
39}
40
41impl<C: Connect + std::fmt::Debug + 'static + Clone + Send + Sync> Connector for HyperConnector<C> {
42    fn request(
43        &self,
44        token: &str,
45        req: HttpRequest,
46    ) -> Pin<Box<dyn Future<Output = Result<HttpResponse, Error>> + Send>> {
47        let uri = Uri::from_str(&req.url.url(token));
48        let client = self.0.clone();
49
50        let future = async move {
51            let uri = uri.map_err(HttpError::from).map_err(ErrorKind::from)?;
52
53            let method = match req.method {
54                TelegramMethod::Get => Method::GET,
55                TelegramMethod::Post => Method::POST,
56            };
57
58            let mut http_request = Request::builder().method(method).uri(uri);
59
60            let request = match req.body {
61                TelegramBody::Empty => http_request.body(Into::<hyper::Body>::into(vec![])),
62                TelegramBody::Json(body) => {
63                    let content_type = "application/json"
64                        .parse()
65                        .map_err(HttpError::from)
66                        .map_err(ErrorKind::from)?;
67                    http_request
68                        .headers_mut()
69                        .map(move |headers| headers.insert(CONTENT_TYPE, content_type));
70                    http_request.body(Into::<hyper::Body>::into(body))
71                }
72                TelegramBody::Multipart(parts) => {
73                    let mut fields = Vec::new();
74                    for (key, value) in parts {
75                        match value {
76                            MultipartValue::Text(text) => {
77                                fields.push((key, MultipartTemporaryValue::Text(text)))
78                            }
79                            MultipartValue::Path { file_name, path } => {
80                                let file_name = file_name
81                                    .or_else(|| {
82                                        AsRef::<Path>::as_ref(&path)
83                                            .file_name()
84                                            .and_then(|s| s.to_str())
85                                            .map(Into::into)
86                                    })
87                                    .ok_or(ErrorKind::InvalidMultipartFilename)?;
88
89                                let data = tokio::fs::read(path).await.map_err(ErrorKind::from)?;
90                                fields.push((
91                                    key,
92                                    MultipartTemporaryValue::Data {
93                                        file_name,
94                                        data: data.into(),
95                                    },
96                                ))
97                            }
98                            MultipartValue::Data { file_name, data } => fields
99                                .push((key, MultipartTemporaryValue::Data { file_name, data })),
100                        }
101                    }
102
103                    let mut prepared = {
104                        let mut part = Multipart::new();
105                        for (key, value) in &fields {
106                            match value {
107                                MultipartTemporaryValue::Text(text) => {
108                                    part.add_text(*key, text.as_str());
109                                }
110                                MultipartTemporaryValue::Data { file_name, data } => {
111                                    part.add_stream(
112                                        *key,
113                                        Cursor::new(data),
114                                        Some(file_name.as_str()),
115                                        None,
116                                    );
117                                }
118                            }
119                        }
120                        part.prepare().map_err(|err| err.error)
121                    }
122                    .map_err(ErrorKind::from)?;
123
124                    let boundary = prepared.boundary();
125
126                    let content_type =
127                        format!("multipart/form-data;boundary={bound}", bound = boundary)
128                            .parse()
129                            .map_err(HttpError::from)
130                            .map_err(ErrorKind::from)?;
131                    http_request.headers_mut().map(move |headers| {
132                        headers.insert(CONTENT_TYPE, content_type);
133                    });
134
135                    let mut bytes = Vec::new();
136                    prepared.read_to_end(&mut bytes).map_err(ErrorKind::from)?;
137                    http_request.body(bytes.into())
138                }
139                body => panic!("Unknown body type {:?}", body),
140            }
141            .map_err(ErrorKind::from)?;
142
143            let response = client.request(request).await.map_err(ErrorKind::from)?;
144            let whole_chunk = to_bytes(response.into_body()).await;
145
146            let body = whole_chunk
147                .iter()
148                .fold(vec![], |mut acc, chunk| -> Vec<u8> {
149                    acc.extend_from_slice(&chunk);
150                    acc
151                });
152
153            Ok::<HttpResponse, Error>(HttpResponse { body: Some(body) })
154        };
155
156        future.boxed()
157    }
158}
159
160pub fn default_connector() -> Result<Box<dyn Connector>, Error> {
161    #[cfg(feature = "rustls")]
162    let connector = HttpsConnector::new();
163
164    #[cfg(feature = "openssl")]
165    let connector = HttpsConnector::new();
166
167    Ok(Box::new(HyperConnector::new(
168        Client::builder().build(connector),
169    )))
170}