Skip to main content

tor_dirclient/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2#![doc = include_str!("../README.md")]
3// @@ begin lint list maintained by maint/add_warning @@
4#![allow(renamed_and_removed_lints)] // @@REMOVE_WHEN(ci_arti_stable)
5#![allow(unknown_lints)] // @@REMOVE_WHEN(ci_arti_nightly)
6#![warn(missing_docs)]
7#![warn(noop_method_call)]
8#![warn(unreachable_pub)]
9#![warn(clippy::all)]
10#![deny(clippy::await_holding_lock)]
11#![deny(clippy::cargo_common_metadata)]
12#![deny(clippy::cast_lossless)]
13#![deny(clippy::checked_conversions)]
14#![warn(clippy::cognitive_complexity)]
15#![deny(clippy::debug_assert_with_mut_call)]
16#![deny(clippy::exhaustive_enums)]
17#![deny(clippy::exhaustive_structs)]
18#![deny(clippy::expl_impl_clone_on_copy)]
19#![deny(clippy::fallible_impl_from)]
20#![deny(clippy::implicit_clone)]
21#![deny(clippy::large_stack_arrays)]
22#![warn(clippy::manual_ok_or)]
23#![deny(clippy::missing_docs_in_private_items)]
24#![warn(clippy::needless_borrow)]
25#![warn(clippy::needless_pass_by_value)]
26#![warn(clippy::option_option)]
27#![deny(clippy::print_stderr)]
28#![deny(clippy::print_stdout)]
29#![warn(clippy::rc_buffer)]
30#![deny(clippy::ref_option_ref)]
31#![warn(clippy::semicolon_if_nothing_returned)]
32#![warn(clippy::trait_duplication_in_bounds)]
33#![deny(clippy::unchecked_time_subtraction)]
34#![deny(clippy::unnecessary_wraps)]
35#![warn(clippy::unseparated_literal_suffix)]
36#![deny(clippy::unwrap_used)]
37#![deny(clippy::mod_module_files)]
38#![allow(clippy::let_unit_value)] // This can reasonably be done for explicitness
39#![allow(clippy::uninlined_format_args)]
40#![allow(clippy::significant_drop_in_scrutinee)] // arti/-/merge_requests/588/#note_2812945
41#![allow(clippy::result_large_err)] // temporary workaround for arti#587
42#![allow(clippy::needless_raw_string_hashes)] // complained-about code is fine, often best
43#![allow(clippy::needless_lifetimes)] // See arti#1765
44#![allow(mismatched_lifetime_syntaxes)] // temporary workaround for arti#2060
45#![allow(clippy::collapsible_if)] // See arti#2342
46#![deny(clippy::unused_async)]
47#![deny(clippy::string_slice)] // See arti#2571
48//! <!-- @@ end lint list maintained by maint/add_warning @@ -->
49
50// TODO probably remove this at some point - see tpo/core/arti#1060
51#![cfg_attr(
52    not(all(feature = "full", feature = "experimental")),
53    allow(unused_imports)
54)]
55
56mod body;
57mod err;
58pub mod request;
59mod response;
60mod util;
61
62use tor_circmgr::{CircMgr, DirInfo};
63use tor_error::bad_api_usage;
64use tor_rtcompat::{Runtime, SleepProvider, SleepProviderExt};
65
66// Zlib is required; the others are optional.
67#[cfg(feature = "xz")]
68use async_compression::futures::bufread::XzDecoder;
69use async_compression::futures::bufread::ZlibDecoder;
70#[cfg(feature = "zstd")]
71use async_compression::futures::bufread::ZstdDecoder;
72
73use futures::FutureExt;
74use futures::io::{
75    AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader,
76};
77use memchr::memchr;
78use std::sync::Arc;
79use std::time::Duration;
80use tracing::{info, instrument};
81
82pub use err::{Error, RequestError, RequestFailedError};
83pub use response::{DirResponse, SourceInfo};
84
85/// Type for results returned in this crate.
86pub type Result<T> = std::result::Result<T, Error>;
87
88/// Type for internal results  containing a RequestError.
89pub type RequestResult<T> = std::result::Result<T, RequestError>;
90
91/// Flag to declare whether a request is always anonymized or not.
92///
93/// This is used by tor-dirclient to control whether *other* deanonymizing metadata
94/// might be added to the request (eg in request headers):
95/// Some requests (like those to download onion service descriptors) are always
96/// anonymized, and should never be sent in a way that leaks information about
97/// our settings or configuration.
98///
99/// It is up to the *caller* of `tor-dirclient` to ensure that
100///
101///   - every request whose anonymization status is `AnonymizedRequest::Direct`
102///     is sent only over non-anonymous connections.
103///
104///     (Sending an `AnonymizedRequest::Direct` request over an anonymized connection
105///     would weaken the connection's anonymity, and can therefore weaken the anonymity
106///     of user traffic sharing the same circuit.)
107///
108///   - every request whose anonymization status is `AnonymizedRequest::Anonymized`
109///     is sent over only anonymous connections (ie, multi-hop circuits).
110///
111///     (Sending an `AnonymizedRequest::Anonymized` request over a direct connection
112///     would directly reveal user behaviour data to the directory server.)
113///
114/// TODO the calling code cannot easily be sure to get this right this because
115/// the anonymization status is a run-time property and the choice of connection kind
116/// is statically defined in the calling code.  (Perhaps this could be checked in tests?)
117#[derive(Copy, Clone, Debug, Eq, PartialEq)]
118#[non_exhaustive]
119pub enum AnonymizedRequest {
120    /// This request's content or semantics reveals or is correlated with sensitive information.
121    ///
122    /// For example, requests for hidden service descriptors reveal which hidden services
123    /// the client is connecting to.
124    ///
125    /// The request must be sent over an anonymous circuit by the caller
126    /// and no additional deanonymizing information should be added to it by `tor-dirclient`.
127    /// (For example, no client-version-specific information should be
128    /// sent in HTTP headers when the request is made.)
129    Anonymized,
130
131    /// Making this request does not reveal anything sensitive, nor any user behaviour.
132    ///
133    /// The request body is uncorrelated with such things as the websites the user might visit,
134    /// the onion services the user is visiting or running, etc.
135    ///
136    /// For example, requests for all router microdescriptors are made by all clients,
137    /// so which microdescriptor(s) are requested reveals nothing to any attacker.
138    ///
139    /// tor-dirclient is allowed to add include information about our capabilities
140    /// when sending this request.
141    /// The request must *not* be sent over an anonymous circuit by the caller
142    /// (at least, not one used for anything else).
143    Direct,
144}
145
146/// Fetch the resource described by `req` over the Tor network.
147///
148/// Circuits are built or found using `circ_mgr`, using paths
149/// constructed using `dirinfo`.
150///
151/// For more fine-grained control over the circuit and stream used,
152/// construct them yourself, and then call [`send_request`] instead.
153///
154/// # TODO
155///
156/// This is the only function in this crate that knows about CircMgr and
157/// DirInfo.  Perhaps this function should move up a level into DirMgr?
158#[instrument(level = "trace", skip_all)]
159pub async fn get_resource<CR, R, SP>(
160    req: &CR,
161    dirinfo: DirInfo<'_>,
162    runtime: &SP,
163    circ_mgr: Arc<CircMgr<R>>,
164) -> Result<DirResponse>
165where
166    CR: request::Requestable + ?Sized,
167    R: Runtime,
168    SP: SleepProvider,
169{
170    let tunnel = circ_mgr.get_or_launch_dir(dirinfo).await?;
171
172    if req.anonymized() == AnonymizedRequest::Anonymized {
173        return Err(bad_api_usage!("Tried to use get_resource for an anonymized request").into());
174    }
175
176    // TODO(nickm) This should be an option, and is too long.
177    let begin_timeout = Duration::from_secs(5);
178    let source = match SourceInfo::from_tunnel(&tunnel) {
179        Ok(source) => source,
180        Err(e) => {
181            return Err(Error::RequestFailed(RequestFailedError {
182                source: None,
183                error: e.into(),
184            }));
185        }
186    };
187
188    let wrap_err = |error| {
189        Error::RequestFailed(RequestFailedError {
190            source: source.clone(),
191            error,
192        })
193    };
194
195    req.check_circuit(&tunnel).await.map_err(wrap_err)?;
196
197    // Launch the stream.
198    let mut stream = runtime
199        .timeout(begin_timeout, tunnel.begin_dir_stream())
200        .await
201        .map_err(RequestError::from)
202        .map_err(wrap_err)?
203        .map_err(RequestError::from)
204        .map_err(wrap_err)?; // TODO(nickm) handle fatalities here too
205
206    // TODO: Perhaps we want separate timeouts for each phase of this.
207    // For now, we just use higher-level timeouts in `dirmgr`.
208    let r = send_request(runtime, req, &mut stream, source.clone()).await;
209
210    if should_retire_circ(&r) {
211        retire_circ(&circ_mgr, &tunnel.unique_id(), "Partial response");
212    }
213
214    r
215}
216
217/// Return true if `result` holds an error indicating that we should retire the
218/// circuit used for the corresponding request.
219fn should_retire_circ(result: &Result<DirResponse>) -> bool {
220    match result {
221        Err(e) => e.should_retire_circ(),
222        Ok(dr) => dr.error().map(RequestError::should_retire_circ) == Some(true),
223    }
224}
225
226/// Fetch a Tor directory object from a provided stream.
227#[deprecated(since = "0.8.1", note = "Use send_request instead.")]
228pub async fn download<R, S, SP>(
229    runtime: &SP,
230    req: &R,
231    stream: &mut S,
232    source: Option<SourceInfo>,
233) -> Result<DirResponse>
234where
235    R: request::Requestable + ?Sized,
236    S: AsyncRead + AsyncWrite + Send + Unpin,
237    SP: SleepProvider,
238{
239    send_request(runtime, req, stream, source).await
240}
241
242/// Fetch or upload a Tor directory object using the provided stream.
243///
244/// To do this, we send a simple HTTP/1.0 request for the described
245/// object in `req` over `stream`, and then wait for a response.  In
246/// log messages, we describe the origin of the data as coming from
247/// `source`.
248///
249/// # Notes
250///
251/// It's kind of bogus to have a 'source' field here at all; we may
252/// eventually want to remove it.
253///
254/// This function doesn't close the stream; you may want to do that
255/// yourself.
256///
257/// The only error variant returned is [`Error::RequestFailed`].
258// TODO: should the error return type change to `RequestFailedError`?
259// If so, that would simplify some code in_dirmgr::bridgedesc.
260pub async fn send_request<R, S, SP>(
261    runtime: &SP,
262    req: &R,
263    stream: &mut S,
264    source: Option<SourceInfo>,
265) -> Result<DirResponse>
266where
267    R: request::Requestable + ?Sized,
268    S: AsyncRead + AsyncWrite + Send + Unpin,
269    SP: SleepProvider,
270{
271    let wrap_err = |error| {
272        Error::RequestFailed(RequestFailedError {
273            source: source.clone(),
274            error,
275        })
276    };
277
278    let partial_ok = req.partial_response_body_ok();
279    let maxlen = req.max_response_len();
280    let anonymized = req.anonymized();
281    let req = req.make_request().map_err(wrap_err)?;
282    let method = req.method().clone();
283    let encoded = util::encode_request(&req);
284
285    // Write the request.
286    for chunk in encoded.iter() {
287        stream
288            .write_all(chunk)
289            .await
290            .map_err(RequestError::from)
291            .map_err(wrap_err)?;
292    }
293    stream
294        .flush()
295        .await
296        .map_err(RequestError::from)
297        .map_err(wrap_err)?;
298
299    let mut buffered = BufReader::new(stream);
300
301    // Handle the response
302    // TODO: should there be a separate timeout here?
303    let header = read_headers(&mut buffered).await.map_err(wrap_err)?;
304    if header.status != Some(200) {
305        return Ok(DirResponse::new(
306            method,
307            header.status.unwrap_or(0),
308            header.status_message,
309            None,
310            vec![],
311            source,
312        ));
313    }
314
315    let mut decoder =
316        get_decoder(buffered, header.encoding.as_deref(), anonymized).map_err(wrap_err)?;
317
318    let mut result = Vec::new();
319    let ok = read_and_decompress(runtime, &mut decoder, maxlen, &mut result).await;
320
321    let ok = match (partial_ok, ok, result.len()) {
322        (true, Err(e), n) if n > 0 => {
323            // Note that we _don't_ return here: we want the partial response.
324            Err(e)
325        }
326        (_, Err(e), _) => {
327            return Err(wrap_err(e));
328        }
329        (_, Ok(()), _) => Ok(()),
330    };
331
332    Ok(DirResponse::new(
333        method,
334        200,
335        None,
336        ok.err(),
337        result,
338        source,
339    ))
340}
341
342/// Maximum length for the HTTP headers in a single request or response.
343///
344/// Chosen more or less arbitrarily.
345const MAX_HEADERS_LEN: usize = 16384;
346
347/// Read and parse HTTP/1 headers from `stream`.
348async fn read_headers<S>(stream: &mut S) -> RequestResult<HeaderStatus>
349where
350    S: AsyncBufRead + Unpin,
351{
352    let mut buf = Vec::with_capacity(1024);
353
354    loop {
355        // TODO: it's inefficient to do this a line at a time; it would
356        // probably be better to read until the CRLF CRLF ending of the
357        // response.  But this should be fast enough.
358        let n = read_until_limited(stream, b'\n', 2048, &mut buf).await?;
359
360        // TODO(nickm): Better maximum and/or let this expand.
361        let mut headers = [httparse::EMPTY_HEADER; 32];
362        let mut response = httparse::Response::new(&mut headers);
363
364        match response.parse(&buf[..])? {
365            httparse::Status::Partial => {
366                // We didn't get a whole response; we may need to try again.
367
368                if n == 0 {
369                    // We hit an EOF; no more progress can be made.
370                    return Err(RequestError::TruncatedHeaders);
371                }
372
373                if buf.len() >= MAX_HEADERS_LEN {
374                    return Err(RequestError::HeadersTooLong(buf.len()));
375                }
376            }
377            httparse::Status::Complete(n_parsed) => {
378                if response.code != Some(200) {
379                    return Ok(HeaderStatus {
380                        status: response.code,
381                        status_message: response.reason.map(str::to_owned),
382                        encoding: None,
383                    });
384                }
385                let encoding = if let Some(enc) = response
386                    .headers
387                    .iter()
388                    .find(|h| h.name == "Content-Encoding")
389                {
390                    Some(String::from_utf8(enc.value.to_vec())?)
391                } else {
392                    None
393                };
394                /*
395                if let Some(clen) = response.headers.iter().find(|h| h.name == "Content-Length") {
396                    let clen = std::str::from_utf8(clen.value)?;
397                    length = Some(clen.parse()?);
398                }
399                 */
400                assert!(n_parsed == buf.len());
401                return Ok(HeaderStatus {
402                    status: Some(200),
403                    status_message: None,
404                    encoding,
405                });
406            }
407        }
408        if n == 0 {
409            return Err(RequestError::TruncatedHeaders);
410        }
411    }
412}
413
414/// Return value from read_headers
415#[derive(Debug, Clone)]
416struct HeaderStatus {
417    /// HTTP status code.
418    status: Option<u16>,
419    /// HTTP status message associated with the status code.
420    status_message: Option<String>,
421    /// The Content-Encoding header, if any.
422    encoding: Option<String>,
423}
424
425/// Helper: download directory information from `stream` and
426/// decompress it into a result buffer.  Assumes that `buf` is empty.
427///
428/// If we get more than maxlen bytes after decompression, give an error.
429///
430/// Returns the status of our download attempt, stores any data that
431/// we were able to download into `result`.  Existing contents of
432/// `result` are overwritten.
433async fn read_and_decompress<S, SP>(
434    runtime: &SP,
435    mut stream: S,
436    maxlen: usize,
437    result: &mut Vec<u8>,
438) -> RequestResult<()>
439where
440    S: AsyncRead + Unpin,
441    SP: SleepProvider,
442{
443    let buffer_window_size = 1024;
444    let mut written_total: usize = 0;
445    // TODO(nickm): This should be an option, and is maybe too long.
446    // Though for some users it may be too short?
447    let read_timeout = Duration::from_secs(10);
448    let timer = runtime.sleep(read_timeout).fuse();
449    futures::pin_mut!(timer);
450
451    loop {
452        // allocate buffer for next read
453        result.resize(written_total + buffer_window_size, 0);
454        let buf: &mut [u8] = &mut result[written_total..written_total + buffer_window_size];
455
456        let status = futures::select! {
457            status = stream.read(buf).fuse() => status,
458            _ = timer => {
459                result.resize(written_total, 0); // truncate as needed
460                return Err(RequestError::DirTimeout);
461            }
462        };
463        let written_in_this_loop = match status {
464            Ok(n) => n,
465            Err(other) => {
466                result.resize(written_total, 0); // truncate as needed
467                return Err(other.into());
468            }
469        };
470
471        written_total += written_in_this_loop;
472
473        // exit conditions below
474
475        if written_in_this_loop == 0 {
476            /*
477            in case we read less than `buffer_window_size` in last `read`
478            we need to shrink result because otherwise we'll return those
479            un-read 0s
480            */
481            if written_total < result.len() {
482                result.resize(written_total, 0);
483            }
484            return Ok(());
485        }
486
487        // TODO: It would be good to detect compression bombs, but
488        // that would require access to the internal stream, which
489        // would in turn require some tricky programming.  For now, we
490        // use the maximum length here to prevent an attacker from
491        // filling our RAM.
492        if written_total > maxlen {
493            result.resize(maxlen, 0);
494            return Err(RequestError::ResponseTooLong(written_total));
495        }
496    }
497}
498
499/// Retire a directory circuit because of an error we've encountered on it.
500fn retire_circ<R>(circ_mgr: &Arc<CircMgr<R>>, id: &tor_proto::circuit::UniqId, error: &str)
501where
502    R: Runtime,
503{
504    info!(
505        "{}: Retiring circuit because of directory failure: {}",
506        &id, &error
507    );
508    circ_mgr.retire_circ(id);
509}
510
511/// As AsyncBufReadExt::read_until, but stops after reading `max` bytes.
512///
513/// Note that this function might not actually read any byte of value
514/// `byte`, since EOF might occur, or we might fill the buffer.
515///
516/// A return value of 0 indicates an end-of-file.
517async fn read_until_limited<S>(
518    stream: &mut S,
519    byte: u8,
520    max: usize,
521    buf: &mut Vec<u8>,
522) -> std::io::Result<usize>
523where
524    S: AsyncBufRead + Unpin,
525{
526    let mut n_added = 0;
527    loop {
528        let data = stream.fill_buf().await?;
529        if data.is_empty() {
530            // End-of-file has been reached.
531            return Ok(n_added);
532        }
533        debug_assert!(n_added < max);
534        let remaining_space = max - n_added;
535        let (available, found_byte) = match memchr(byte, data) {
536            Some(idx) => (idx + 1, true),
537            None => (data.len(), false),
538        };
539        debug_assert!(available >= 1);
540        let n_to_copy = std::cmp::min(remaining_space, available);
541        buf.extend(&data[..n_to_copy]);
542        stream.consume_unpin(n_to_copy);
543        n_added += n_to_copy;
544        if found_byte || n_added == max {
545            return Ok(n_added);
546        }
547    }
548}
549
550/// Helper: Return a boxed decoder object that wraps the stream  $s.
551macro_rules! decoder {
552    ($dec:ident, $s:expr) => {{
553        let mut decoder = $dec::new($s);
554        decoder.multiple_members(true);
555        Ok(Box::new(decoder))
556    }};
557}
558
559/// Wrap `stream` in an appropriate type to undo the content encoding
560/// as described in `encoding`.
561fn get_decoder<'a, S: AsyncBufRead + Unpin + Send + 'a>(
562    stream: S,
563    encoding: Option<&str>,
564    anonymized: AnonymizedRequest,
565) -> RequestResult<Box<dyn AsyncRead + Unpin + Send + 'a>> {
566    use AnonymizedRequest::Direct;
567    match (encoding, anonymized) {
568        (None | Some("identity"), _) => Ok(Box::new(stream)),
569        (Some("deflate"), _) => decoder!(ZlibDecoder, stream),
570        // We only admit to supporting these on a direct connection; otherwise,
571        // a hostile directory could send them back even though we hadn't
572        // requested them.
573        #[cfg(feature = "xz")]
574        (Some("x-tor-lzma"), Direct) => decoder!(XzDecoder, stream),
575        #[cfg(feature = "zstd")]
576        (Some("x-zstd"), Direct) => decoder!(ZstdDecoder, stream),
577        (Some(other), _) => Err(RequestError::ContentEncoding(other.into())),
578    }
579}
580
581#[cfg(test)]
582mod test {
583    // @@ begin test lint list maintained by maint/add_warning @@
584    #![allow(clippy::bool_assert_comparison)]
585    #![allow(clippy::clone_on_copy)]
586    #![allow(clippy::dbg_macro)]
587    #![allow(clippy::mixed_attributes_style)]
588    #![allow(clippy::print_stderr)]
589    #![allow(clippy::print_stdout)]
590    #![allow(clippy::single_char_pattern)]
591    #![allow(clippy::unwrap_used)]
592    #![allow(clippy::unchecked_time_subtraction)]
593    #![allow(clippy::useless_vec)]
594    #![allow(clippy::needless_pass_by_value)]
595    #![allow(clippy::string_slice)] // See arti#2571
596    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
597    use super::*;
598    use tor_rtmock::io::stream_pair;
599
600    use tor_rtmock::simple_time::SimpleMockTimeProvider;
601    use web_time_compat::{SystemTime, SystemTimeExt};
602
603    use futures_await_test::async_test;
604
605    #[async_test]
606    async fn test_read_until_limited() -> RequestResult<()> {
607        let mut out = Vec::new();
608        let bytes = b"This line eventually ends\nthen comes another\n";
609
610        // Case 1: find a whole line.
611        let mut s = &bytes[..];
612        let res = read_until_limited(&mut s, b'\n', 100, &mut out).await;
613        assert_eq!(res?, 26);
614        assert_eq!(&out[..], b"This line eventually ends\n");
615
616        // Case 2: reach the limit.
617        let mut s = &bytes[..];
618        out.clear();
619        let res = read_until_limited(&mut s, b'\n', 10, &mut out).await;
620        assert_eq!(res?, 10);
621        assert_eq!(&out[..], b"This line ");
622
623        // Case 3: reach EOF.
624        let mut s = &bytes[..];
625        out.clear();
626        let res = read_until_limited(&mut s, b'Z', 100, &mut out).await;
627        assert_eq!(res?, 45);
628        assert_eq!(&out[..], &bytes[..]);
629
630        Ok(())
631    }
632
633    // Basic decompression wrapper.
634    async fn decomp_basic(
635        encoding: Option<&str>,
636        data: &[u8],
637        maxlen: usize,
638    ) -> (RequestResult<()>, Vec<u8>) {
639        // We don't need to do anything fancy here, since we aren't simulating
640        // a timeout.
641        #[allow(deprecated)] // TODO #1885
642        let mock_time = SimpleMockTimeProvider::from_wallclock(SystemTime::get());
643
644        let mut output = Vec::new();
645        let mut stream = match get_decoder(data, encoding, AnonymizedRequest::Direct) {
646            Ok(s) => s,
647            Err(e) => return (Err(e), output),
648        };
649
650        let r = read_and_decompress(&mock_time, &mut stream, maxlen, &mut output).await;
651
652        (r, output)
653    }
654
655    #[async_test]
656    async fn decompress_identity() -> RequestResult<()> {
657        let mut text = Vec::new();
658        for _ in 0..1000 {
659            text.extend(b"This is a string with a nontrivial length that we'll use to make sure that the loop is executed more than once.");
660        }
661
662        let limit = 10 << 20;
663        let (s, r) = decomp_basic(None, &text[..], limit).await;
664        s?;
665        assert_eq!(r, text);
666
667        let (s, r) = decomp_basic(Some("identity"), &text[..], limit).await;
668        s?;
669        assert_eq!(r, text);
670
671        // Try truncated result
672        let limit = 100;
673        let (s, r) = decomp_basic(Some("identity"), &text[..], limit).await;
674        assert!(s.is_err());
675        assert_eq!(r, &text[..100]);
676
677        Ok(())
678    }
679
680    #[async_test]
681    async fn decomp_zlib() -> RequestResult<()> {
682        let compressed =
683            hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5970c88").unwrap();
684
685        let limit = 10 << 20;
686        let (s, r) = decomp_basic(Some("deflate"), &compressed, limit).await;
687        s?;
688        assert_eq!(r, b"One fish Two fish Red fish Blue fish");
689
690        Ok(())
691    }
692
693    #[cfg(feature = "zstd")]
694    #[async_test]
695    async fn decomp_zstd() -> RequestResult<()> {
696        let compressed = hex::decode("28b52ffd24250d0100c84f6e6520666973682054776f526564426c756520666973680a0200600c0e2509478352cb").unwrap();
697        let limit = 10 << 20;
698        let (s, r) = decomp_basic(Some("x-zstd"), &compressed, limit).await;
699        s?;
700        assert_eq!(r, b"One fish Two fish Red fish Blue fish\n");
701
702        Ok(())
703    }
704
705    #[cfg(feature = "xz")]
706    #[async_test]
707    async fn decomp_xz2() -> RequestResult<()> {
708        // Not so good at tiny files...
709        let compressed = hex::decode("fd377a585a000004e6d6b446020021011c00000010cf58cce00024001d5d00279b88a202ca8612cfb3c19c87c34248a570451e4851d3323d34ab8000000000000901af64854c91f600013925d6ec06651fb6f37d010000000004595a").unwrap();
710        let limit = 10 << 20;
711        let (s, r) = decomp_basic(Some("x-tor-lzma"), &compressed, limit).await;
712        s?;
713        assert_eq!(r, b"One fish Two fish Red fish Blue fish\n");
714
715        Ok(())
716    }
717
718    #[async_test]
719    async fn decomp_unknown() {
720        let compressed = hex::decode("28b52ffd24250d0100c84f6e6520666973682054776f526564426c756520666973680a0200600c0e2509478352cb").unwrap();
721        let limit = 10 << 20;
722        let (s, _r) = decomp_basic(Some("x-proprietary-rle"), &compressed, limit).await;
723
724        assert!(matches!(s, Err(RequestError::ContentEncoding(_))));
725    }
726
727    #[async_test]
728    async fn decomp_bad_data() {
729        let compressed = b"This is not good zlib data";
730        let limit = 10 << 20;
731        let (s, _r) = decomp_basic(Some("deflate"), compressed, limit).await;
732
733        // This should possibly be a different type in the future.
734        assert!(matches!(s, Err(RequestError::IoError(_))));
735    }
736
737    #[async_test]
738    async fn headers_ok() -> RequestResult<()> {
739        let text = b"HTTP/1.0 200 OK\r\nDate: ignored\r\nContent-Encoding: Waffles\r\n\r\n";
740
741        let mut s = &text[..];
742        let h = read_headers(&mut s).await?;
743
744        assert_eq!(h.status, Some(200));
745        assert_eq!(h.encoding.as_deref(), Some("Waffles"));
746
747        // now try truncated
748        let mut s = &text[..15];
749        let h = read_headers(&mut s).await;
750        assert!(matches!(h, Err(RequestError::TruncatedHeaders)));
751
752        // now try with no encoding.
753        let text = b"HTTP/1.0 404 Not found\r\n\r\n";
754        let mut s = &text[..];
755        let h = read_headers(&mut s).await?;
756
757        assert_eq!(h.status, Some(404));
758        assert!(h.encoding.is_none());
759
760        Ok(())
761    }
762
763    #[async_test]
764    async fn headers_bogus() -> Result<()> {
765        let text = b"HTTP/999.0 WHAT EVEN\r\n\r\n";
766        let mut s = &text[..];
767        let h = read_headers(&mut s).await;
768
769        assert!(h.is_err());
770        assert!(matches!(h, Err(RequestError::HttparseError(_))));
771        Ok(())
772    }
773
774    /// Run a trivial download example with a response provided as a binary
775    /// string.
776    ///
777    /// Return the directory response (if any) and the request as encoded (if
778    /// any.)
779    fn run_download_test<Req: request::Requestable>(
780        req: Req,
781        response: &[u8],
782    ) -> (Result<DirResponse>, RequestResult<Vec<u8>>) {
783        let (mut s1, s2) = stream_pair();
784        let (mut s2_r, mut s2_w) = s2.split();
785
786        tor_rtcompat::test_with_one_runtime!(|rt| async move {
787            let rt2 = rt.clone();
788            let (v1, v2, v3): (
789                Result<DirResponse>,
790                RequestResult<Vec<u8>>,
791                RequestResult<()>,
792            ) = futures::join!(
793                async {
794                    // Run the download function.
795                    let r = send_request(&rt, &req, &mut s1, None).await;
796                    s1.close().await.map_err(|error| {
797                        Error::RequestFailed(RequestFailedError {
798                            source: None,
799                            error: error.into(),
800                        })
801                    })?;
802                    r
803                },
804                async {
805                    // Take the request from the client, and return it in "v2"
806                    let mut v = Vec::new();
807                    s2_r.read_to_end(&mut v).await?;
808                    Ok(v)
809                },
810                async {
811                    // Send back a response.
812                    s2_w.write_all(response).await?;
813                    // We wait a moment to give the other side time to notice it
814                    // has data.
815                    //
816                    // (Tentative diagnosis: The `async-compress` crate seems to
817                    // be behave differently depending on whether the "close"
818                    // comes right after the incomplete data or whether it comes
819                    // after a delay.  If there's a delay, it notices the
820                    // truncated data and tells us about it. But when there's
821                    // _no_delay, it treats the data as an error and doesn't
822                    // tell our code.)
823
824                    // TODO: sleeping in tests is not great.
825                    rt2.sleep(Duration::from_millis(50)).await;
826                    s2_w.close().await?;
827                    Ok(())
828                }
829            );
830
831            assert!(v3.is_ok());
832
833            (v1, v2)
834        })
835    }
836
837    #[test]
838    fn test_send_request() -> RequestResult<()> {
839        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
840
841        let (response, request) = run_download_test(
842            req,
843            b"HTTP/1.0 200 OK\r\n\r\nThis is where the descs would go.",
844        );
845
846        let request = request?;
847        assert!(request[..].starts_with(
848            b"GET /tor/micro/d/CQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQk HTTP/1.0\r\n"
849        ));
850
851        let response = response.unwrap();
852        assert_eq!(response.status_code(), 200);
853        assert!(!response.is_partial());
854        assert!(response.error().is_none());
855        assert!(response.source().is_none());
856        let out_ref = response.output_unchecked();
857        assert_eq!(out_ref, b"This is where the descs would go.");
858        let out = response.into_output_unchecked();
859        assert_eq!(&out, b"This is where the descs would go.");
860
861        Ok(())
862    }
863
864    #[test]
865    fn test_download_truncated() {
866        // Request only one md, so "partial ok" will not be set.
867        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
868        let mut response_text: Vec<u8> =
869            (*b"HTTP/1.0 200 OK\r\nContent-Encoding: deflate\r\n\r\n").into();
870        // "One fish two fish" as above twice, but truncated the second time
871        response_text.extend(
872            hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5970c88").unwrap(),
873        );
874        response_text.extend(
875            hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5").unwrap(),
876        );
877        let (response, request) = run_download_test(req, &response_text);
878        assert!(request.is_ok());
879        assert!(response.is_err()); // The whole download should fail, since partial_ok wasn't set.
880
881        // request two microdescs, so "partial_ok" will be set.
882        let req: request::MicrodescRequest = vec![[9; 32]; 2].into_iter().collect();
883
884        let (response, request) = run_download_test(req, &response_text);
885        assert!(request.is_ok());
886
887        let response = response.unwrap();
888        assert_eq!(response.status_code(), 200);
889        assert!(response.error().is_some());
890        assert!(response.is_partial());
891        assert!(response.output_unchecked().len() < 37 * 2);
892        assert!(response.output_unchecked().starts_with(b"One fish"));
893    }
894
895    #[test]
896    fn test_404() {
897        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
898        let response_text = b"HTTP/1.0 418 I'm a teapot\r\n\r\n";
899        let (response, _request) = run_download_test(req, response_text);
900
901        assert_eq!(response.unwrap().status_code(), 418);
902    }
903
904    #[test]
905    fn test_headers_truncated() {
906        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
907        let response_text = b"HTTP/1.0 404 truncation happens here\r\n";
908        let (response, _request) = run_download_test(req, response_text);
909
910        assert!(matches!(
911            response,
912            Err(Error::RequestFailed(RequestFailedError {
913                error: RequestError::TruncatedHeaders,
914                ..
915            }))
916        ));
917
918        // Try a completely empty response.
919        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
920        let response_text = b"";
921        let (response, _request) = run_download_test(req, response_text);
922
923        assert!(matches!(
924            response,
925            Err(Error::RequestFailed(RequestFailedError {
926                error: RequestError::TruncatedHeaders,
927                ..
928            }))
929        ));
930    }
931
932    #[test]
933    fn test_headers_too_long() {
934        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
935        let mut response_text: Vec<u8> = (*b"HTTP/1.0 418 I'm a teapot\r\nX-Too-Many-As: ").into();
936        response_text.resize(16384, b'A');
937        let (response, _request) = run_download_test(req, &response_text);
938
939        assert!(response.as_ref().unwrap_err().should_retire_circ());
940        assert!(matches!(
941            response,
942            Err(Error::RequestFailed(RequestFailedError {
943                error: RequestError::HeadersTooLong(_),
944                ..
945            }))
946        ));
947    }
948
949    // TODO: test with bad utf-8
950}