Skip to main content

tor_dirclient/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2#![doc = include_str!("../README.md")]
3// @@ begin lint list maintained by maint/add_warning @@
4#![allow(renamed_and_removed_lints)] // @@REMOVE_WHEN(ci_arti_stable)
5#![allow(unknown_lints)] // @@REMOVE_WHEN(ci_arti_nightly)
6#![warn(missing_docs)]
7#![warn(noop_method_call)]
8#![warn(unreachable_pub)]
9#![warn(clippy::all)]
10#![deny(clippy::await_holding_lock)]
11#![deny(clippy::cargo_common_metadata)]
12#![deny(clippy::cast_lossless)]
13#![deny(clippy::checked_conversions)]
14#![warn(clippy::cognitive_complexity)]
15#![deny(clippy::debug_assert_with_mut_call)]
16#![deny(clippy::exhaustive_enums)]
17#![deny(clippy::exhaustive_structs)]
18#![deny(clippy::expl_impl_clone_on_copy)]
19#![deny(clippy::fallible_impl_from)]
20#![deny(clippy::implicit_clone)]
21#![deny(clippy::large_stack_arrays)]
22#![warn(clippy::manual_ok_or)]
23#![deny(clippy::missing_docs_in_private_items)]
24#![warn(clippy::needless_borrow)]
25#![warn(clippy::needless_pass_by_value)]
26#![warn(clippy::option_option)]
27#![deny(clippy::print_stderr)]
28#![deny(clippy::print_stdout)]
29#![warn(clippy::rc_buffer)]
30#![deny(clippy::ref_option_ref)]
31#![warn(clippy::semicolon_if_nothing_returned)]
32#![warn(clippy::trait_duplication_in_bounds)]
33#![deny(clippy::unchecked_time_subtraction)]
34#![deny(clippy::unnecessary_wraps)]
35#![warn(clippy::unseparated_literal_suffix)]
36#![deny(clippy::unwrap_used)]
37#![deny(clippy::mod_module_files)]
38#![allow(clippy::let_unit_value)] // This can reasonably be done for explicitness
39#![allow(clippy::uninlined_format_args)]
40#![allow(clippy::significant_drop_in_scrutinee)] // arti/-/merge_requests/588/#note_2812945
41#![allow(clippy::result_large_err)] // temporary workaround for arti#587
42#![allow(clippy::needless_raw_string_hashes)] // complained-about code is fine, often best
43#![allow(clippy::needless_lifetimes)] // See arti#1765
44#![allow(mismatched_lifetime_syntaxes)] // temporary workaround for arti#2060
45#![deny(clippy::unused_async)]
46//! <!-- @@ end lint list maintained by maint/add_warning @@ -->
47
48// TODO probably remove this at some point - see tpo/core/arti#1060
49#![cfg_attr(
50    not(all(feature = "full", feature = "experimental")),
51    allow(unused_imports)
52)]
53
54mod err;
55pub mod request;
56mod response;
57mod util;
58
59use tor_circmgr::{CircMgr, DirInfo};
60use tor_error::bad_api_usage;
61use tor_rtcompat::{Runtime, SleepProvider, SleepProviderExt};
62
63// Zlib is required; the others are optional.
64#[cfg(feature = "xz")]
65use async_compression::futures::bufread::XzDecoder;
66use async_compression::futures::bufread::ZlibDecoder;
67#[cfg(feature = "zstd")]
68use async_compression::futures::bufread::ZstdDecoder;
69
70use futures::FutureExt;
71use futures::io::{
72    AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader,
73};
74use memchr::memchr;
75use std::sync::Arc;
76use std::time::Duration;
77use tracing::{info, instrument};
78
79pub use err::{Error, RequestError, RequestFailedError};
80pub use response::{DirResponse, SourceInfo};
81
82/// Type for results returned in this crate.
83pub type Result<T> = std::result::Result<T, Error>;
84
85/// Type for internal results  containing a RequestError.
86pub type RequestResult<T> = std::result::Result<T, RequestError>;
87
88/// Flag to declare whether a request is anonymized or not.
89///
90/// Some requests (like those to download onion service descriptors) are always
91/// anonymized, and should never be sent in a way that leaks information about
92/// our settings or configuration.
93#[derive(Copy, Clone, Debug, Eq, PartialEq)]
94#[non_exhaustive]
95pub enum AnonymizedRequest {
96    /// This request should not leak any information about our configuration.
97    Anonymized,
98    /// This request is allowed to include information about our capabilities.
99    Direct,
100}
101
102/// Fetch the resource described by `req` over the Tor network.
103///
104/// Circuits are built or found using `circ_mgr`, using paths
105/// constructed using `dirinfo`.
106///
107/// For more fine-grained control over the circuit and stream used,
108/// construct them yourself, and then call [`send_request`] instead.
109///
110/// # TODO
111///
112/// This is the only function in this crate that knows about CircMgr and
113/// DirInfo.  Perhaps this function should move up a level into DirMgr?
114#[instrument(level = "trace", skip_all)]
115pub async fn get_resource<CR, R, SP>(
116    req: &CR,
117    dirinfo: DirInfo<'_>,
118    runtime: &SP,
119    circ_mgr: Arc<CircMgr<R>>,
120) -> Result<DirResponse>
121where
122    CR: request::Requestable + ?Sized,
123    R: Runtime,
124    SP: SleepProvider,
125{
126    let tunnel = circ_mgr.get_or_launch_dir(dirinfo).await?;
127
128    if req.anonymized() == AnonymizedRequest::Anonymized {
129        return Err(bad_api_usage!("Tried to use get_resource for an anonymized request").into());
130    }
131
132    // TODO(nickm) This should be an option, and is too long.
133    let begin_timeout = Duration::from_secs(5);
134    let source = match SourceInfo::from_tunnel(&tunnel) {
135        Ok(source) => source,
136        Err(e) => {
137            return Err(Error::RequestFailed(RequestFailedError {
138                source: None,
139                error: e.into(),
140            }));
141        }
142    };
143
144    let wrap_err = |error| {
145        Error::RequestFailed(RequestFailedError {
146            source: source.clone(),
147            error,
148        })
149    };
150
151    req.check_circuit(&tunnel).await.map_err(wrap_err)?;
152
153    // Launch the stream.
154    let mut stream = runtime
155        .timeout(begin_timeout, tunnel.begin_dir_stream())
156        .await
157        .map_err(RequestError::from)
158        .map_err(wrap_err)?
159        .map_err(RequestError::from)
160        .map_err(wrap_err)?; // TODO(nickm) handle fatalities here too
161
162    // TODO: Perhaps we want separate timeouts for each phase of this.
163    // For now, we just use higher-level timeouts in `dirmgr`.
164    let r = send_request(runtime, req, &mut stream, source.clone()).await;
165
166    if should_retire_circ(&r) {
167        retire_circ(&circ_mgr, &tunnel.unique_id(), "Partial response");
168    }
169
170    r
171}
172
173/// Return true if `result` holds an error indicating that we should retire the
174/// circuit used for the corresponding request.
175fn should_retire_circ(result: &Result<DirResponse>) -> bool {
176    match result {
177        Err(e) => e.should_retire_circ(),
178        Ok(dr) => dr.error().map(RequestError::should_retire_circ) == Some(true),
179    }
180}
181
182/// Fetch a Tor directory object from a provided stream.
183#[deprecated(since = "0.8.1", note = "Use send_request instead.")]
184pub async fn download<R, S, SP>(
185    runtime: &SP,
186    req: &R,
187    stream: &mut S,
188    source: Option<SourceInfo>,
189) -> Result<DirResponse>
190where
191    R: request::Requestable + ?Sized,
192    S: AsyncRead + AsyncWrite + Send + Unpin,
193    SP: SleepProvider,
194{
195    send_request(runtime, req, stream, source).await
196}
197
198/// Fetch or upload a Tor directory object using the provided stream.
199///
200/// To do this, we send a simple HTTP/1.0 request for the described
201/// object in `req` over `stream`, and then wait for a response.  In
202/// log messages, we describe the origin of the data as coming from
203/// `source`.
204///
205/// # Notes
206///
207/// It's kind of bogus to have a 'source' field here at all; we may
208/// eventually want to remove it.
209///
210/// This function doesn't close the stream; you may want to do that
211/// yourself.
212///
213/// The only error variant returned is [`Error::RequestFailed`].
214// TODO: should the error return type change to `RequestFailedError`?
215// If so, that would simplify some code in_dirmgr::bridgedesc.
216pub async fn send_request<R, S, SP>(
217    runtime: &SP,
218    req: &R,
219    stream: &mut S,
220    source: Option<SourceInfo>,
221) -> Result<DirResponse>
222where
223    R: request::Requestable + ?Sized,
224    S: AsyncRead + AsyncWrite + Send + Unpin,
225    SP: SleepProvider,
226{
227    let wrap_err = |error| {
228        Error::RequestFailed(RequestFailedError {
229            source: source.clone(),
230            error,
231        })
232    };
233
234    let partial_ok = req.partial_response_body_ok();
235    let maxlen = req.max_response_len();
236    let anonymized = req.anonymized();
237    let req = req.make_request().map_err(wrap_err)?;
238    let encoded = util::encode_request(&req);
239
240    // Write the request.
241    stream
242        .write_all(encoded.as_bytes())
243        .await
244        .map_err(RequestError::from)
245        .map_err(wrap_err)?;
246    stream
247        .flush()
248        .await
249        .map_err(RequestError::from)
250        .map_err(wrap_err)?;
251
252    let mut buffered = BufReader::new(stream);
253
254    // Handle the response
255    // TODO: should there be a separate timeout here?
256    let header = read_headers(&mut buffered).await.map_err(wrap_err)?;
257    if header.status != Some(200) {
258        return Ok(DirResponse::new(
259            header.status.unwrap_or(0),
260            header.status_message,
261            None,
262            vec![],
263            source,
264        ));
265    }
266
267    let mut decoder =
268        get_decoder(buffered, header.encoding.as_deref(), anonymized).map_err(wrap_err)?;
269
270    let mut result = Vec::new();
271    let ok = read_and_decompress(runtime, &mut decoder, maxlen, &mut result).await;
272
273    let ok = match (partial_ok, ok, result.len()) {
274        (true, Err(e), n) if n > 0 => {
275            // Note that we _don't_ return here: we want the partial response.
276            Err(e)
277        }
278        (_, Err(e), _) => {
279            return Err(wrap_err(e));
280        }
281        (_, Ok(()), _) => Ok(()),
282    };
283
284    Ok(DirResponse::new(200, None, ok.err(), result, source))
285}
286
287/// Maximum length for the HTTP headers in a single request or response.
288///
289/// Chosen more or less arbitrarily.
290const MAX_HEADERS_LEN: usize = 16384;
291
292/// Read and parse HTTP/1 headers from `stream`.
293async fn read_headers<S>(stream: &mut S) -> RequestResult<HeaderStatus>
294where
295    S: AsyncBufRead + Unpin,
296{
297    let mut buf = Vec::with_capacity(1024);
298
299    loop {
300        // TODO: it's inefficient to do this a line at a time; it would
301        // probably be better to read until the CRLF CRLF ending of the
302        // response.  But this should be fast enough.
303        let n = read_until_limited(stream, b'\n', 2048, &mut buf).await?;
304
305        // TODO(nickm): Better maximum and/or let this expand.
306        let mut headers = [httparse::EMPTY_HEADER; 32];
307        let mut response = httparse::Response::new(&mut headers);
308
309        match response.parse(&buf[..])? {
310            httparse::Status::Partial => {
311                // We didn't get a whole response; we may need to try again.
312
313                if n == 0 {
314                    // We hit an EOF; no more progress can be made.
315                    return Err(RequestError::TruncatedHeaders);
316                }
317
318                if buf.len() >= MAX_HEADERS_LEN {
319                    return Err(RequestError::HeadersTooLong(buf.len()));
320                }
321            }
322            httparse::Status::Complete(n_parsed) => {
323                if response.code != Some(200) {
324                    return Ok(HeaderStatus {
325                        status: response.code,
326                        status_message: response.reason.map(str::to_owned),
327                        encoding: None,
328                    });
329                }
330                let encoding = if let Some(enc) = response
331                    .headers
332                    .iter()
333                    .find(|h| h.name == "Content-Encoding")
334                {
335                    Some(String::from_utf8(enc.value.to_vec())?)
336                } else {
337                    None
338                };
339                /*
340                if let Some(clen) = response.headers.iter().find(|h| h.name == "Content-Length") {
341                    let clen = std::str::from_utf8(clen.value)?;
342                    length = Some(clen.parse()?);
343                }
344                 */
345                assert!(n_parsed == buf.len());
346                return Ok(HeaderStatus {
347                    status: Some(200),
348                    status_message: None,
349                    encoding,
350                });
351            }
352        }
353        if n == 0 {
354            return Err(RequestError::TruncatedHeaders);
355        }
356    }
357}
358
359/// Return value from read_headers
360#[derive(Debug, Clone)]
361struct HeaderStatus {
362    /// HTTP status code.
363    status: Option<u16>,
364    /// HTTP status message associated with the status code.
365    status_message: Option<String>,
366    /// The Content-Encoding header, if any.
367    encoding: Option<String>,
368}
369
370/// Helper: download directory information from `stream` and
371/// decompress it into a result buffer.  Assumes that `buf` is empty.
372///
373/// If we get more than maxlen bytes after decompression, give an error.
374///
375/// Returns the status of our download attempt, stores any data that
376/// we were able to download into `result`.  Existing contents of
377/// `result` are overwritten.
378async fn read_and_decompress<S, SP>(
379    runtime: &SP,
380    mut stream: S,
381    maxlen: usize,
382    result: &mut Vec<u8>,
383) -> RequestResult<()>
384where
385    S: AsyncRead + Unpin,
386    SP: SleepProvider,
387{
388    let buffer_window_size = 1024;
389    let mut written_total: usize = 0;
390    // TODO(nickm): This should be an option, and is maybe too long.
391    // Though for some users it may be too short?
392    let read_timeout = Duration::from_secs(10);
393    let timer = runtime.sleep(read_timeout).fuse();
394    futures::pin_mut!(timer);
395
396    loop {
397        // allocate buffer for next read
398        result.resize(written_total + buffer_window_size, 0);
399        let buf: &mut [u8] = &mut result[written_total..written_total + buffer_window_size];
400
401        let status = futures::select! {
402            status = stream.read(buf).fuse() => status,
403            _ = timer => {
404                result.resize(written_total, 0); // truncate as needed
405                return Err(RequestError::DirTimeout);
406            }
407        };
408        let written_in_this_loop = match status {
409            Ok(n) => n,
410            Err(other) => {
411                result.resize(written_total, 0); // truncate as needed
412                return Err(other.into());
413            }
414        };
415
416        written_total += written_in_this_loop;
417
418        // exit conditions below
419
420        if written_in_this_loop == 0 {
421            /*
422            in case we read less than `buffer_window_size` in last `read`
423            we need to shrink result because otherwise we'll return those
424            un-read 0s
425            */
426            if written_total < result.len() {
427                result.resize(written_total, 0);
428            }
429            return Ok(());
430        }
431
432        // TODO: It would be good to detect compression bombs, but
433        // that would require access to the internal stream, which
434        // would in turn require some tricky programming.  For now, we
435        // use the maximum length here to prevent an attacker from
436        // filling our RAM.
437        if written_total > maxlen {
438            result.resize(maxlen, 0);
439            return Err(RequestError::ResponseTooLong(written_total));
440        }
441    }
442}
443
444/// Retire a directory circuit because of an error we've encountered on it.
445fn retire_circ<R>(circ_mgr: &Arc<CircMgr<R>>, id: &tor_proto::circuit::UniqId, error: &str)
446where
447    R: Runtime,
448{
449    info!(
450        "{}: Retiring circuit because of directory failure: {}",
451        &id, &error
452    );
453    circ_mgr.retire_circ(id);
454}
455
456/// As AsyncBufReadExt::read_until, but stops after reading `max` bytes.
457///
458/// Note that this function might not actually read any byte of value
459/// `byte`, since EOF might occur, or we might fill the buffer.
460///
461/// A return value of 0 indicates an end-of-file.
462async fn read_until_limited<S>(
463    stream: &mut S,
464    byte: u8,
465    max: usize,
466    buf: &mut Vec<u8>,
467) -> std::io::Result<usize>
468where
469    S: AsyncBufRead + Unpin,
470{
471    let mut n_added = 0;
472    loop {
473        let data = stream.fill_buf().await?;
474        if data.is_empty() {
475            // End-of-file has been reached.
476            return Ok(n_added);
477        }
478        debug_assert!(n_added < max);
479        let remaining_space = max - n_added;
480        let (available, found_byte) = match memchr(byte, data) {
481            Some(idx) => (idx + 1, true),
482            None => (data.len(), false),
483        };
484        debug_assert!(available >= 1);
485        let n_to_copy = std::cmp::min(remaining_space, available);
486        buf.extend(&data[..n_to_copy]);
487        stream.consume_unpin(n_to_copy);
488        n_added += n_to_copy;
489        if found_byte || n_added == max {
490            return Ok(n_added);
491        }
492    }
493}
494
495/// Helper: Return a boxed decoder object that wraps the stream  $s.
496macro_rules! decoder {
497    ($dec:ident, $s:expr) => {{
498        let mut decoder = $dec::new($s);
499        decoder.multiple_members(true);
500        Ok(Box::new(decoder))
501    }};
502}
503
504/// Wrap `stream` in an appropriate type to undo the content encoding
505/// as described in `encoding`.
506fn get_decoder<'a, S: AsyncBufRead + Unpin + Send + 'a>(
507    stream: S,
508    encoding: Option<&str>,
509    anonymized: AnonymizedRequest,
510) -> RequestResult<Box<dyn AsyncRead + Unpin + Send + 'a>> {
511    use AnonymizedRequest::Direct;
512    match (encoding, anonymized) {
513        (None | Some("identity"), _) => Ok(Box::new(stream)),
514        (Some("deflate"), _) => decoder!(ZlibDecoder, stream),
515        // We only admit to supporting these on a direct connection; otherwise,
516        // a hostile directory could send them back even though we hadn't
517        // requested them.
518        #[cfg(feature = "xz")]
519        (Some("x-tor-lzma"), Direct) => decoder!(XzDecoder, stream),
520        #[cfg(feature = "zstd")]
521        (Some("x-zstd"), Direct) => decoder!(ZstdDecoder, stream),
522        (Some(other), _) => Err(RequestError::ContentEncoding(other.into())),
523    }
524}
525
526#[cfg(test)]
527mod test {
528    // @@ begin test lint list maintained by maint/add_warning @@
529    #![allow(clippy::bool_assert_comparison)]
530    #![allow(clippy::clone_on_copy)]
531    #![allow(clippy::dbg_macro)]
532    #![allow(clippy::mixed_attributes_style)]
533    #![allow(clippy::print_stderr)]
534    #![allow(clippy::print_stdout)]
535    #![allow(clippy::single_char_pattern)]
536    #![allow(clippy::unwrap_used)]
537    #![allow(clippy::unchecked_time_subtraction)]
538    #![allow(clippy::useless_vec)]
539    #![allow(clippy::needless_pass_by_value)]
540    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
541    use super::*;
542    use tor_rtmock::io::stream_pair;
543
544    use tor_rtmock::simple_time::SimpleMockTimeProvider;
545
546    use futures_await_test::async_test;
547
548    #[async_test]
549    async fn test_read_until_limited() -> RequestResult<()> {
550        let mut out = Vec::new();
551        let bytes = b"This line eventually ends\nthen comes another\n";
552
553        // Case 1: find a whole line.
554        let mut s = &bytes[..];
555        let res = read_until_limited(&mut s, b'\n', 100, &mut out).await;
556        assert_eq!(res?, 26);
557        assert_eq!(&out[..], b"This line eventually ends\n");
558
559        // Case 2: reach the limit.
560        let mut s = &bytes[..];
561        out.clear();
562        let res = read_until_limited(&mut s, b'\n', 10, &mut out).await;
563        assert_eq!(res?, 10);
564        assert_eq!(&out[..], b"This line ");
565
566        // Case 3: reach EOF.
567        let mut s = &bytes[..];
568        out.clear();
569        let res = read_until_limited(&mut s, b'Z', 100, &mut out).await;
570        assert_eq!(res?, 45);
571        assert_eq!(&out[..], &bytes[..]);
572
573        Ok(())
574    }
575
576    // Basic decompression wrapper.
577    async fn decomp_basic(
578        encoding: Option<&str>,
579        data: &[u8],
580        maxlen: usize,
581    ) -> (RequestResult<()>, Vec<u8>) {
582        // We don't need to do anything fancy here, since we aren't simulating
583        // a timeout.
584        #[allow(deprecated)] // TODO #1885
585        let mock_time = SimpleMockTimeProvider::from_wallclock(std::time::SystemTime::now());
586
587        let mut output = Vec::new();
588        let mut stream = match get_decoder(data, encoding, AnonymizedRequest::Direct) {
589            Ok(s) => s,
590            Err(e) => return (Err(e), output),
591        };
592
593        let r = read_and_decompress(&mock_time, &mut stream, maxlen, &mut output).await;
594
595        (r, output)
596    }
597
598    #[async_test]
599    async fn decompress_identity() -> RequestResult<()> {
600        let mut text = Vec::new();
601        for _ in 0..1000 {
602            text.extend(b"This is a string with a nontrivial length that we'll use to make sure that the loop is executed more than once.");
603        }
604
605        let limit = 10 << 20;
606        let (s, r) = decomp_basic(None, &text[..], limit).await;
607        s?;
608        assert_eq!(r, text);
609
610        let (s, r) = decomp_basic(Some("identity"), &text[..], limit).await;
611        s?;
612        assert_eq!(r, text);
613
614        // Try truncated result
615        let limit = 100;
616        let (s, r) = decomp_basic(Some("identity"), &text[..], limit).await;
617        assert!(s.is_err());
618        assert_eq!(r, &text[..100]);
619
620        Ok(())
621    }
622
623    #[async_test]
624    async fn decomp_zlib() -> RequestResult<()> {
625        let compressed =
626            hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5970c88").unwrap();
627
628        let limit = 10 << 20;
629        let (s, r) = decomp_basic(Some("deflate"), &compressed, limit).await;
630        s?;
631        assert_eq!(r, b"One fish Two fish Red fish Blue fish");
632
633        Ok(())
634    }
635
636    #[cfg(feature = "zstd")]
637    #[async_test]
638    async fn decomp_zstd() -> RequestResult<()> {
639        let compressed = hex::decode("28b52ffd24250d0100c84f6e6520666973682054776f526564426c756520666973680a0200600c0e2509478352cb").unwrap();
640        let limit = 10 << 20;
641        let (s, r) = decomp_basic(Some("x-zstd"), &compressed, limit).await;
642        s?;
643        assert_eq!(r, b"One fish Two fish Red fish Blue fish\n");
644
645        Ok(())
646    }
647
648    #[cfg(feature = "xz")]
649    #[async_test]
650    async fn decomp_xz2() -> RequestResult<()> {
651        // Not so good at tiny files...
652        let compressed = hex::decode("fd377a585a000004e6d6b446020021011c00000010cf58cce00024001d5d00279b88a202ca8612cfb3c19c87c34248a570451e4851d3323d34ab8000000000000901af64854c91f600013925d6ec06651fb6f37d010000000004595a").unwrap();
653        let limit = 10 << 20;
654        let (s, r) = decomp_basic(Some("x-tor-lzma"), &compressed, limit).await;
655        s?;
656        assert_eq!(r, b"One fish Two fish Red fish Blue fish\n");
657
658        Ok(())
659    }
660
661    #[async_test]
662    async fn decomp_unknown() {
663        let compressed = hex::decode("28b52ffd24250d0100c84f6e6520666973682054776f526564426c756520666973680a0200600c0e2509478352cb").unwrap();
664        let limit = 10 << 20;
665        let (s, _r) = decomp_basic(Some("x-proprietary-rle"), &compressed, limit).await;
666
667        assert!(matches!(s, Err(RequestError::ContentEncoding(_))));
668    }
669
670    #[async_test]
671    async fn decomp_bad_data() {
672        let compressed = b"This is not good zlib data";
673        let limit = 10 << 20;
674        let (s, _r) = decomp_basic(Some("deflate"), compressed, limit).await;
675
676        // This should possibly be a different type in the future.
677        assert!(matches!(s, Err(RequestError::IoError(_))));
678    }
679
680    #[async_test]
681    async fn headers_ok() -> RequestResult<()> {
682        let text = b"HTTP/1.0 200 OK\r\nDate: ignored\r\nContent-Encoding: Waffles\r\n\r\n";
683
684        let mut s = &text[..];
685        let h = read_headers(&mut s).await?;
686
687        assert_eq!(h.status, Some(200));
688        assert_eq!(h.encoding.as_deref(), Some("Waffles"));
689
690        // now try truncated
691        let mut s = &text[..15];
692        let h = read_headers(&mut s).await;
693        assert!(matches!(h, Err(RequestError::TruncatedHeaders)));
694
695        // now try with no encoding.
696        let text = b"HTTP/1.0 404 Not found\r\n\r\n";
697        let mut s = &text[..];
698        let h = read_headers(&mut s).await?;
699
700        assert_eq!(h.status, Some(404));
701        assert!(h.encoding.is_none());
702
703        Ok(())
704    }
705
706    #[async_test]
707    async fn headers_bogus() -> Result<()> {
708        let text = b"HTTP/999.0 WHAT EVEN\r\n\r\n";
709        let mut s = &text[..];
710        let h = read_headers(&mut s).await;
711
712        assert!(h.is_err());
713        assert!(matches!(h, Err(RequestError::HttparseError(_))));
714        Ok(())
715    }
716
717    /// Run a trivial download example with a response provided as a binary
718    /// string.
719    ///
720    /// Return the directory response (if any) and the request as encoded (if
721    /// any.)
722    fn run_download_test<Req: request::Requestable>(
723        req: Req,
724        response: &[u8],
725    ) -> (Result<DirResponse>, RequestResult<Vec<u8>>) {
726        let (mut s1, s2) = stream_pair();
727        let (mut s2_r, mut s2_w) = s2.split();
728
729        tor_rtcompat::test_with_one_runtime!(|rt| async move {
730            let rt2 = rt.clone();
731            let (v1, v2, v3): (
732                Result<DirResponse>,
733                RequestResult<Vec<u8>>,
734                RequestResult<()>,
735            ) = futures::join!(
736                async {
737                    // Run the download function.
738                    let r = send_request(&rt, &req, &mut s1, None).await;
739                    s1.close().await.map_err(|error| {
740                        Error::RequestFailed(RequestFailedError {
741                            source: None,
742                            error: error.into(),
743                        })
744                    })?;
745                    r
746                },
747                async {
748                    // Take the request from the client, and return it in "v2"
749                    let mut v = Vec::new();
750                    s2_r.read_to_end(&mut v).await?;
751                    Ok(v)
752                },
753                async {
754                    // Send back a response.
755                    s2_w.write_all(response).await?;
756                    // We wait a moment to give the other side time to notice it
757                    // has data.
758                    //
759                    // (Tentative diagnosis: The `async-compress` crate seems to
760                    // be behave differently depending on whether the "close"
761                    // comes right after the incomplete data or whether it comes
762                    // after a delay.  If there's a delay, it notices the
763                    // truncated data and tells us about it. But when there's
764                    // _no_delay, it treats the data as an error and doesn't
765                    // tell our code.)
766
767                    // TODO: sleeping in tests is not great.
768                    rt2.sleep(Duration::from_millis(50)).await;
769                    s2_w.close().await?;
770                    Ok(())
771                }
772            );
773
774            assert!(v3.is_ok());
775
776            (v1, v2)
777        })
778    }
779
780    #[test]
781    fn test_send_request() -> RequestResult<()> {
782        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
783
784        let (response, request) = run_download_test(
785            req,
786            b"HTTP/1.0 200 OK\r\n\r\nThis is where the descs would go.",
787        );
788
789        let request = request?;
790        assert!(request[..].starts_with(
791            b"GET /tor/micro/d/CQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQk HTTP/1.0\r\n"
792        ));
793
794        let response = response.unwrap();
795        assert_eq!(response.status_code(), 200);
796        assert!(!response.is_partial());
797        assert!(response.error().is_none());
798        assert!(response.source().is_none());
799        let out_ref = response.output_unchecked();
800        assert_eq!(out_ref, b"This is where the descs would go.");
801        let out = response.into_output_unchecked();
802        assert_eq!(&out, b"This is where the descs would go.");
803
804        Ok(())
805    }
806
807    #[test]
808    fn test_download_truncated() {
809        // Request only one md, so "partial ok" will not be set.
810        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
811        let mut response_text: Vec<u8> =
812            (*b"HTTP/1.0 200 OK\r\nContent-Encoding: deflate\r\n\r\n").into();
813        // "One fish two fish" as above twice, but truncated the second time
814        response_text.extend(
815            hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5970c88").unwrap(),
816        );
817        response_text.extend(
818            hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5").unwrap(),
819        );
820        let (response, request) = run_download_test(req, &response_text);
821        assert!(request.is_ok());
822        assert!(response.is_err()); // The whole download should fail, since partial_ok wasn't set.
823
824        // request two microdescs, so "partial_ok" will be set.
825        let req: request::MicrodescRequest = vec![[9; 32]; 2].into_iter().collect();
826
827        let (response, request) = run_download_test(req, &response_text);
828        assert!(request.is_ok());
829
830        let response = response.unwrap();
831        assert_eq!(response.status_code(), 200);
832        assert!(response.error().is_some());
833        assert!(response.is_partial());
834        assert!(response.output_unchecked().len() < 37 * 2);
835        assert!(response.output_unchecked().starts_with(b"One fish"));
836    }
837
838    #[test]
839    fn test_404() {
840        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
841        let response_text = b"HTTP/1.0 418 I'm a teapot\r\n\r\n";
842        let (response, _request) = run_download_test(req, response_text);
843
844        assert_eq!(response.unwrap().status_code(), 418);
845    }
846
847    #[test]
848    fn test_headers_truncated() {
849        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
850        let response_text = b"HTTP/1.0 404 truncation happens here\r\n";
851        let (response, _request) = run_download_test(req, response_text);
852
853        assert!(matches!(
854            response,
855            Err(Error::RequestFailed(RequestFailedError {
856                error: RequestError::TruncatedHeaders,
857                ..
858            }))
859        ));
860
861        // Try a completely empty response.
862        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
863        let response_text = b"";
864        let (response, _request) = run_download_test(req, response_text);
865
866        assert!(matches!(
867            response,
868            Err(Error::RequestFailed(RequestFailedError {
869                error: RequestError::TruncatedHeaders,
870                ..
871            }))
872        ));
873    }
874
875    #[test]
876    fn test_headers_too_long() {
877        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
878        let mut response_text: Vec<u8> = (*b"HTTP/1.0 418 I'm a teapot\r\nX-Too-Many-As: ").into();
879        response_text.resize(16384, b'A');
880        let (response, _request) = run_download_test(req, &response_text);
881
882        assert!(response.as_ref().unwrap_err().should_retire_circ());
883        assert!(matches!(
884            response,
885            Err(Error::RequestFailed(RequestFailedError {
886                error: RequestError::HeadersTooLong(_),
887                ..
888            }))
889        ));
890    }
891
892    // TODO: test with bad utf-8
893}