rust_rcs_client/messaging/ft_http/
download.rs

1// Copyright 2023 宋昊文
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    fmt::Display,
17    fs::{File, OpenOptions},
18    io::{self, Seek, Write},
19    pin::Pin,
20    sync::Arc,
21    task::{Context, Poll},
22    u32,
23};
24
25use futures::{future::BoxFuture, io::copy_buf, AsyncWrite, FutureExt};
26use rust_rcs_core::{
27    ffi::log::platform_log,
28    http::{
29        request::{Request, GET},
30        HttpClient,
31    },
32    internet::{header, Header},
33    io::ProgressReportingReader,
34    security::{
35        authentication::digest::DigestAnswerParams,
36        gba::{self, GbaContext},
37        SecurityContext,
38    },
39};
40use tokio::io::copy;
41use tokio_util::compat::{FuturesAsyncReadCompatExt, FuturesAsyncWriteCompatExt};
42use url::Url;
43
44const LOG_TAG: &str = "fthttp";
45
46pub enum FileDownloadError {
47    Http(u16, String),
48    IO,
49    MalformedHost,
50    NetworkIO,
51}
52
53impl FileDownloadError {
54    pub fn error_code(&self) -> u16 {
55        match &self {
56            FileDownloadError::Http(status_code, _) => *status_code,
57            FileDownloadError::IO => 0,
58            FileDownloadError::MalformedHost => 0,
59            FileDownloadError::NetworkIO => 0,
60        }
61    }
62
63    pub fn error_string(&self) -> String {
64        match &self {
65            FileDownloadError::Http(_, reason_phrase) => String::from(reason_phrase),
66            FileDownloadError::IO => String::from("IO"),
67            FileDownloadError::MalformedHost => String::from("MalformedHost"),
68            FileDownloadError::NetworkIO => String::from("NetworkIO"),
69        }
70    }
71}
72
73impl Display for FileDownloadError {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        match &self {
76            FileDownloadError::Http(status_code, reason_phrase) => {
77                f.write_fmt(format_args!("Http {} {}", status_code, reason_phrase))
78            }
79            FileDownloadError::IO => f.write_str("IO"),
80            FileDownloadError::MalformedHost => f.write_str("MalformedHost"),
81            FileDownloadError::NetworkIO => f.write_str("NetworkIO"),
82        }
83    }
84}
85
86struct FileOutput {
87    f: File,
88}
89
90impl AsyncWrite for FileOutput {
91    fn poll_write(
92        self: Pin<&mut Self>,
93        _cx: &mut Context<'_>,
94        buf: &[u8],
95    ) -> Poll<io::Result<usize>> {
96        let p = self.get_mut();
97        match p.f.write(buf) {
98            Ok(i) => Poll::Ready(Ok(i)),
99            Err(e) => match e.kind() {
100                io::ErrorKind::WouldBlock => Poll::Pending,
101                _ => Poll::Ready(Err(e)),
102            },
103        }
104    }
105
106    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
107        let p = self.get_mut();
108        match p.f.flush() {
109            Ok(()) => Poll::Ready(Ok(())),
110            Err(e) => match e.kind() {
111                io::ErrorKind::WouldBlock => Poll::Pending,
112                _ => Poll::Ready(Err(e)),
113            },
114        }
115    }
116
117    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
118        let p = self.get_mut();
119        match p.f.sync_all() {
120            Ok(()) => Poll::Ready(Ok(())),
121            Err(e) => match e.kind() {
122                io::ErrorKind::WouldBlock => Poll::Pending,
123                _ => Poll::Ready(Err(e)),
124            },
125        }
126    }
127}
128
129async fn download_file_inner<F>(
130    file_uri: &str,
131    download_path: &str,
132    start: usize,
133    total: Option<usize>,
134    msisdn: Option<&str>,
135    http_client: &Arc<HttpClient>,
136    gba_context: &Arc<GbaContext>,
137    security_context: &Arc<SecurityContext>,
138    digest_answer: Option<&DigestAnswerParams>,
139    progress_callback: F,
140) -> Result<(), FileDownloadError>
141where
142    F: Fn(u32, i32) + Send + Sync + 'static,
143{
144    if let Ok(url) = Url::parse(file_uri) {
145        if let Ok(conn) = http_client.connect(&url, false).await {
146            let host = url.host_str().unwrap();
147
148            let mut req = Request::new_with_default_headers(GET, host, url.path(), url.query());
149
150            if let Some(msisdn) = msisdn {
151                req.headers.push(Header::new(
152                    b"X-3GPP-Intended-Identity",
153                    format!("tel:{}", msisdn),
154                ));
155            }
156
157            let preloaded_answer = match digest_answer {
158                Some(_) => None,
159                None => {
160                    platform_log(LOG_TAG, "using stored authorization info");
161                    security_context.preload_auth(gba_context, host, conn.cipher_id(), GET, None)
162                }
163            };
164
165            let digest_answer = match digest_answer {
166                Some(digest_answer) => Some(digest_answer),
167                None => match &preloaded_answer {
168                    Some(preloaded_answer) => {
169                        platform_log(LOG_TAG, "using preloaded digest answer");
170                        Some(preloaded_answer)
171                    }
172                    None => None,
173                },
174            };
175
176            if let Some(digest_answer) = digest_answer {
177                if let Ok(authorization) = digest_answer.make_authorization_header(
178                    match &digest_answer.challenge {
179                        Some(challenge) => Some(&challenge.algorithm),
180                        None => None,
181                    },
182                    false,
183                    false,
184                ) {
185                    if let Ok(authorization) = String::from_utf8(authorization) {
186                        req.headers
187                            .push(Header::new(b"Authorization", String::from(authorization)));
188                    }
189                }
190            }
191
192            if start > 0 {
193                if let Some(total) = total {
194                    if total > start {
195                        req.headers.push(Header::new(
196                            b"Range",
197                            format!("bytes={}-{}", start, total - 1),
198                        ));
199                    }
200                } else {
201                    req.headers
202                        .push(Header::new(b"Range", format!("bytes={}-", start)));
203                }
204            }
205
206            if let Ok((resp, resp_stream)) = conn.send(req, |_| {}).await {
207                platform_log(
208                    LOG_TAG,
209                    format!(
210                        "download_file_inner resp.status_code = {}",
211                        resp.status_code
212                    ),
213                );
214
215                if resp.status_code == 200 {
216                    if let Some(authentication_info_header) =
217                        header::search(&resp.headers, b"Authentication-Info", false)
218                    {
219                        if let Some(digest_answer) = digest_answer {
220                            if let Some(challenge) = &digest_answer.challenge {
221                                security_context.update_auth_info(
222                                    authentication_info_header,
223                                    host,
224                                    b"\"",
225                                    challenge,
226                                    false,
227                                );
228                            }
229                        }
230                    }
231
232                    let progress_total = get_progress_total(total);
233
234                    if let Some(resp_stream) = resp_stream {
235                        let mut f = if start == 0 {
236                            match OpenOptions::new()
237                                .write(true)
238                                .create(true)
239                                .open(download_path)
240                            {
241                                Ok(f) => f,
242                                Err(e) => {
243                                    platform_log(LOG_TAG, format!("file create error: {}", e));
244                                    return Err(FileDownloadError::IO);
245                                }
246                            }
247                        } else {
248                            match OpenOptions::new()
249                                .write(true)
250                                .append(true)
251                                .open(download_path)
252                            {
253                                Ok(f) => f,
254                                Err(e) => {
255                                    platform_log(LOG_TAG, format!("file open error: {}", e));
256                                    match OpenOptions::new()
257                                        .write(true)
258                                        .create(true)
259                                        .open(download_path)
260                                    {
261                                        Ok(f) => f,
262                                        Err(e) => {
263                                            platform_log(
264                                                LOG_TAG,
265                                                format!("file create error: {}", e),
266                                            );
267                                            return Err(FileDownloadError::IO);
268                                        }
269                                    }
270                                }
271                            }
272                        };
273
274                        loop {
275                            match f.seek(io::SeekFrom::Current(0)) {
276                                Ok(i) => {
277                                    if let Ok(i) = usize::try_from(i) {
278                                        if i == start {
279                                            break;
280                                        }
281                                    }
282
283                                    platform_log(
284                                        LOG_TAG,
285                                        "underlying file size is inconsistent with provided value",
286                                    );
287                                }
288
289                                Err(e) => {
290                                    platform_log(LOG_TAG, format!("file seek error: {}", e));
291                                }
292                            }
293
294                            return Err(FileDownloadError::IO);
295                        }
296
297                        let f = FileOutput { f };
298
299                        let reader = ProgressReportingReader::new(resp_stream, move |read| {
300                            if let Ok(current) = u32::try_from(read) {
301                                progress_callback(current, progress_total);
302                            }
303                        });
304
305                        let mut rh = reader.compat();
306                        let mut wh = f.compat_write();
307
308                        match copy(&mut rh, &mut wh).await {
309                            Ok(i) => {
310                                platform_log(LOG_TAG, format!("bytes copied {}", i));
311                                let download_size_verified = if let Some(total) = total {
312                                    if let Ok(i) = usize::try_from(i) {
313                                        if start + i == total {
314                                            true
315                                        } else {
316                                            false
317                                        }
318                                    } else {
319                                        false
320                                    }
321                                } else {
322                                    true
323                                };
324                                if download_size_verified {
325                                    return Ok(());
326                                }
327                                platform_log(
328                                    LOG_TAG,
329                                    "inconsistent result of bytes copied and expected total",
330                                );
331                                return Err(FileDownloadError::IO);
332                            }
333                            Err(e) => {
334                                platform_log(
335                                    LOG_TAG,
336                                    format!("http stream copy failed with error: {}", e),
337                                );
338                                return Err(FileDownloadError::IO);
339                            }
340                        }
341                    }
342                } else if resp.status_code == 401 {
343                    if digest_answer.is_none() {
344                        if let Some(www_authenticate_header) =
345                            header::search(&resp.headers, b"WWW-Authenticate", false)
346                        {
347                            if let Some(Ok(answer)) = gba::try_process_401_response(
348                                gba_context,
349                                host.as_bytes(),
350                                conn.cipher_id(),
351                                GET,
352                                b"\"/\"",
353                                None,
354                                www_authenticate_header,
355                                http_client,
356                                security_context,
357                            )
358                            .await
359                            {
360                                return download_file(
361                                    file_uri,
362                                    download_path,
363                                    start,
364                                    total,
365                                    msisdn,
366                                    http_client,
367                                    gba_context,
368                                    security_context,
369                                    Some(&answer),
370                                    progress_callback,
371                                )
372                                .await;
373                            }
374                        }
375                    }
376                } else {
377                    return Err(FileDownloadError::Http(
378                        resp.status_code,
379                        match String::from_utf8(resp.reason_phrase) {
380                            Ok(reason_phrase) => reason_phrase,
381                            Err(_) => String::from(""),
382                        },
383                    ));
384                }
385            }
386        }
387
388        Err(FileDownloadError::NetworkIO)
389    } else {
390        Err(FileDownloadError::MalformedHost)
391    }
392}
393
394pub fn download_file<'a, 'b: 'a, F>(
395    file_uri: &'b str,
396    download_path: &'b str,
397    start: usize,
398    total: Option<usize>,
399    msisdn: Option<&'b str>,
400    http_client: &'b Arc<HttpClient>,
401    gba_context: &'b Arc<GbaContext>,
402    security_context: &'b Arc<SecurityContext>,
403    digest_answer: Option<&'a DigestAnswerParams>,
404    progress_callback: F,
405) -> BoxFuture<'a, Result<(), FileDownloadError>>
406where
407    F: Fn(u32, i32) + Send + Sync + 'static,
408{
409    async move {
410        download_file_inner(
411            file_uri,
412            download_path,
413            start,
414            total,
415            msisdn,
416            http_client,
417            gba_context,
418            security_context,
419            digest_answer,
420            progress_callback,
421        )
422        .await
423    }
424    .boxed()
425}
426
427fn get_progress_total(total: Option<usize>) -> i32 {
428    if let Some(total) = total {
429        if let Ok(total) = i32::try_from(total) {
430            total
431        } else {
432            -1
433        }
434    } else {
435        -1
436    }
437}
438
439fn get_progress_current(start: usize, i: u64) -> u32 {
440    if let Ok(start) = u32::try_from(start) {
441        if let Ok(i) = u32::try_from(i) {
442            return start + i;
443        } else {
444            0
445        }
446    } else {
447        0
448    }
449}