Skip to main content

typst_kit/
downloader.rs

1//! Web requests with optional progress reporting.
2//!
3//! All `typst-kit` functionality that may trigger downloads goes through the
4//! [`Downloader`] trait. A built-in implementation is provided through the
5//! [`SystemDownloader`].
6//!
7//! Downloads can optionally be tracked by wrapping an existing downloader in a
8//! [`ProgressDownloader`]. All downloads are identified by a dynamic key, so
9//! that the progress downloader's underlying [`Progress`] reporter can decide
10//! whether it wants to display something.
11
12use std::any::Any;
13use std::collections::VecDeque;
14use std::fmt::{self, Debug, Display, Formatter};
15use std::io::{self, Cursor, ErrorKind, Read};
16use std::sync::Arc;
17use std::time::{Duration, Instant};
18
19#[cfg(feature = "system-downloader")]
20use {
21    ecow::EcoString,
22    native_tls::{Certificate, TlsConnector},
23    once_cell::sync::OnceCell,
24    std::path::PathBuf,
25};
26
27/// Downloads resources from the network.
28///
29/// If the remote returns a `404` status code, the implementation should return
30/// an error with [`io::ErrorKind::NotFound`].
31///
32/// See the module-level docs and [`ProgressDownloader`] for more information on
33/// the `key` argument.
34pub trait Downloader: Send + Sync + 'static {
35    /// Fetches the given URL, returning an optional size hint and a reader for
36    /// the remote data.
37    fn stream(
38        &self,
39        key: &dyn Any,
40        url: &str,
41    ) -> io::Result<(Option<usize>, Box<dyn Read>)>;
42
43    /// Fetches the given URL, returning the full data as a vector.
44    ///
45    /// This is optional to implement. A default implementation in terms of
46    /// `stream` is provided.
47    fn download(&self, key: &dyn Any, url: &str) -> io::Result<Vec<u8>> {
48        let (hint, mut reader) = self.stream(key, url)?;
49        let mut buf = match hint {
50            None => Vec::new(),
51            Some(size) => Vec::with_capacity(size),
52        };
53        reader.read_to_end(&mut buf)?;
54        Ok(buf)
55    }
56}
57
58impl<T: Downloader> Downloader for Box<T> {
59    fn stream(
60        &self,
61        key: &dyn Any,
62        url: &str,
63    ) -> io::Result<(Option<usize>, Box<dyn Read>)> {
64        (**self).stream(key, url)
65    }
66
67    fn download(&self, key: &dyn Any, url: &str) -> io::Result<Vec<u8>> {
68        (**self).download(key, url)
69    }
70}
71
72impl<T: Downloader> Downloader for Arc<T> {
73    fn stream(
74        &self,
75        key: &dyn Any,
76        url: &str,
77    ) -> io::Result<(Option<usize>, Box<dyn Read>)> {
78        (**self).stream(key, url)
79    }
80
81    fn download(&self, key: &dyn Any, url: &str) -> io::Result<Vec<u8>> {
82        (**self).download(key, url)
83    }
84}
85
86/// A minimal HTTPS client for downloads.
87///
88/// Uses system-native TLS and respects proxying environment variables.
89#[cfg(feature = "system-downloader")]
90pub struct SystemDownloader {
91    user_agent: EcoString,
92    cert_path: Option<PathBuf>,
93    cert: OnceCell<Certificate>,
94}
95
96#[cfg(feature = "system-downloader")]
97impl SystemDownloader {
98    /// Creates a new downloader with the given user agent and no certificate.
99    pub fn new(user_agent: impl Into<EcoString>) -> Self {
100        Self {
101            user_agent: user_agent.into(),
102            cert_path: None,
103            cert: OnceCell::new(),
104        }
105    }
106
107    /// Creates a new downloader with the given user agent and certificate.
108    pub fn with_cert(user_agent: impl Into<EcoString>, cert: Certificate) -> Self {
109        Self {
110            user_agent: user_agent.into(),
111            cert_path: None,
112            cert: OnceCell::with_value(cert),
113        }
114    }
115
116    /// Creates a new downloader with the given user agent and certificate path.
117    ///
118    /// If the certificate cannot be read, it is ignored.
119    pub fn with_cert_path(user_agent: impl Into<EcoString>, cert_path: PathBuf) -> Self {
120        Self {
121            user_agent: user_agent.into(),
122            cert_path: Some(cert_path),
123            cert: OnceCell::new(),
124        }
125    }
126
127    /// Returns the certificate this client is using, if a custom certificate is
128    /// used it is loaded on first access.
129    ///
130    /// - Returns `None` if no certificate was configured.
131    /// - Returns `Some(Ok(cert))` if the certificate was loaded successfully.
132    /// - Returns `Some(Err(err))` if an error occurred while loading the
133    ///   certificate.
134    fn cert(&self) -> Option<io::Result<&Certificate>> {
135        if let Some(cert) = self.cert.get() {
136            return Some(Ok(cert));
137        }
138
139        self.cert_path.as_ref().map(|path| {
140            self.cert.get_or_try_init(|| {
141                let pem = std::fs::read(path)?;
142                Certificate::from_pem(&pem).map_err(io::Error::other)
143            })
144        })
145    }
146}
147
148#[cfg(feature = "system-downloader")]
149impl Downloader for SystemDownloader {
150    fn stream(
151        &self,
152        _: &dyn Any,
153        url: &str,
154    ) -> io::Result<(Option<usize>, Box<dyn Read>)> {
155        let mut builder = ureq::AgentBuilder::new();
156        let mut tls = TlsConnector::builder();
157
158        // Set user agent.
159        builder = builder.user_agent(&self.user_agent);
160
161        // Get the network proxy config from the environment and apply it.
162        if let Some(proxy) = env_proxy::for_url_str(url)
163            .to_url()
164            .and_then(|url| ureq::Proxy::new(url).ok())
165        {
166            builder = builder.proxy(proxy);
167        }
168
169        // Apply a custom CA certificate if present.
170        if let Some(cert) = self.cert() {
171            tls.add_root_certificate(cert?.clone());
172        }
173
174        // Configure native TLS.
175        let connector = tls.build().map_err(io::Error::other)?;
176        builder = builder.tls_connector(Arc::new(connector));
177
178        let response = builder.build().get(url).call().map_err(|err| match err {
179            ureq::Error::Status(404, _) => io::Error::new(io::ErrorKind::NotFound, err),
180            err => io::Error::other(err),
181        })?;
182
183        let content_len: Option<usize> = response
184            .header("Content-Length")
185            .and_then(|header| header.parse().ok());
186
187        Ok((content_len, response.into_reader()))
188    }
189}
190
191#[cfg(feature = "system-downloader")]
192impl Debug for SystemDownloader {
193    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
194        f.debug_struct("SystemDownloader")
195            .field("user_agent", &self.user_agent)
196            .finish_non_exhaustive()
197    }
198}
199
200/// Wraps a downloader and adds progress reporting to it.
201///
202/// Needs
203/// - an underlying downloader
204/// - a factory function that creates an instance of type `Progress`.
205///
206/// The factory function is passed an `&dyn Any` key. A key is provided for each
207/// download and can be used to decide what to print. For instance, the CLI will
208/// display downloads with a [`PackageSpec`](typst_syntax::package::PackageSpec)
209/// key, but not with a `"package index"` key.
210///
211/// Keys used by functionality in `typst-kit` are documented with the respective
212/// functionality.
213pub struct ProgressDownloader<T, F, P>
214where
215    F: Fn(&dyn Any) -> P + Send + Sync + 'static,
216{
217    inner: T,
218    progress: F,
219    period: Duration,
220}
221
222impl<T, F, P> ProgressDownloader<T, F, P>
223where
224    T: Downloader,
225    F: Fn(&dyn Any) -> P + Send + Sync + 'static,
226    P: ProgressReporter + 'static,
227{
228    /// Creates a new progress downloader.
229    pub fn new(inner: T, progress: F) -> Self {
230        Self {
231            inner,
232            progress,
233            period: Duration::from_millis(100),
234        }
235    }
236
237    /// Creates a new progress downloader.
238    pub fn with_interval(inner: T, progress: F, period: Duration) -> Self {
239        Self { inner, progress, period }
240    }
241}
242
243impl<T, F, P> Downloader for ProgressDownloader<T, F, P>
244where
245    T: Downloader,
246    F: Fn(&dyn Any) -> P + Send + Sync + 'static,
247    P: ProgressReporter + 'static,
248{
249    fn download(&self, key: &dyn Any, url: &str) -> io::Result<Vec<u8>> {
250        let (len, reader) = self.inner.stream(key, url)?;
251        let mut progress = (self.progress)(key);
252        let data =
253            ProgressReader::new(len, reader, self.period, &mut progress).download()?;
254        Ok(data)
255    }
256
257    fn stream(
258        &self,
259        key: &dyn Any,
260        url: &str,
261    ) -> io::Result<(Option<usize>, Box<dyn Read>)> {
262        let data = self.inner.download(key, url)?;
263        Ok((Some(data.len()), Box::new(Cursor::new(data))))
264    }
265}
266
267impl<T, F, P> Debug for ProgressDownloader<T, F, P>
268where
269    T: Debug,
270    F: Fn(&dyn Any) -> P + Send + Sync + 'static,
271{
272    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
273        f.debug_struct("ProgressDownloader")
274            .field("inner", &self.inner)
275            .finish_non_exhaustive()
276    }
277}
278
279/// Manages progress reporting for downloads.
280pub trait ProgressReporter {
281    /// Invoked when a download is started.
282    fn start(&mut self, progress: &Progress);
283
284    /// Invoked repeatedly while a download is ongoing.
285    fn update(&mut self, progress: &Progress);
286
287    /// Invoked when a download is finished.
288    fn finish(&mut self, progress: &Progress);
289}
290
291/// The current progress of a download.
292#[derive(Debug)]
293pub struct Progress {
294    /// The download starting instant.
295    pub start_time: Instant,
296    /// The expected amount of bytes to download, `None` if the response header
297    /// was not set.
298    pub content_len: Option<usize>,
299    /// The total amount of downloaded bytes until now.
300    pub downloaded: usize,
301    /// A backlog of the amount of downloaded bytes for each bucket.
302    pub samples: VecDeque<usize>,
303    /// The duration of each bucket (in samples).
304    pub period: Duration,
305}
306
307impl Display for Progress {
308    /// Formats several download statistics for display.
309    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
310        let len = self.samples.len();
311        let sum: usize = self.samples.iter().sum();
312        let bytes_per_period = sum.checked_div(len).or(self.content_len).unwrap_or(0);
313
314        let frequency: usize = Duration::from_secs(1)
315            .as_nanos()
316            .checked_div(self.period.as_nanos())
317            .and_then(|s| s.try_into().ok())
318            .unwrap_or(1);
319        let bytes_per_sec = bytes_per_period * frequency;
320
321        match self.content_len {
322            Some(content_len) => {
323                let ratio = self.downloaded as f64 / content_len as f64;
324                let remaining_bytes = content_len - self.downloaded;
325                let remaining_buckets: u32 = remaining_bytes
326                    .checked_div(bytes_per_period)
327                    .and_then(|c| c.try_into().ok())
328                    .unwrap_or(0);
329
330                let eta = self.period * remaining_buckets;
331
332                write!(
333                    f,
334                    "{downloaded} / {total} ({percent:3.0} %), {bytes}/s, ETA: {eta}",
335                    downloaded = format_byte_unit(self.downloaded),
336                    total = format_byte_unit(content_len),
337                    percent = 100.0 * ratio,
338                    bytes = format_byte_unit(bytes_per_sec),
339                    eta = format_seconds(eta),
340                )
341            }
342
343            None => write!(
344                f,
345                "{downloaded}, {bytes}/s",
346                downloaded = format_byte_unit(self.downloaded),
347                bytes = format_byte_unit(bytes_per_sec),
348            ),
349        }
350    }
351}
352
353/// Format a given size as a unit of bytes.
354fn format_byte_unit(size: usize) -> impl Display {
355    const KI: f64 = 1024.0;
356    const MI: f64 = KI * KI;
357    const GI: f64 = KI * KI * KI;
358
359    let size = size as f64;
360
361    typst_utils::display(move |f| {
362        if size >= GI {
363            write!(f, "{:5.1} GiB", size / GI)
364        } else if size >= MI {
365            write!(f, "{:5.1} MiB", size / MI)
366        } else if size >= KI {
367            write!(f, "{:5.1} KiB", size / KI)
368        } else {
369            write!(f, "{size:3} B")
370        }
371    })
372}
373
374/// Formats a duration with second precision.
375fn format_seconds(duration: Duration) -> impl Display {
376    typst_utils::display(move |f| write!(f, "{} s", duration.as_secs()))
377}
378
379// Acknowledgment:
380// The `RemoteReader` is closely modeled after rustup's `DownloadTracker`.
381// https://github.com/rust-lang/rustup/blob/master/src/cli/download_tracker.rs
382
383/// Keep track of this many download speed samples.
384const SAMPLES: usize = 25;
385
386/// A wrapper around [`ureq::Response`] that reads the response body in chunks
387/// over a websocket and reports its progress.
388struct ProgressReader<'p> {
389    /// The reader returned by the ureq::Response.
390    reader: Box<dyn Read>,
391    /// The download state, holding download metadata for progress reporting.
392    state: Progress,
393    /// The instant at which progress was last reported.
394    last_progress: Option<Instant>,
395    /// A trait object used to report download progress.
396    progress: &'p mut dyn ProgressReporter,
397}
398
399impl<'p> ProgressReader<'p> {
400    /// Wraps a [`ureq::Response`] and prepares it for downloading.
401    ///
402    /// The 'Content-Length' header is used as a size hint for read
403    /// optimization, if present.
404    fn new(
405        content_len: Option<usize>,
406        reader: Box<dyn Read>,
407        period: Duration,
408        progress: &'p mut dyn ProgressReporter,
409    ) -> Self {
410        Self {
411            reader,
412            last_progress: None,
413            state: Progress {
414                content_len,
415                downloaded: 0,
416                samples: VecDeque::with_capacity(SAMPLES),
417                start_time: Instant::now(),
418                period,
419            },
420            progress,
421        }
422    }
423
424    /// Download the body's content as raw bytes while reporting download
425    /// progress.
426    fn download(mut self) -> io::Result<Vec<u8>> {
427        let mut buffer = vec![0; 8192];
428        let mut data = match self.state.content_len {
429            Some(content_len) => Vec::with_capacity(content_len),
430            None => Vec::with_capacity(8192),
431        };
432
433        self.progress.start(&self.state);
434
435        let mut downloaded_this_period = 0;
436        loop {
437            let read = match self.reader.read(&mut buffer) {
438                Ok(0) => break,
439                Ok(n) => n,
440                // If the data is not yet ready but will be available eventually
441                // keep trying until we either get an actual error, receive data
442                // or an Ok(0).
443                Err(ref e) if e.kind() == ErrorKind::Interrupted => continue,
444                Err(e) => return Err(e),
445            };
446
447            data.extend(&buffer[..read]);
448
449            let last_printed = match self.last_progress {
450                Some(prev) => prev,
451                None => {
452                    let current_time = Instant::now();
453                    self.last_progress = Some(current_time);
454                    current_time
455                }
456            };
457            let elapsed = Instant::now().saturating_duration_since(last_printed);
458
459            downloaded_this_period += read;
460            self.state.downloaded += read;
461
462            if elapsed >= self.state.period {
463                if self.state.samples.len() == SAMPLES {
464                    self.state.samples.pop_back();
465                }
466
467                self.state.samples.push_front(downloaded_this_period);
468                downloaded_this_period = 0;
469
470                self.progress.update(&self.state);
471                self.last_progress = Some(Instant::now());
472            }
473        }
474
475        self.progress.finish(&self.state);
476        Ok(data)
477    }
478}