tor_dirclient/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2#![doc = include_str!("../README.md")]
3// @@ begin lint list maintained by maint/add_warning @@
4#![allow(renamed_and_removed_lints)] // @@REMOVE_WHEN(ci_arti_stable)
5#![allow(unknown_lints)] // @@REMOVE_WHEN(ci_arti_nightly)
6#![warn(missing_docs)]
7#![warn(noop_method_call)]
8#![warn(unreachable_pub)]
9#![warn(clippy::all)]
10#![deny(clippy::await_holding_lock)]
11#![deny(clippy::cargo_common_metadata)]
12#![deny(clippy::cast_lossless)]
13#![deny(clippy::checked_conversions)]
14#![warn(clippy::cognitive_complexity)]
15#![deny(clippy::debug_assert_with_mut_call)]
16#![deny(clippy::exhaustive_enums)]
17#![deny(clippy::exhaustive_structs)]
18#![deny(clippy::expl_impl_clone_on_copy)]
19#![deny(clippy::fallible_impl_from)]
20#![deny(clippy::implicit_clone)]
21#![deny(clippy::large_stack_arrays)]
22#![warn(clippy::manual_ok_or)]
23#![deny(clippy::missing_docs_in_private_items)]
24#![warn(clippy::needless_borrow)]
25#![warn(clippy::needless_pass_by_value)]
26#![warn(clippy::option_option)]
27#![deny(clippy::print_stderr)]
28#![deny(clippy::print_stdout)]
29#![warn(clippy::rc_buffer)]
30#![deny(clippy::ref_option_ref)]
31#![warn(clippy::semicolon_if_nothing_returned)]
32#![warn(clippy::trait_duplication_in_bounds)]
33#![deny(clippy::unchecked_time_subtraction)]
34#![deny(clippy::unnecessary_wraps)]
35#![warn(clippy::unseparated_literal_suffix)]
36#![deny(clippy::unwrap_used)]
37#![deny(clippy::mod_module_files)]
38#![allow(clippy::let_unit_value)] // This can reasonably be done for explicitness
39#![allow(clippy::uninlined_format_args)]
40#![allow(clippy::significant_drop_in_scrutinee)] // arti/-/merge_requests/588/#note_2812945
41#![allow(clippy::result_large_err)] // temporary workaround for arti#587
42#![allow(clippy::needless_raw_string_hashes)] // complained-about code is fine, often best
43#![allow(clippy::needless_lifetimes)] // See arti#1765
44#![allow(mismatched_lifetime_syntaxes)] // temporary workaround for arti#2060
45//! <!-- @@ end lint list maintained by maint/add_warning @@ -->
46
47// TODO probably remove this at some point - see tpo/core/arti#1060
48#![cfg_attr(
49    not(all(feature = "full", feature = "experimental")),
50    allow(unused_imports)
51)]
52
53mod err;
54pub mod request;
55mod response;
56mod util;
57
58use tor_circmgr::{CircMgr, DirInfo};
59use tor_error::bad_api_usage;
60use tor_rtcompat::{Runtime, SleepProvider, SleepProviderExt};
61
62// Zlib is required; the others are optional.
63#[cfg(feature = "xz")]
64use async_compression::futures::bufread::XzDecoder;
65use async_compression::futures::bufread::ZlibDecoder;
66#[cfg(feature = "zstd")]
67use async_compression::futures::bufread::ZstdDecoder;
68
69use futures::FutureExt;
70use futures::io::{
71    AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader,
72};
73use memchr::memchr;
74use std::sync::Arc;
75use std::time::Duration;
76use tracing::{info, instrument};
77
78pub use err::{Error, RequestError, RequestFailedError};
79pub use response::{DirResponse, SourceInfo};
80
81/// Type for results returned in this crate.
82pub type Result<T> = std::result::Result<T, Error>;
83
84/// Type for internal results  containing a RequestError.
85pub type RequestResult<T> = std::result::Result<T, RequestError>;
86
87/// Flag to declare whether a request is anonymized or not.
88///
89/// Some requests (like those to download onion service descriptors) are always
90/// anonymized, and should never be sent in a way that leaks information about
91/// our settings or configuration.
92#[derive(Copy, Clone, Debug, Eq, PartialEq)]
93#[non_exhaustive]
94pub enum AnonymizedRequest {
95    /// This request should not leak any information about our configuration.
96    Anonymized,
97    /// This request is allowed to include information about our capabilities.
98    Direct,
99}
100
101/// Fetch the resource described by `req` over the Tor network.
102///
103/// Circuits are built or found using `circ_mgr`, using paths
104/// constructed using `dirinfo`.
105///
106/// For more fine-grained control over the circuit and stream used,
107/// construct them yourself, and then call [`send_request`] instead.
108///
109/// # TODO
110///
111/// This is the only function in this crate that knows about CircMgr and
112/// DirInfo.  Perhaps this function should move up a level into DirMgr?
113#[instrument(level = "trace", skip_all)]
114pub async fn get_resource<CR, R, SP>(
115    req: &CR,
116    dirinfo: DirInfo<'_>,
117    runtime: &SP,
118    circ_mgr: Arc<CircMgr<R>>,
119) -> Result<DirResponse>
120where
121    CR: request::Requestable + ?Sized,
122    R: Runtime,
123    SP: SleepProvider,
124{
125    let tunnel = circ_mgr.get_or_launch_dir(dirinfo).await?;
126
127    if req.anonymized() == AnonymizedRequest::Anonymized {
128        return Err(bad_api_usage!("Tried to use get_resource for an anonymized request").into());
129    }
130
131    // TODO(nickm) This should be an option, and is too long.
132    let begin_timeout = Duration::from_secs(5);
133    let source = match SourceInfo::from_tunnel(&tunnel) {
134        Ok(source) => source,
135        Err(e) => {
136            return Err(Error::RequestFailed(RequestFailedError {
137                source: None,
138                error: e.into(),
139            }));
140        }
141    };
142
143    let wrap_err = |error| {
144        Error::RequestFailed(RequestFailedError {
145            source: source.clone(),
146            error,
147        })
148    };
149
150    req.check_circuit(&tunnel).await.map_err(wrap_err)?;
151
152    // Launch the stream.
153    let mut stream = runtime
154        .timeout(begin_timeout, tunnel.begin_dir_stream())
155        .await
156        .map_err(RequestError::from)
157        .map_err(wrap_err)?
158        .map_err(RequestError::from)
159        .map_err(wrap_err)?; // TODO(nickm) handle fatalities here too
160
161    // TODO: Perhaps we want separate timeouts for each phase of this.
162    // For now, we just use higher-level timeouts in `dirmgr`.
163    let r = send_request(runtime, req, &mut stream, source.clone()).await;
164
165    if should_retire_circ(&r) {
166        retire_circ(&circ_mgr, &tunnel.unique_id(), "Partial response");
167    }
168
169    r
170}
171
172/// Return true if `result` holds an error indicating that we should retire the
173/// circuit used for the corresponding request.
174fn should_retire_circ(result: &Result<DirResponse>) -> bool {
175    match result {
176        Err(e) => e.should_retire_circ(),
177        Ok(dr) => dr.error().map(RequestError::should_retire_circ) == Some(true),
178    }
179}
180
181/// Fetch a Tor directory object from a provided stream.
182#[deprecated(since = "0.8.1", note = "Use send_request instead.")]
183pub async fn download<R, S, SP>(
184    runtime: &SP,
185    req: &R,
186    stream: &mut S,
187    source: Option<SourceInfo>,
188) -> Result<DirResponse>
189where
190    R: request::Requestable + ?Sized,
191    S: AsyncRead + AsyncWrite + Send + Unpin,
192    SP: SleepProvider,
193{
194    send_request(runtime, req, stream, source).await
195}
196
197/// Fetch or upload a Tor directory object using the provided stream.
198///
199/// To do this, we send a simple HTTP/1.0 request for the described
200/// object in `req` over `stream`, and then wait for a response.  In
201/// log messages, we describe the origin of the data as coming from
202/// `source`.
203///
204/// # Notes
205///
206/// It's kind of bogus to have a 'source' field here at all; we may
207/// eventually want to remove it.
208///
209/// This function doesn't close the stream; you may want to do that
210/// yourself.
211///
212/// The only error variant returned is [`Error::RequestFailed`].
213// TODO: should the error return type change to `RequestFailedError`?
214// If so, that would simplify some code in_dirmgr::bridgedesc.
215pub async fn send_request<R, S, SP>(
216    runtime: &SP,
217    req: &R,
218    stream: &mut S,
219    source: Option<SourceInfo>,
220) -> Result<DirResponse>
221where
222    R: request::Requestable + ?Sized,
223    S: AsyncRead + AsyncWrite + Send + Unpin,
224    SP: SleepProvider,
225{
226    let wrap_err = |error| {
227        Error::RequestFailed(RequestFailedError {
228            source: source.clone(),
229            error,
230        })
231    };
232
233    let partial_ok = req.partial_response_body_ok();
234    let maxlen = req.max_response_len();
235    let anonymized = req.anonymized();
236    let req = req.make_request().map_err(wrap_err)?;
237    let encoded = util::encode_request(&req);
238
239    // Write the request.
240    stream
241        .write_all(encoded.as_bytes())
242        .await
243        .map_err(RequestError::from)
244        .map_err(wrap_err)?;
245    stream
246        .flush()
247        .await
248        .map_err(RequestError::from)
249        .map_err(wrap_err)?;
250
251    let mut buffered = BufReader::new(stream);
252
253    // Handle the response
254    // TODO: should there be a separate timeout here?
255    let header = read_headers(&mut buffered).await.map_err(wrap_err)?;
256    if header.status != Some(200) {
257        return Ok(DirResponse::new(
258            header.status.unwrap_or(0),
259            header.status_message,
260            None,
261            vec![],
262            source,
263        ));
264    }
265
266    let mut decoder =
267        get_decoder(buffered, header.encoding.as_deref(), anonymized).map_err(wrap_err)?;
268
269    let mut result = Vec::new();
270    let ok = read_and_decompress(runtime, &mut decoder, maxlen, &mut result).await;
271
272    let ok = match (partial_ok, ok, result.len()) {
273        (true, Err(e), n) if n > 0 => {
274            // Note that we _don't_ return here: we want the partial response.
275            Err(e)
276        }
277        (_, Err(e), _) => {
278            return Err(wrap_err(e));
279        }
280        (_, Ok(()), _) => Ok(()),
281    };
282
283    Ok(DirResponse::new(200, None, ok.err(), result, source))
284}
285
286/// Maximum length for the HTTP headers in a single request or response.
287///
288/// Chosen more or less arbitrarily.
289const MAX_HEADERS_LEN: usize = 16384;
290
291/// Read and parse HTTP/1 headers from `stream`.
292async fn read_headers<S>(stream: &mut S) -> RequestResult<HeaderStatus>
293where
294    S: AsyncBufRead + Unpin,
295{
296    let mut buf = Vec::with_capacity(1024);
297
298    loop {
299        // TODO: it's inefficient to do this a line at a time; it would
300        // probably be better to read until the CRLF CRLF ending of the
301        // response.  But this should be fast enough.
302        let n = read_until_limited(stream, b'\n', 2048, &mut buf).await?;
303
304        // TODO(nickm): Better maximum and/or let this expand.
305        let mut headers = [httparse::EMPTY_HEADER; 32];
306        let mut response = httparse::Response::new(&mut headers);
307
308        match response.parse(&buf[..])? {
309            httparse::Status::Partial => {
310                // We didn't get a whole response; we may need to try again.
311
312                if n == 0 {
313                    // We hit an EOF; no more progress can be made.
314                    return Err(RequestError::TruncatedHeaders);
315                }
316
317                if buf.len() >= MAX_HEADERS_LEN {
318                    return Err(RequestError::HeadersTooLong(buf.len()));
319                }
320            }
321            httparse::Status::Complete(n_parsed) => {
322                if response.code != Some(200) {
323                    return Ok(HeaderStatus {
324                        status: response.code,
325                        status_message: response.reason.map(str::to_owned),
326                        encoding: None,
327                    });
328                }
329                let encoding = if let Some(enc) = response
330                    .headers
331                    .iter()
332                    .find(|h| h.name == "Content-Encoding")
333                {
334                    Some(String::from_utf8(enc.value.to_vec())?)
335                } else {
336                    None
337                };
338                /*
339                if let Some(clen) = response.headers.iter().find(|h| h.name == "Content-Length") {
340                    let clen = std::str::from_utf8(clen.value)?;
341                    length = Some(clen.parse()?);
342                }
343                 */
344                assert!(n_parsed == buf.len());
345                return Ok(HeaderStatus {
346                    status: Some(200),
347                    status_message: None,
348                    encoding,
349                });
350            }
351        }
352        if n == 0 {
353            return Err(RequestError::TruncatedHeaders);
354        }
355    }
356}
357
358/// Return value from read_headers
359#[derive(Debug, Clone)]
360struct HeaderStatus {
361    /// HTTP status code.
362    status: Option<u16>,
363    /// HTTP status message associated with the status code.
364    status_message: Option<String>,
365    /// The Content-Encoding header, if any.
366    encoding: Option<String>,
367}
368
369/// Helper: download directory information from `stream` and
370/// decompress it into a result buffer.  Assumes that `buf` is empty.
371///
372/// If we get more than maxlen bytes after decompression, give an error.
373///
374/// Returns the status of our download attempt, stores any data that
375/// we were able to download into `result`.  Existing contents of
376/// `result` are overwritten.
377async fn read_and_decompress<S, SP>(
378    runtime: &SP,
379    mut stream: S,
380    maxlen: usize,
381    result: &mut Vec<u8>,
382) -> RequestResult<()>
383where
384    S: AsyncRead + Unpin,
385    SP: SleepProvider,
386{
387    let buffer_window_size = 1024;
388    let mut written_total: usize = 0;
389    // TODO(nickm): This should be an option, and is maybe too long.
390    // Though for some users it may be too short?
391    let read_timeout = Duration::from_secs(10);
392    let timer = runtime.sleep(read_timeout).fuse();
393    futures::pin_mut!(timer);
394
395    loop {
396        // allocate buffer for next read
397        result.resize(written_total + buffer_window_size, 0);
398        let buf: &mut [u8] = &mut result[written_total..written_total + buffer_window_size];
399
400        let status = futures::select! {
401            status = stream.read(buf).fuse() => status,
402            _ = timer => {
403                result.resize(written_total, 0); // truncate as needed
404                return Err(RequestError::DirTimeout);
405            }
406        };
407        let written_in_this_loop = match status {
408            Ok(n) => n,
409            Err(other) => {
410                result.resize(written_total, 0); // truncate as needed
411                return Err(other.into());
412            }
413        };
414
415        written_total += written_in_this_loop;
416
417        // exit conditions below
418
419        if written_in_this_loop == 0 {
420            /*
421            in case we read less than `buffer_window_size` in last `read`
422            we need to shrink result because otherwise we'll return those
423            un-read 0s
424            */
425            if written_total < result.len() {
426                result.resize(written_total, 0);
427            }
428            return Ok(());
429        }
430
431        // TODO: It would be good to detect compression bombs, but
432        // that would require access to the internal stream, which
433        // would in turn require some tricky programming.  For now, we
434        // use the maximum length here to prevent an attacker from
435        // filling our RAM.
436        if written_total > maxlen {
437            result.resize(maxlen, 0);
438            return Err(RequestError::ResponseTooLong(written_total));
439        }
440    }
441}
442
443/// Retire a directory circuit because of an error we've encountered on it.
444fn retire_circ<R>(circ_mgr: &Arc<CircMgr<R>>, id: &tor_proto::circuit::UniqId, error: &str)
445where
446    R: Runtime,
447{
448    info!(
449        "{}: Retiring circuit because of directory failure: {}",
450        &id, &error
451    );
452    circ_mgr.retire_circ(id);
453}
454
455/// As AsyncBufReadExt::read_until, but stops after reading `max` bytes.
456///
457/// Note that this function might not actually read any byte of value
458/// `byte`, since EOF might occur, or we might fill the buffer.
459///
460/// A return value of 0 indicates an end-of-file.
461async fn read_until_limited<S>(
462    stream: &mut S,
463    byte: u8,
464    max: usize,
465    buf: &mut Vec<u8>,
466) -> std::io::Result<usize>
467where
468    S: AsyncBufRead + Unpin,
469{
470    let mut n_added = 0;
471    loop {
472        let data = stream.fill_buf().await?;
473        if data.is_empty() {
474            // End-of-file has been reached.
475            return Ok(n_added);
476        }
477        debug_assert!(n_added < max);
478        let remaining_space = max - n_added;
479        let (available, found_byte) = match memchr(byte, data) {
480            Some(idx) => (idx + 1, true),
481            None => (data.len(), false),
482        };
483        debug_assert!(available >= 1);
484        let n_to_copy = std::cmp::min(remaining_space, available);
485        buf.extend(&data[..n_to_copy]);
486        stream.consume_unpin(n_to_copy);
487        n_added += n_to_copy;
488        if found_byte || n_added == max {
489            return Ok(n_added);
490        }
491    }
492}
493
494/// Helper: Return a boxed decoder object that wraps the stream  $s.
495macro_rules! decoder {
496    ($dec:ident, $s:expr) => {{
497        let mut decoder = $dec::new($s);
498        decoder.multiple_members(true);
499        Ok(Box::new(decoder))
500    }};
501}
502
503/// Wrap `stream` in an appropriate type to undo the content encoding
504/// as described in `encoding`.
505fn get_decoder<'a, S: AsyncBufRead + Unpin + Send + 'a>(
506    stream: S,
507    encoding: Option<&str>,
508    anonymized: AnonymizedRequest,
509) -> RequestResult<Box<dyn AsyncRead + Unpin + Send + 'a>> {
510    use AnonymizedRequest::Direct;
511    match (encoding, anonymized) {
512        (None | Some("identity"), _) => Ok(Box::new(stream)),
513        (Some("deflate"), _) => decoder!(ZlibDecoder, stream),
514        // We only admit to supporting these on a direct connection; otherwise,
515        // a hostile directory could send them back even though we hadn't
516        // requested them.
517        #[cfg(feature = "xz")]
518        (Some("x-tor-lzma"), Direct) => decoder!(XzDecoder, stream),
519        #[cfg(feature = "zstd")]
520        (Some("x-zstd"), Direct) => decoder!(ZstdDecoder, stream),
521        (Some(other), _) => Err(RequestError::ContentEncoding(other.into())),
522    }
523}
524
525#[cfg(test)]
526mod test {
527    // @@ begin test lint list maintained by maint/add_warning @@
528    #![allow(clippy::bool_assert_comparison)]
529    #![allow(clippy::clone_on_copy)]
530    #![allow(clippy::dbg_macro)]
531    #![allow(clippy::mixed_attributes_style)]
532    #![allow(clippy::print_stderr)]
533    #![allow(clippy::print_stdout)]
534    #![allow(clippy::single_char_pattern)]
535    #![allow(clippy::unwrap_used)]
536    #![allow(clippy::unchecked_time_subtraction)]
537    #![allow(clippy::useless_vec)]
538    #![allow(clippy::needless_pass_by_value)]
539    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
540    use super::*;
541    use tor_rtmock::io::stream_pair;
542
543    use tor_rtmock::simple_time::SimpleMockTimeProvider;
544
545    use futures_await_test::async_test;
546
547    #[async_test]
548    async fn test_read_until_limited() -> RequestResult<()> {
549        let mut out = Vec::new();
550        let bytes = b"This line eventually ends\nthen comes another\n";
551
552        // Case 1: find a whole line.
553        let mut s = &bytes[..];
554        let res = read_until_limited(&mut s, b'\n', 100, &mut out).await;
555        assert_eq!(res?, 26);
556        assert_eq!(&out[..], b"This line eventually ends\n");
557
558        // Case 2: reach the limit.
559        let mut s = &bytes[..];
560        out.clear();
561        let res = read_until_limited(&mut s, b'\n', 10, &mut out).await;
562        assert_eq!(res?, 10);
563        assert_eq!(&out[..], b"This line ");
564
565        // Case 3: reach EOF.
566        let mut s = &bytes[..];
567        out.clear();
568        let res = read_until_limited(&mut s, b'Z', 100, &mut out).await;
569        assert_eq!(res?, 45);
570        assert_eq!(&out[..], &bytes[..]);
571
572        Ok(())
573    }
574
575    // Basic decompression wrapper.
576    async fn decomp_basic(
577        encoding: Option<&str>,
578        data: &[u8],
579        maxlen: usize,
580    ) -> (RequestResult<()>, Vec<u8>) {
581        // We don't need to do anything fancy here, since we aren't simulating
582        // a timeout.
583        #[allow(deprecated)] // TODO #1885
584        let mock_time = SimpleMockTimeProvider::from_wallclock(std::time::SystemTime::now());
585
586        let mut output = Vec::new();
587        let mut stream = match get_decoder(data, encoding, AnonymizedRequest::Direct) {
588            Ok(s) => s,
589            Err(e) => return (Err(e), output),
590        };
591
592        let r = read_and_decompress(&mock_time, &mut stream, maxlen, &mut output).await;
593
594        (r, output)
595    }
596
597    #[async_test]
598    async fn decompress_identity() -> RequestResult<()> {
599        let mut text = Vec::new();
600        for _ in 0..1000 {
601            text.extend(b"This is a string with a nontrivial length that we'll use to make sure that the loop is executed more than once.");
602        }
603
604        let limit = 10 << 20;
605        let (s, r) = decomp_basic(None, &text[..], limit).await;
606        s?;
607        assert_eq!(r, text);
608
609        let (s, r) = decomp_basic(Some("identity"), &text[..], limit).await;
610        s?;
611        assert_eq!(r, text);
612
613        // Try truncated result
614        let limit = 100;
615        let (s, r) = decomp_basic(Some("identity"), &text[..], limit).await;
616        assert!(s.is_err());
617        assert_eq!(r, &text[..100]);
618
619        Ok(())
620    }
621
622    #[async_test]
623    async fn decomp_zlib() -> RequestResult<()> {
624        let compressed =
625            hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5970c88").unwrap();
626
627        let limit = 10 << 20;
628        let (s, r) = decomp_basic(Some("deflate"), &compressed, limit).await;
629        s?;
630        assert_eq!(r, b"One fish Two fish Red fish Blue fish");
631
632        Ok(())
633    }
634
635    #[cfg(feature = "zstd")]
636    #[async_test]
637    async fn decomp_zstd() -> RequestResult<()> {
638        let compressed = hex::decode("28b52ffd24250d0100c84f6e6520666973682054776f526564426c756520666973680a0200600c0e2509478352cb").unwrap();
639        let limit = 10 << 20;
640        let (s, r) = decomp_basic(Some("x-zstd"), &compressed, limit).await;
641        s?;
642        assert_eq!(r, b"One fish Two fish Red fish Blue fish\n");
643
644        Ok(())
645    }
646
647    #[cfg(feature = "xz")]
648    #[async_test]
649    async fn decomp_xz2() -> RequestResult<()> {
650        // Not so good at tiny files...
651        let compressed = hex::decode("fd377a585a000004e6d6b446020021011c00000010cf58cce00024001d5d00279b88a202ca8612cfb3c19c87c34248a570451e4851d3323d34ab8000000000000901af64854c91f600013925d6ec06651fb6f37d010000000004595a").unwrap();
652        let limit = 10 << 20;
653        let (s, r) = decomp_basic(Some("x-tor-lzma"), &compressed, limit).await;
654        s?;
655        assert_eq!(r, b"One fish Two fish Red fish Blue fish\n");
656
657        Ok(())
658    }
659
660    #[async_test]
661    async fn decomp_unknown() {
662        let compressed = hex::decode("28b52ffd24250d0100c84f6e6520666973682054776f526564426c756520666973680a0200600c0e2509478352cb").unwrap();
663        let limit = 10 << 20;
664        let (s, _r) = decomp_basic(Some("x-proprietary-rle"), &compressed, limit).await;
665
666        assert!(matches!(s, Err(RequestError::ContentEncoding(_))));
667    }
668
669    #[async_test]
670    async fn decomp_bad_data() {
671        let compressed = b"This is not good zlib data";
672        let limit = 10 << 20;
673        let (s, _r) = decomp_basic(Some("deflate"), compressed, limit).await;
674
675        // This should possibly be a different type in the future.
676        assert!(matches!(s, Err(RequestError::IoError(_))));
677    }
678
679    #[async_test]
680    async fn headers_ok() -> RequestResult<()> {
681        let text = b"HTTP/1.0 200 OK\r\nDate: ignored\r\nContent-Encoding: Waffles\r\n\r\n";
682
683        let mut s = &text[..];
684        let h = read_headers(&mut s).await?;
685
686        assert_eq!(h.status, Some(200));
687        assert_eq!(h.encoding.as_deref(), Some("Waffles"));
688
689        // now try truncated
690        let mut s = &text[..15];
691        let h = read_headers(&mut s).await;
692        assert!(matches!(h, Err(RequestError::TruncatedHeaders)));
693
694        // now try with no encoding.
695        let text = b"HTTP/1.0 404 Not found\r\n\r\n";
696        let mut s = &text[..];
697        let h = read_headers(&mut s).await?;
698
699        assert_eq!(h.status, Some(404));
700        assert!(h.encoding.is_none());
701
702        Ok(())
703    }
704
705    #[async_test]
706    async fn headers_bogus() -> Result<()> {
707        let text = b"HTTP/999.0 WHAT EVEN\r\n\r\n";
708        let mut s = &text[..];
709        let h = read_headers(&mut s).await;
710
711        assert!(h.is_err());
712        assert!(matches!(h, Err(RequestError::HttparseError(_))));
713        Ok(())
714    }
715
716    /// Run a trivial download example with a response provided as a binary
717    /// string.
718    ///
719    /// Return the directory response (if any) and the request as encoded (if
720    /// any.)
721    fn run_download_test<Req: request::Requestable>(
722        req: Req,
723        response: &[u8],
724    ) -> (Result<DirResponse>, RequestResult<Vec<u8>>) {
725        let (mut s1, s2) = stream_pair();
726        let (mut s2_r, mut s2_w) = s2.split();
727
728        tor_rtcompat::test_with_one_runtime!(|rt| async move {
729            let rt2 = rt.clone();
730            let (v1, v2, v3): (
731                Result<DirResponse>,
732                RequestResult<Vec<u8>>,
733                RequestResult<()>,
734            ) = futures::join!(
735                async {
736                    // Run the download function.
737                    let r = send_request(&rt, &req, &mut s1, None).await;
738                    s1.close().await.map_err(|error| {
739                        Error::RequestFailed(RequestFailedError {
740                            source: None,
741                            error: error.into(),
742                        })
743                    })?;
744                    r
745                },
746                async {
747                    // Take the request from the client, and return it in "v2"
748                    let mut v = Vec::new();
749                    s2_r.read_to_end(&mut v).await?;
750                    Ok(v)
751                },
752                async {
753                    // Send back a response.
754                    s2_w.write_all(response).await?;
755                    // We wait a moment to give the other side time to notice it
756                    // has data.
757                    //
758                    // (Tentative diagnosis: The `async-compress` crate seems to
759                    // be behave differently depending on whether the "close"
760                    // comes right after the incomplete data or whether it comes
761                    // after a delay.  If there's a delay, it notices the
762                    // truncated data and tells us about it. But when there's
763                    // _no_delay, it treats the data as an error and doesn't
764                    // tell our code.)
765
766                    // TODO: sleeping in tests is not great.
767                    rt2.sleep(Duration::from_millis(50)).await;
768                    s2_w.close().await?;
769                    Ok(())
770                }
771            );
772
773            assert!(v3.is_ok());
774
775            (v1, v2)
776        })
777    }
778
779    #[test]
780    fn test_send_request() -> RequestResult<()> {
781        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
782
783        let (response, request) = run_download_test(
784            req,
785            b"HTTP/1.0 200 OK\r\n\r\nThis is where the descs would go.",
786        );
787
788        let request = request?;
789        assert!(request[..].starts_with(
790            b"GET /tor/micro/d/CQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQk HTTP/1.0\r\n"
791        ));
792
793        let response = response.unwrap();
794        assert_eq!(response.status_code(), 200);
795        assert!(!response.is_partial());
796        assert!(response.error().is_none());
797        assert!(response.source().is_none());
798        let out_ref = response.output_unchecked();
799        assert_eq!(out_ref, b"This is where the descs would go.");
800        let out = response.into_output_unchecked();
801        assert_eq!(&out, b"This is where the descs would go.");
802
803        Ok(())
804    }
805
806    #[test]
807    fn test_download_truncated() {
808        // Request only one md, so "partial ok" will not be set.
809        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
810        let mut response_text: Vec<u8> =
811            (*b"HTTP/1.0 200 OK\r\nContent-Encoding: deflate\r\n\r\n").into();
812        // "One fish two fish" as above twice, but truncated the second time
813        response_text.extend(
814            hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5970c88").unwrap(),
815        );
816        response_text.extend(
817            hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5").unwrap(),
818        );
819        let (response, request) = run_download_test(req, &response_text);
820        assert!(request.is_ok());
821        assert!(response.is_err()); // The whole download should fail, since partial_ok wasn't set.
822
823        // request two microdescs, so "partial_ok" will be set.
824        let req: request::MicrodescRequest = vec![[9; 32]; 2].into_iter().collect();
825
826        let (response, request) = run_download_test(req, &response_text);
827        assert!(request.is_ok());
828
829        let response = response.unwrap();
830        assert_eq!(response.status_code(), 200);
831        assert!(response.error().is_some());
832        assert!(response.is_partial());
833        assert!(response.output_unchecked().len() < 37 * 2);
834        assert!(response.output_unchecked().starts_with(b"One fish"));
835    }
836
837    #[test]
838    fn test_404() {
839        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
840        let response_text = b"HTTP/1.0 418 I'm a teapot\r\n\r\n";
841        let (response, _request) = run_download_test(req, response_text);
842
843        assert_eq!(response.unwrap().status_code(), 418);
844    }
845
846    #[test]
847    fn test_headers_truncated() {
848        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
849        let response_text = b"HTTP/1.0 404 truncation happens here\r\n";
850        let (response, _request) = run_download_test(req, response_text);
851
852        assert!(matches!(
853            response,
854            Err(Error::RequestFailed(RequestFailedError {
855                error: RequestError::TruncatedHeaders,
856                ..
857            }))
858        ));
859
860        // Try a completely empty response.
861        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
862        let response_text = b"";
863        let (response, _request) = run_download_test(req, response_text);
864
865        assert!(matches!(
866            response,
867            Err(Error::RequestFailed(RequestFailedError {
868                error: RequestError::TruncatedHeaders,
869                ..
870            }))
871        ));
872    }
873
874    #[test]
875    fn test_headers_too_long() {
876        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
877        let mut response_text: Vec<u8> = (*b"HTTP/1.0 418 I'm a teapot\r\nX-Too-Many-As: ").into();
878        response_text.resize(16384, b'A');
879        let (response, _request) = run_download_test(req, &response_text);
880
881        assert!(response.as_ref().unwrap_err().should_retire_circ());
882        assert!(matches!(
883            response,
884            Err(Error::RequestFailed(RequestFailedError {
885                error: RequestError::HeadersTooLong(_),
886                ..
887            }))
888        ));
889    }
890
891    // TODO: test with bad utf-8
892}