Skip to main content

tor_dirclient/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2#![doc = include_str!("../README.md")]
3// @@ begin lint list maintained by maint/add_warning @@
4#![allow(renamed_and_removed_lints)] // @@REMOVE_WHEN(ci_arti_stable)
5#![allow(unknown_lints)] // @@REMOVE_WHEN(ci_arti_nightly)
6#![warn(missing_docs)]
7#![warn(noop_method_call)]
8#![warn(unreachable_pub)]
9#![warn(clippy::all)]
10#![deny(clippy::await_holding_lock)]
11#![deny(clippy::cargo_common_metadata)]
12#![deny(clippy::cast_lossless)]
13#![deny(clippy::checked_conversions)]
14#![warn(clippy::cognitive_complexity)]
15#![deny(clippy::debug_assert_with_mut_call)]
16#![deny(clippy::exhaustive_enums)]
17#![deny(clippy::exhaustive_structs)]
18#![deny(clippy::expl_impl_clone_on_copy)]
19#![deny(clippy::fallible_impl_from)]
20#![deny(clippy::implicit_clone)]
21#![deny(clippy::large_stack_arrays)]
22#![warn(clippy::manual_ok_or)]
23#![deny(clippy::missing_docs_in_private_items)]
24#![warn(clippy::needless_borrow)]
25#![warn(clippy::needless_pass_by_value)]
26#![warn(clippy::option_option)]
27#![deny(clippy::print_stderr)]
28#![deny(clippy::print_stdout)]
29#![warn(clippy::rc_buffer)]
30#![deny(clippy::ref_option_ref)]
31#![warn(clippy::semicolon_if_nothing_returned)]
32#![warn(clippy::trait_duplication_in_bounds)]
33#![deny(clippy::unchecked_time_subtraction)]
34#![deny(clippy::unnecessary_wraps)]
35#![warn(clippy::unseparated_literal_suffix)]
36#![deny(clippy::unwrap_used)]
37#![deny(clippy::mod_module_files)]
38#![allow(clippy::let_unit_value)] // This can reasonably be done for explicitness
39#![allow(clippy::uninlined_format_args)]
40#![allow(clippy::significant_drop_in_scrutinee)] // arti/-/merge_requests/588/#note_2812945
41#![allow(clippy::result_large_err)] // temporary workaround for arti#587
42#![allow(clippy::needless_raw_string_hashes)] // complained-about code is fine, often best
43#![allow(clippy::needless_lifetimes)] // See arti#1765
44#![allow(mismatched_lifetime_syntaxes)] // temporary workaround for arti#2060
45#![allow(clippy::collapsible_if)] // See arti#2342
46#![deny(clippy::unused_async)]
47//! <!-- @@ end lint list maintained by maint/add_warning @@ -->
48
49// TODO probably remove this at some point - see tpo/core/arti#1060
50#![cfg_attr(
51    not(all(feature = "full", feature = "experimental")),
52    allow(unused_imports)
53)]
54
55mod err;
56pub mod request;
57mod response;
58mod util;
59
60use tor_circmgr::{CircMgr, DirInfo};
61use tor_error::bad_api_usage;
62use tor_rtcompat::{Runtime, SleepProvider, SleepProviderExt};
63
64// Zlib is required; the others are optional.
65#[cfg(feature = "xz")]
66use async_compression::futures::bufread::XzDecoder;
67use async_compression::futures::bufread::ZlibDecoder;
68#[cfg(feature = "zstd")]
69use async_compression::futures::bufread::ZstdDecoder;
70
71use futures::FutureExt;
72use futures::io::{
73    AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader,
74};
75use memchr::memchr;
76use std::sync::Arc;
77use std::time::Duration;
78use tracing::{info, instrument};
79
80pub use err::{Error, RequestError, RequestFailedError};
81pub use response::{DirResponse, SourceInfo};
82
83/// Type for results returned in this crate.
84pub type Result<T> = std::result::Result<T, Error>;
85
86/// Type for internal results  containing a RequestError.
87pub type RequestResult<T> = std::result::Result<T, RequestError>;
88
89/// Flag to declare whether a request is anonymized or not.
90///
91/// Some requests (like those to download onion service descriptors) are always
92/// anonymized, and should never be sent in a way that leaks information about
93/// our settings or configuration.
94#[derive(Copy, Clone, Debug, Eq, PartialEq)]
95#[non_exhaustive]
96pub enum AnonymizedRequest {
97    /// This request should not leak any information about our configuration.
98    Anonymized,
99    /// This request is allowed to include information about our capabilities.
100    Direct,
101}
102
103/// Fetch the resource described by `req` over the Tor network.
104///
105/// Circuits are built or found using `circ_mgr`, using paths
106/// constructed using `dirinfo`.
107///
108/// For more fine-grained control over the circuit and stream used,
109/// construct them yourself, and then call [`send_request`] instead.
110///
111/// # TODO
112///
113/// This is the only function in this crate that knows about CircMgr and
114/// DirInfo.  Perhaps this function should move up a level into DirMgr?
115#[instrument(level = "trace", skip_all)]
116pub async fn get_resource<CR, R, SP>(
117    req: &CR,
118    dirinfo: DirInfo<'_>,
119    runtime: &SP,
120    circ_mgr: Arc<CircMgr<R>>,
121) -> Result<DirResponse>
122where
123    CR: request::Requestable + ?Sized,
124    R: Runtime,
125    SP: SleepProvider,
126{
127    let tunnel = circ_mgr.get_or_launch_dir(dirinfo).await?;
128
129    if req.anonymized() == AnonymizedRequest::Anonymized {
130        return Err(bad_api_usage!("Tried to use get_resource for an anonymized request").into());
131    }
132
133    // TODO(nickm) This should be an option, and is too long.
134    let begin_timeout = Duration::from_secs(5);
135    let source = match SourceInfo::from_tunnel(&tunnel) {
136        Ok(source) => source,
137        Err(e) => {
138            return Err(Error::RequestFailed(RequestFailedError {
139                source: None,
140                error: e.into(),
141            }));
142        }
143    };
144
145    let wrap_err = |error| {
146        Error::RequestFailed(RequestFailedError {
147            source: source.clone(),
148            error,
149        })
150    };
151
152    req.check_circuit(&tunnel).await.map_err(wrap_err)?;
153
154    // Launch the stream.
155    let mut stream = runtime
156        .timeout(begin_timeout, tunnel.begin_dir_stream())
157        .await
158        .map_err(RequestError::from)
159        .map_err(wrap_err)?
160        .map_err(RequestError::from)
161        .map_err(wrap_err)?; // TODO(nickm) handle fatalities here too
162
163    // TODO: Perhaps we want separate timeouts for each phase of this.
164    // For now, we just use higher-level timeouts in `dirmgr`.
165    let r = send_request(runtime, req, &mut stream, source.clone()).await;
166
167    if should_retire_circ(&r) {
168        retire_circ(&circ_mgr, &tunnel.unique_id(), "Partial response");
169    }
170
171    r
172}
173
174/// Return true if `result` holds an error indicating that we should retire the
175/// circuit used for the corresponding request.
176fn should_retire_circ(result: &Result<DirResponse>) -> bool {
177    match result {
178        Err(e) => e.should_retire_circ(),
179        Ok(dr) => dr.error().map(RequestError::should_retire_circ) == Some(true),
180    }
181}
182
183/// Fetch a Tor directory object from a provided stream.
184#[deprecated(since = "0.8.1", note = "Use send_request instead.")]
185pub async fn download<R, S, SP>(
186    runtime: &SP,
187    req: &R,
188    stream: &mut S,
189    source: Option<SourceInfo>,
190) -> Result<DirResponse>
191where
192    R: request::Requestable + ?Sized,
193    S: AsyncRead + AsyncWrite + Send + Unpin,
194    SP: SleepProvider,
195{
196    send_request(runtime, req, stream, source).await
197}
198
199/// Fetch or upload a Tor directory object using the provided stream.
200///
201/// To do this, we send a simple HTTP/1.0 request for the described
202/// object in `req` over `stream`, and then wait for a response.  In
203/// log messages, we describe the origin of the data as coming from
204/// `source`.
205///
206/// # Notes
207///
208/// It's kind of bogus to have a 'source' field here at all; we may
209/// eventually want to remove it.
210///
211/// This function doesn't close the stream; you may want to do that
212/// yourself.
213///
214/// The only error variant returned is [`Error::RequestFailed`].
215// TODO: should the error return type change to `RequestFailedError`?
216// If so, that would simplify some code in_dirmgr::bridgedesc.
217pub async fn send_request<R, S, SP>(
218    runtime: &SP,
219    req: &R,
220    stream: &mut S,
221    source: Option<SourceInfo>,
222) -> Result<DirResponse>
223where
224    R: request::Requestable + ?Sized,
225    S: AsyncRead + AsyncWrite + Send + Unpin,
226    SP: SleepProvider,
227{
228    let wrap_err = |error| {
229        Error::RequestFailed(RequestFailedError {
230            source: source.clone(),
231            error,
232        })
233    };
234
235    let partial_ok = req.partial_response_body_ok();
236    let maxlen = req.max_response_len();
237    let anonymized = req.anonymized();
238    let req = req.make_request().map_err(wrap_err)?;
239    let encoded = util::encode_request(&req);
240
241    // Write the request.
242    stream
243        .write_all(encoded.as_bytes())
244        .await
245        .map_err(RequestError::from)
246        .map_err(wrap_err)?;
247    stream
248        .flush()
249        .await
250        .map_err(RequestError::from)
251        .map_err(wrap_err)?;
252
253    let mut buffered = BufReader::new(stream);
254
255    // Handle the response
256    // TODO: should there be a separate timeout here?
257    let header = read_headers(&mut buffered).await.map_err(wrap_err)?;
258    if header.status != Some(200) {
259        return Ok(DirResponse::new(
260            header.status.unwrap_or(0),
261            header.status_message,
262            None,
263            vec![],
264            source,
265        ));
266    }
267
268    let mut decoder =
269        get_decoder(buffered, header.encoding.as_deref(), anonymized).map_err(wrap_err)?;
270
271    let mut result = Vec::new();
272    let ok = read_and_decompress(runtime, &mut decoder, maxlen, &mut result).await;
273
274    let ok = match (partial_ok, ok, result.len()) {
275        (true, Err(e), n) if n > 0 => {
276            // Note that we _don't_ return here: we want the partial response.
277            Err(e)
278        }
279        (_, Err(e), _) => {
280            return Err(wrap_err(e));
281        }
282        (_, Ok(()), _) => Ok(()),
283    };
284
285    Ok(DirResponse::new(200, None, ok.err(), result, source))
286}
287
288/// Maximum length for the HTTP headers in a single request or response.
289///
290/// Chosen more or less arbitrarily.
291const MAX_HEADERS_LEN: usize = 16384;
292
293/// Read and parse HTTP/1 headers from `stream`.
294async fn read_headers<S>(stream: &mut S) -> RequestResult<HeaderStatus>
295where
296    S: AsyncBufRead + Unpin,
297{
298    let mut buf = Vec::with_capacity(1024);
299
300    loop {
301        // TODO: it's inefficient to do this a line at a time; it would
302        // probably be better to read until the CRLF CRLF ending of the
303        // response.  But this should be fast enough.
304        let n = read_until_limited(stream, b'\n', 2048, &mut buf).await?;
305
306        // TODO(nickm): Better maximum and/or let this expand.
307        let mut headers = [httparse::EMPTY_HEADER; 32];
308        let mut response = httparse::Response::new(&mut headers);
309
310        match response.parse(&buf[..])? {
311            httparse::Status::Partial => {
312                // We didn't get a whole response; we may need to try again.
313
314                if n == 0 {
315                    // We hit an EOF; no more progress can be made.
316                    return Err(RequestError::TruncatedHeaders);
317                }
318
319                if buf.len() >= MAX_HEADERS_LEN {
320                    return Err(RequestError::HeadersTooLong(buf.len()));
321                }
322            }
323            httparse::Status::Complete(n_parsed) => {
324                if response.code != Some(200) {
325                    return Ok(HeaderStatus {
326                        status: response.code,
327                        status_message: response.reason.map(str::to_owned),
328                        encoding: None,
329                    });
330                }
331                let encoding = if let Some(enc) = response
332                    .headers
333                    .iter()
334                    .find(|h| h.name == "Content-Encoding")
335                {
336                    Some(String::from_utf8(enc.value.to_vec())?)
337                } else {
338                    None
339                };
340                /*
341                if let Some(clen) = response.headers.iter().find(|h| h.name == "Content-Length") {
342                    let clen = std::str::from_utf8(clen.value)?;
343                    length = Some(clen.parse()?);
344                }
345                 */
346                assert!(n_parsed == buf.len());
347                return Ok(HeaderStatus {
348                    status: Some(200),
349                    status_message: None,
350                    encoding,
351                });
352            }
353        }
354        if n == 0 {
355            return Err(RequestError::TruncatedHeaders);
356        }
357    }
358}
359
360/// Return value from read_headers
361#[derive(Debug, Clone)]
362struct HeaderStatus {
363    /// HTTP status code.
364    status: Option<u16>,
365    /// HTTP status message associated with the status code.
366    status_message: Option<String>,
367    /// The Content-Encoding header, if any.
368    encoding: Option<String>,
369}
370
371/// Helper: download directory information from `stream` and
372/// decompress it into a result buffer.  Assumes that `buf` is empty.
373///
374/// If we get more than maxlen bytes after decompression, give an error.
375///
376/// Returns the status of our download attempt, stores any data that
377/// we were able to download into `result`.  Existing contents of
378/// `result` are overwritten.
379async fn read_and_decompress<S, SP>(
380    runtime: &SP,
381    mut stream: S,
382    maxlen: usize,
383    result: &mut Vec<u8>,
384) -> RequestResult<()>
385where
386    S: AsyncRead + Unpin,
387    SP: SleepProvider,
388{
389    let buffer_window_size = 1024;
390    let mut written_total: usize = 0;
391    // TODO(nickm): This should be an option, and is maybe too long.
392    // Though for some users it may be too short?
393    let read_timeout = Duration::from_secs(10);
394    let timer = runtime.sleep(read_timeout).fuse();
395    futures::pin_mut!(timer);
396
397    loop {
398        // allocate buffer for next read
399        result.resize(written_total + buffer_window_size, 0);
400        let buf: &mut [u8] = &mut result[written_total..written_total + buffer_window_size];
401
402        let status = futures::select! {
403            status = stream.read(buf).fuse() => status,
404            _ = timer => {
405                result.resize(written_total, 0); // truncate as needed
406                return Err(RequestError::DirTimeout);
407            }
408        };
409        let written_in_this_loop = match status {
410            Ok(n) => n,
411            Err(other) => {
412                result.resize(written_total, 0); // truncate as needed
413                return Err(other.into());
414            }
415        };
416
417        written_total += written_in_this_loop;
418
419        // exit conditions below
420
421        if written_in_this_loop == 0 {
422            /*
423            in case we read less than `buffer_window_size` in last `read`
424            we need to shrink result because otherwise we'll return those
425            un-read 0s
426            */
427            if written_total < result.len() {
428                result.resize(written_total, 0);
429            }
430            return Ok(());
431        }
432
433        // TODO: It would be good to detect compression bombs, but
434        // that would require access to the internal stream, which
435        // would in turn require some tricky programming.  For now, we
436        // use the maximum length here to prevent an attacker from
437        // filling our RAM.
438        if written_total > maxlen {
439            result.resize(maxlen, 0);
440            return Err(RequestError::ResponseTooLong(written_total));
441        }
442    }
443}
444
445/// Retire a directory circuit because of an error we've encountered on it.
446fn retire_circ<R>(circ_mgr: &Arc<CircMgr<R>>, id: &tor_proto::circuit::UniqId, error: &str)
447where
448    R: Runtime,
449{
450    info!(
451        "{}: Retiring circuit because of directory failure: {}",
452        &id, &error
453    );
454    circ_mgr.retire_circ(id);
455}
456
457/// As AsyncBufReadExt::read_until, but stops after reading `max` bytes.
458///
459/// Note that this function might not actually read any byte of value
460/// `byte`, since EOF might occur, or we might fill the buffer.
461///
462/// A return value of 0 indicates an end-of-file.
463async fn read_until_limited<S>(
464    stream: &mut S,
465    byte: u8,
466    max: usize,
467    buf: &mut Vec<u8>,
468) -> std::io::Result<usize>
469where
470    S: AsyncBufRead + Unpin,
471{
472    let mut n_added = 0;
473    loop {
474        let data = stream.fill_buf().await?;
475        if data.is_empty() {
476            // End-of-file has been reached.
477            return Ok(n_added);
478        }
479        debug_assert!(n_added < max);
480        let remaining_space = max - n_added;
481        let (available, found_byte) = match memchr(byte, data) {
482            Some(idx) => (idx + 1, true),
483            None => (data.len(), false),
484        };
485        debug_assert!(available >= 1);
486        let n_to_copy = std::cmp::min(remaining_space, available);
487        buf.extend(&data[..n_to_copy]);
488        stream.consume_unpin(n_to_copy);
489        n_added += n_to_copy;
490        if found_byte || n_added == max {
491            return Ok(n_added);
492        }
493    }
494}
495
496/// Helper: Return a boxed decoder object that wraps the stream  $s.
497macro_rules! decoder {
498    ($dec:ident, $s:expr) => {{
499        let mut decoder = $dec::new($s);
500        decoder.multiple_members(true);
501        Ok(Box::new(decoder))
502    }};
503}
504
505/// Wrap `stream` in an appropriate type to undo the content encoding
506/// as described in `encoding`.
507fn get_decoder<'a, S: AsyncBufRead + Unpin + Send + 'a>(
508    stream: S,
509    encoding: Option<&str>,
510    anonymized: AnonymizedRequest,
511) -> RequestResult<Box<dyn AsyncRead + Unpin + Send + 'a>> {
512    use AnonymizedRequest::Direct;
513    match (encoding, anonymized) {
514        (None | Some("identity"), _) => Ok(Box::new(stream)),
515        (Some("deflate"), _) => decoder!(ZlibDecoder, stream),
516        // We only admit to supporting these on a direct connection; otherwise,
517        // a hostile directory could send them back even though we hadn't
518        // requested them.
519        #[cfg(feature = "xz")]
520        (Some("x-tor-lzma"), Direct) => decoder!(XzDecoder, stream),
521        #[cfg(feature = "zstd")]
522        (Some("x-zstd"), Direct) => decoder!(ZstdDecoder, stream),
523        (Some(other), _) => Err(RequestError::ContentEncoding(other.into())),
524    }
525}
526
527#[cfg(test)]
528mod test {
529    // @@ begin test lint list maintained by maint/add_warning @@
530    #![allow(clippy::bool_assert_comparison)]
531    #![allow(clippy::clone_on_copy)]
532    #![allow(clippy::dbg_macro)]
533    #![allow(clippy::mixed_attributes_style)]
534    #![allow(clippy::print_stderr)]
535    #![allow(clippy::print_stdout)]
536    #![allow(clippy::single_char_pattern)]
537    #![allow(clippy::unwrap_used)]
538    #![allow(clippy::unchecked_time_subtraction)]
539    #![allow(clippy::useless_vec)]
540    #![allow(clippy::needless_pass_by_value)]
541    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
542    use super::*;
543    use tor_rtmock::io::stream_pair;
544
545    use tor_rtmock::simple_time::SimpleMockTimeProvider;
546
547    use futures_await_test::async_test;
548
549    #[async_test]
550    async fn test_read_until_limited() -> RequestResult<()> {
551        let mut out = Vec::new();
552        let bytes = b"This line eventually ends\nthen comes another\n";
553
554        // Case 1: find a whole line.
555        let mut s = &bytes[..];
556        let res = read_until_limited(&mut s, b'\n', 100, &mut out).await;
557        assert_eq!(res?, 26);
558        assert_eq!(&out[..], b"This line eventually ends\n");
559
560        // Case 2: reach the limit.
561        let mut s = &bytes[..];
562        out.clear();
563        let res = read_until_limited(&mut s, b'\n', 10, &mut out).await;
564        assert_eq!(res?, 10);
565        assert_eq!(&out[..], b"This line ");
566
567        // Case 3: reach EOF.
568        let mut s = &bytes[..];
569        out.clear();
570        let res = read_until_limited(&mut s, b'Z', 100, &mut out).await;
571        assert_eq!(res?, 45);
572        assert_eq!(&out[..], &bytes[..]);
573
574        Ok(())
575    }
576
577    // Basic decompression wrapper.
578    async fn decomp_basic(
579        encoding: Option<&str>,
580        data: &[u8],
581        maxlen: usize,
582    ) -> (RequestResult<()>, Vec<u8>) {
583        // We don't need to do anything fancy here, since we aren't simulating
584        // a timeout.
585        #[allow(deprecated)] // TODO #1885
586        let mock_time = SimpleMockTimeProvider::from_wallclock(std::time::SystemTime::now());
587
588        let mut output = Vec::new();
589        let mut stream = match get_decoder(data, encoding, AnonymizedRequest::Direct) {
590            Ok(s) => s,
591            Err(e) => return (Err(e), output),
592        };
593
594        let r = read_and_decompress(&mock_time, &mut stream, maxlen, &mut output).await;
595
596        (r, output)
597    }
598
599    #[async_test]
600    async fn decompress_identity() -> RequestResult<()> {
601        let mut text = Vec::new();
602        for _ in 0..1000 {
603            text.extend(b"This is a string with a nontrivial length that we'll use to make sure that the loop is executed more than once.");
604        }
605
606        let limit = 10 << 20;
607        let (s, r) = decomp_basic(None, &text[..], limit).await;
608        s?;
609        assert_eq!(r, text);
610
611        let (s, r) = decomp_basic(Some("identity"), &text[..], limit).await;
612        s?;
613        assert_eq!(r, text);
614
615        // Try truncated result
616        let limit = 100;
617        let (s, r) = decomp_basic(Some("identity"), &text[..], limit).await;
618        assert!(s.is_err());
619        assert_eq!(r, &text[..100]);
620
621        Ok(())
622    }
623
624    #[async_test]
625    async fn decomp_zlib() -> RequestResult<()> {
626        let compressed =
627            hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5970c88").unwrap();
628
629        let limit = 10 << 20;
630        let (s, r) = decomp_basic(Some("deflate"), &compressed, limit).await;
631        s?;
632        assert_eq!(r, b"One fish Two fish Red fish Blue fish");
633
634        Ok(())
635    }
636
637    #[cfg(feature = "zstd")]
638    #[async_test]
639    async fn decomp_zstd() -> RequestResult<()> {
640        let compressed = hex::decode("28b52ffd24250d0100c84f6e6520666973682054776f526564426c756520666973680a0200600c0e2509478352cb").unwrap();
641        let limit = 10 << 20;
642        let (s, r) = decomp_basic(Some("x-zstd"), &compressed, limit).await;
643        s?;
644        assert_eq!(r, b"One fish Two fish Red fish Blue fish\n");
645
646        Ok(())
647    }
648
649    #[cfg(feature = "xz")]
650    #[async_test]
651    async fn decomp_xz2() -> RequestResult<()> {
652        // Not so good at tiny files...
653        let compressed = hex::decode("fd377a585a000004e6d6b446020021011c00000010cf58cce00024001d5d00279b88a202ca8612cfb3c19c87c34248a570451e4851d3323d34ab8000000000000901af64854c91f600013925d6ec06651fb6f37d010000000004595a").unwrap();
654        let limit = 10 << 20;
655        let (s, r) = decomp_basic(Some("x-tor-lzma"), &compressed, limit).await;
656        s?;
657        assert_eq!(r, b"One fish Two fish Red fish Blue fish\n");
658
659        Ok(())
660    }
661
662    #[async_test]
663    async fn decomp_unknown() {
664        let compressed = hex::decode("28b52ffd24250d0100c84f6e6520666973682054776f526564426c756520666973680a0200600c0e2509478352cb").unwrap();
665        let limit = 10 << 20;
666        let (s, _r) = decomp_basic(Some("x-proprietary-rle"), &compressed, limit).await;
667
668        assert!(matches!(s, Err(RequestError::ContentEncoding(_))));
669    }
670
671    #[async_test]
672    async fn decomp_bad_data() {
673        let compressed = b"This is not good zlib data";
674        let limit = 10 << 20;
675        let (s, _r) = decomp_basic(Some("deflate"), compressed, limit).await;
676
677        // This should possibly be a different type in the future.
678        assert!(matches!(s, Err(RequestError::IoError(_))));
679    }
680
681    #[async_test]
682    async fn headers_ok() -> RequestResult<()> {
683        let text = b"HTTP/1.0 200 OK\r\nDate: ignored\r\nContent-Encoding: Waffles\r\n\r\n";
684
685        let mut s = &text[..];
686        let h = read_headers(&mut s).await?;
687
688        assert_eq!(h.status, Some(200));
689        assert_eq!(h.encoding.as_deref(), Some("Waffles"));
690
691        // now try truncated
692        let mut s = &text[..15];
693        let h = read_headers(&mut s).await;
694        assert!(matches!(h, Err(RequestError::TruncatedHeaders)));
695
696        // now try with no encoding.
697        let text = b"HTTP/1.0 404 Not found\r\n\r\n";
698        let mut s = &text[..];
699        let h = read_headers(&mut s).await?;
700
701        assert_eq!(h.status, Some(404));
702        assert!(h.encoding.is_none());
703
704        Ok(())
705    }
706
707    #[async_test]
708    async fn headers_bogus() -> Result<()> {
709        let text = b"HTTP/999.0 WHAT EVEN\r\n\r\n";
710        let mut s = &text[..];
711        let h = read_headers(&mut s).await;
712
713        assert!(h.is_err());
714        assert!(matches!(h, Err(RequestError::HttparseError(_))));
715        Ok(())
716    }
717
718    /// Run a trivial download example with a response provided as a binary
719    /// string.
720    ///
721    /// Return the directory response (if any) and the request as encoded (if
722    /// any.)
723    fn run_download_test<Req: request::Requestable>(
724        req: Req,
725        response: &[u8],
726    ) -> (Result<DirResponse>, RequestResult<Vec<u8>>) {
727        let (mut s1, s2) = stream_pair();
728        let (mut s2_r, mut s2_w) = s2.split();
729
730        tor_rtcompat::test_with_one_runtime!(|rt| async move {
731            let rt2 = rt.clone();
732            let (v1, v2, v3): (
733                Result<DirResponse>,
734                RequestResult<Vec<u8>>,
735                RequestResult<()>,
736            ) = futures::join!(
737                async {
738                    // Run the download function.
739                    let r = send_request(&rt, &req, &mut s1, None).await;
740                    s1.close().await.map_err(|error| {
741                        Error::RequestFailed(RequestFailedError {
742                            source: None,
743                            error: error.into(),
744                        })
745                    })?;
746                    r
747                },
748                async {
749                    // Take the request from the client, and return it in "v2"
750                    let mut v = Vec::new();
751                    s2_r.read_to_end(&mut v).await?;
752                    Ok(v)
753                },
754                async {
755                    // Send back a response.
756                    s2_w.write_all(response).await?;
757                    // We wait a moment to give the other side time to notice it
758                    // has data.
759                    //
760                    // (Tentative diagnosis: The `async-compress` crate seems to
761                    // be behave differently depending on whether the "close"
762                    // comes right after the incomplete data or whether it comes
763                    // after a delay.  If there's a delay, it notices the
764                    // truncated data and tells us about it. But when there's
765                    // _no_delay, it treats the data as an error and doesn't
766                    // tell our code.)
767
768                    // TODO: sleeping in tests is not great.
769                    rt2.sleep(Duration::from_millis(50)).await;
770                    s2_w.close().await?;
771                    Ok(())
772                }
773            );
774
775            assert!(v3.is_ok());
776
777            (v1, v2)
778        })
779    }
780
781    #[test]
782    fn test_send_request() -> RequestResult<()> {
783        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
784
785        let (response, request) = run_download_test(
786            req,
787            b"HTTP/1.0 200 OK\r\n\r\nThis is where the descs would go.",
788        );
789
790        let request = request?;
791        assert!(request[..].starts_with(
792            b"GET /tor/micro/d/CQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQk HTTP/1.0\r\n"
793        ));
794
795        let response = response.unwrap();
796        assert_eq!(response.status_code(), 200);
797        assert!(!response.is_partial());
798        assert!(response.error().is_none());
799        assert!(response.source().is_none());
800        let out_ref = response.output_unchecked();
801        assert_eq!(out_ref, b"This is where the descs would go.");
802        let out = response.into_output_unchecked();
803        assert_eq!(&out, b"This is where the descs would go.");
804
805        Ok(())
806    }
807
808    #[test]
809    fn test_download_truncated() {
810        // Request only one md, so "partial ok" will not be set.
811        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
812        let mut response_text: Vec<u8> =
813            (*b"HTTP/1.0 200 OK\r\nContent-Encoding: deflate\r\n\r\n").into();
814        // "One fish two fish" as above twice, but truncated the second time
815        response_text.extend(
816            hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5970c88").unwrap(),
817        );
818        response_text.extend(
819            hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5").unwrap(),
820        );
821        let (response, request) = run_download_test(req, &response_text);
822        assert!(request.is_ok());
823        assert!(response.is_err()); // The whole download should fail, since partial_ok wasn't set.
824
825        // request two microdescs, so "partial_ok" will be set.
826        let req: request::MicrodescRequest = vec![[9; 32]; 2].into_iter().collect();
827
828        let (response, request) = run_download_test(req, &response_text);
829        assert!(request.is_ok());
830
831        let response = response.unwrap();
832        assert_eq!(response.status_code(), 200);
833        assert!(response.error().is_some());
834        assert!(response.is_partial());
835        assert!(response.output_unchecked().len() < 37 * 2);
836        assert!(response.output_unchecked().starts_with(b"One fish"));
837    }
838
839    #[test]
840    fn test_404() {
841        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
842        let response_text = b"HTTP/1.0 418 I'm a teapot\r\n\r\n";
843        let (response, _request) = run_download_test(req, response_text);
844
845        assert_eq!(response.unwrap().status_code(), 418);
846    }
847
848    #[test]
849    fn test_headers_truncated() {
850        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
851        let response_text = b"HTTP/1.0 404 truncation happens here\r\n";
852        let (response, _request) = run_download_test(req, response_text);
853
854        assert!(matches!(
855            response,
856            Err(Error::RequestFailed(RequestFailedError {
857                error: RequestError::TruncatedHeaders,
858                ..
859            }))
860        ));
861
862        // Try a completely empty response.
863        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
864        let response_text = b"";
865        let (response, _request) = run_download_test(req, response_text);
866
867        assert!(matches!(
868            response,
869            Err(Error::RequestFailed(RequestFailedError {
870                error: RequestError::TruncatedHeaders,
871                ..
872            }))
873        ));
874    }
875
876    #[test]
877    fn test_headers_too_long() {
878        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
879        let mut response_text: Vec<u8> = (*b"HTTP/1.0 418 I'm a teapot\r\nX-Too-Many-As: ").into();
880        response_text.resize(16384, b'A');
881        let (response, _request) = run_download_test(req, &response_text);
882
883        assert!(response.as_ref().unwrap_err().should_retire_circ());
884        assert!(matches!(
885            response,
886            Err(Error::RequestFailed(RequestFailedError {
887                error: RequestError::HeadersTooLong(_),
888                ..
889            }))
890        ));
891    }
892
893    // TODO: test with bad utf-8
894}