tor_client_lib/
control_connection.rs

1use crate::{
2    auth::TorAuthentication,
3    error::TorError,
4    key::{TorEd25519SigningKey, TorServiceId},
5};
6use futures::{SinkExt, StreamExt};
7use lazy_static::lazy_static;
8use log::info;
9use regex::{Captures, Regex};
10use serde::{Deserialize, Serialize};
11use std::fmt::{Display, Error, Formatter};
12use std::net::{AddrParseError, SocketAddr as TcpSocketAddr};
13use std::os::unix::net::SocketAddr as UnixSocketAddr;
14use std::path::Path;
15use std::pin::Pin;
16use std::str::FromStr;
17use std::task::{Context, Poll};
18use tokio::{
19    io::{AsyncRead, AsyncWrite, ReadBuf, ReadHalf, WriteHalf},
20    net::{TcpListener, TcpStream, ToSocketAddrs, UnixListener, UnixStream},
21};
22use tokio_socks::{IntoTargetAddr, TargetAddr};
23use tokio_stream::wrappers::{TcpListenerStream, UnixListenerStream};
24use tokio_util::codec::{FramedRead, FramedWrite, LinesCodec, LinesCodecError};
25
26/// Generalization of the [std::net::SocketAddr] for Tor communication.
27/// Clients can communicate with the Tor server either through the standard TCP connection, or
28/// through a Unix socket.
29#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Hash, Ord, Deserialize, Serialize)]
30pub enum TorSocketAddr {
31    Tcp(TcpSocketAddr),
32    Unix(String),
33}
34
35impl TorSocketAddr {
36    /// Create the socket address from a TCP address string of the form "<ip>:<port>"
37    fn from_tcp_string(address: &str) -> Result<Self, AddrParseError> {
38        Ok(Self::Tcp(TcpSocketAddr::from_str(address)?))
39    }
40
41    /// Create the socket address from the path to the unix socket
42    fn from_unix_string<P: AsRef<Path>>(path: P) -> Result<Self, std::io::Error> {
43        Ok(Self::Unix(
44            UnixSocketAddr::from_pathname(path)?
45                .as_pathname()
46                .unwrap()
47                .to_str()
48                .unwrap()
49                .to_string(),
50        ))
51    }
52}
53
54/// Convert from a [std::net::SocketAddr] to this
55impl From<TcpSocketAddr> for TorSocketAddr {
56    fn from(socket_addr: TcpSocketAddr) -> Self {
57        Self::Tcp(socket_addr)
58    }
59}
60
61impl Display for TorSocketAddr {
62    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
63        match self {
64            Self::Tcp(sock_addr) => write!(f, "{}", sock_addr),
65            Self::Unix(sock_addr) => write!(f, "unix:{:?}", sock_addr),
66        }
67    }
68}
69
70/// Error returned when a given listen address type has a parse error
71#[derive(Debug)]
72pub enum ListenAddressParseError {
73    TcpParseError(AddrParseError),
74    UnixParseError(std::io::Error),
75}
76
77impl std::error::Error for ListenAddressParseError {}
78
79impl Display for ListenAddressParseError {
80    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
81        match self {
82            Self::TcpParseError(error) => write!(f, "Error parsing TCP address: {}", error),
83            Self::UnixParseError(error) => write!(f, "Error parsing Unix address:{}", error),
84        }
85    }
86}
87
88impl From<AddrParseError> for ListenAddressParseError {
89    fn from(err: AddrParseError) -> Self {
90        Self::TcpParseError(err)
91    }
92}
93
94impl From<std::io::Error> for ListenAddressParseError {
95    fn from(err: std::io::Error) -> Self {
96        Self::UnixParseError(err)
97    }
98}
99
100impl FromStr for TorSocketAddr {
101    type Err = ListenAddressParseError;
102
103    fn from_str(s: &str) -> Result<Self, Self::Err> {
104        if let Some(path) = s.strip_prefix("unix:") {
105            Ok(Self::from_unix_string(path)?)
106        } else {
107            Ok(Self::from_tcp_string(s)?)
108        }
109    }
110}
111
112/// You can listen for data for an onion service either through TCP or a unix socket
113#[derive(Debug)]
114pub enum OnionServiceListener {
115    Tcp(TcpListener),
116    Unix(UnixListener),
117}
118
119impl OnionServiceListener {
120    /// Bind to the given socket address for listening
121    pub async fn bind(socket_addr: TorSocketAddr) -> Result<OnionServiceListener, std::io::Error> {
122        match socket_addr {
123            TorSocketAddr::Tcp(socket_addr) => Ok(OnionServiceListener::Tcp(
124                TcpListener::bind(socket_addr).await?,
125            )),
126            TorSocketAddr::Unix(path) => Ok(OnionServiceListener::Unix(UnixListener::bind(path)?)),
127        }
128    }
129
130    /// Accept an incoming connection from the listener
131    pub async fn accept(&self) -> Result<(OnionServiceStream, TorSocketAddr), std::io::Error> {
132        match self {
133            Self::Tcp(listener) => {
134                let (stream, socket) = listener.accept().await?;
135                Ok((OnionServiceStream::Tcp(stream), socket.into()))
136            }
137            Self::Unix(listener) => {
138                let (stream, socket) = listener.accept().await?;
139                Ok((
140                    OnionServiceStream::Unix(stream),
141                    TorSocketAddr::Unix(
142                        socket.as_pathname().unwrap().to_string_lossy().to_string(),
143                    ),
144                ))
145            }
146        }
147    }
148
149    pub fn as_stream(self) -> OnionServiceListenerStream {
150        match self {
151            OnionServiceListener::Tcp(listener) => {
152                OnionServiceListenerStream::Tcp(TcpListenerStream::new(listener))
153            }
154            OnionServiceListener::Unix(listener) => {
155                OnionServiceListenerStream::Unix(UnixListenerStream::new(listener))
156            }
157        }
158    }
159}
160
161pub enum OnionServiceListenerStream {
162    Tcp(TcpListenerStream),
163    Unix(UnixListenerStream),
164}
165
166/// A stream of data from an accepted listener socket
167pub enum OnionServiceStream {
168    Tcp(TcpStream),
169    Unix(UnixStream),
170}
171
172impl AsyncRead for OnionServiceStream {
173    fn poll_read(
174        self: Pin<&mut Self>,
175        cx: &mut Context<'_>,
176        buf: &mut ReadBuf<'_>,
177    ) -> Poll<Result<(), std::io::Error>> {
178        match Pin::into_inner(self) {
179            Self::Tcp(stream) => Pin::new(stream).poll_read(cx, buf),
180            Self::Unix(stream) => Pin::new(stream).poll_read(cx, buf),
181        }
182    }
183}
184
185impl AsyncWrite for OnionServiceStream {
186    fn poll_write(
187        self: Pin<&mut Self>,
188        cx: &mut Context<'_>,
189        buf: &[u8],
190    ) -> Poll<Result<usize, std::io::Error>> {
191        match Pin::into_inner(self) {
192            Self::Tcp(stream) => Pin::new(stream).poll_write(cx, buf),
193            Self::Unix(stream) => Pin::new(stream).poll_write(cx, buf),
194        }
195    }
196
197    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
198        match Pin::into_inner(self) {
199            Self::Tcp(stream) => Pin::new(stream).poll_flush(cx),
200            Self::Unix(stream) => Pin::new(stream).poll_flush(cx),
201        }
202    }
203
204    fn poll_shutdown(
205        self: Pin<&mut Self>,
206        cx: &mut Context<'_>,
207    ) -> Poll<Result<(), std::io::Error>> {
208        match Pin::into_inner(self) {
209            Self::Tcp(stream) => Pin::new(stream).poll_shutdown(cx),
210            Self::Unix(stream) => Pin::new(stream).poll_shutdown(cx),
211        }
212    }
213}
214
215/// Mapping from an Onion service virtual port to a local listen address
216#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Deserialize, Serialize)]
217pub struct OnionServiceMapping {
218    virt_port: u16,
219    listen_address: TorSocketAddr,
220}
221
222impl OnionServiceMapping {
223    pub fn new(virt_port: u16, listen_address: Option<TorSocketAddr>) -> Self {
224        Self {
225            virt_port,
226            listen_address: match listen_address {
227                None => {
228                    TorSocketAddr::from_tcp_string(&format!("127.0.0.1:{}", virt_port)).unwrap()
229                }
230                Some(a) => a,
231            },
232        }
233    }
234
235    pub fn virt_port(&self) -> u16 {
236        self.virt_port
237    }
238
239    pub fn listen_address(&self) -> &TorSocketAddr {
240        &self.listen_address
241    }
242}
243
244/// Onion address, containing a [TorServiceId] and a service port
245#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
246pub struct OnionAddress {
247    service_id: TorServiceId,
248    service_port: u16,
249}
250
251impl OnionAddress {
252    pub fn new(service_id: TorServiceId, port: u16) -> Self {
253        Self {
254            service_id,
255            service_port: port,
256        }
257    }
258
259    pub fn service_id(&self) -> &TorServiceId {
260        &self.service_id
261    }
262
263    pub fn hostname(&self) -> String {
264        format!("{}.onion", self.service_id)
265    }
266
267    pub fn service_port(&self) -> u16 {
268        self.service_port
269    }
270}
271
272impl<'a> IntoTargetAddr<'a> for OnionAddress {
273    fn into_target_addr(self) -> tokio_socks::Result<TargetAddr<'a>> {
274        Ok(TargetAddr::Domain(
275            self.hostname().into(),
276            self.service_port,
277        ))
278    }
279}
280
281impl FromStr for OnionAddress {
282    type Err = TorError;
283
284    fn from_str(s: &str) -> Result<Self, Self::Err> {
285        let values = s.split(':').collect::<Vec<&str>>();
286        if values.len() != 2 {
287            return Err(TorError::protocol_error("Bad onion address"));
288        }
289        let host_values = values[0].split('.').collect::<Vec<&str>>();
290        if host_values.len() != 2 || host_values[1] != "onion" {
291            return Err(TorError::protocol_error("Bad onion address"));
292        }
293        let service_id = match TorServiceId::from_str(host_values[0]) {
294            Ok(id) => id,
295            Err(error) => {
296                return Err(TorError::protocol_error(&format!(
297                    "Error parsing host field in onion address: {}",
298                    error
299                )));
300            }
301        };
302        let service_port = match values[1].parse::<u16>() {
303            Ok(port) => port,
304            Err(error) => {
305                return Err(TorError::protocol_error(&format!(
306                    "Error parsing port field in onion address: {}",
307                    error
308                )));
309            }
310        };
311        Ok(Self {
312            service_id,
313            service_port,
314        })
315    }
316}
317
318impl Display for OnionAddress {
319    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
320        write!(f, "{}.onion:{}", self.service_id, self.service_port)
321    }
322}
323
324/// Definition of a Tor Onion service (AKA "hidden service").
325///
326/// An onion service can be thought of as an encrypted load balancer, which presents itself as a
327/// virtual host in the Tor network, and which maps virtual ports on that virtual host to service
328/// ports running on your local machine. While, in practice, most onion services map a single
329/// virtual port to a service port, say, 443 to 443, you can map multiple virtual ports to a single
330/// service port, or a single virtual port to multiple service ports (in which case Tor will load
331/// balance the traffic coming in on the virtual port across the corresponding service ports).
332///
333/// Each onion service has the following:
334/// - The service ID contains all the information for the public key (see [TorServiceId] for
335/// - details).
336/// - The signing, i.e, private, key for the onion service
337/// - The mapping from the virtual port(s) to the service port(s)
338pub struct OnionService {
339    ports: Vec<OnionServiceMapping>,
340    service_id: TorServiceId,
341    signing_key: TorEd25519SigningKey,
342}
343
344impl OnionService {
345    /// Create a new `OnionService` object
346    pub fn new<K>(key: K, ports: &[OnionServiceMapping]) -> Self
347    where
348        TorEd25519SigningKey: From<K>,
349    {
350        let signing_key: TorEd25519SigningKey = key.into();
351        let service_id = signing_key.verifying_key().into();
352        Self {
353            ports: ports.to_vec(),
354            service_id,
355            signing_key,
356        }
357    }
358
359    /// Return all the listen addresses for a given onion address (including virtual port)
360    /// `onion_address` should be formatted as `<onion-address>:<port>`, e.g.
361    /// `joikeok6el5h5sbrojo2h3afw63lmfm7huvwtziacl34wjrx7n62gsad.onion:443`
362    pub fn listen_addresses_for_onion_address(&self, onion_address: &str) -> Vec<TorSocketAddr> {
363        self.ports
364            .iter()
365            .map(|p| (p, format!("{}.onion:{}", self.service_id, p.virt_port)))
366            .filter(|(_p, a)| a == onion_address)
367            .map(|(p, _a)| p.listen_address.clone())
368            .collect()
369    }
370
371    /// Return all the listen addresses for the given local service port
372    pub fn listen_addresses_for_port(&self, service_port: u16) -> Vec<TorSocketAddr> {
373        self.ports
374            .iter()
375            .filter(|p| p.virt_port == service_port)
376            .map(|p| p.listen_address.clone())
377            .collect()
378    }
379
380    /// Return the onion address (i.e., the onion hostname and virtual port) which maps to the
381    /// given local service port
382    pub fn onion_address(&self, service_port: u16) -> Result<OnionAddress, TorError> {
383        if self.ports.iter().any(|p| p.virt_port == service_port) {
384            Ok(OnionAddress {
385                service_id: self.service_id.clone(),
386                service_port,
387            })
388        } else {
389            Err(TorError::protocol_error(&format!(
390                "No Onion Service Port {} found for onion service {}",
391                service_port, self.service_id
392            )))
393        }
394    }
395
396    /// Return a list of all the onion addresses for this onion service
397    pub fn onion_addresses(&self) -> Vec<OnionAddress> {
398        self.ports
399            .iter()
400            .map(|p| OnionAddress::new(self.service_id.clone(), p.virt_port))
401            .collect()
402    }
403
404    /// Return the [TorServiceId] for this onion service
405    pub fn service_id(&self) -> &TorServiceId {
406        &self.service_id
407    }
408
409    /// Return the Tor signing key for this onion service
410    pub fn signing_key(&self) -> &TorEd25519SigningKey {
411        &self.signing_key
412    }
413
414    // Provide a way to destructively retrieve the signing key
415    fn into_signing_key(self) -> TorEd25519SigningKey {
416        self.signing_key
417    }
418
419    /// Return the list of virtual to service port mappings for this onion service
420    pub fn ports(&self) -> Vec<OnionServiceMapping> {
421        self.ports.clone()
422    }
423}
424
425/// Convert an onion service to a signing key
426impl From<OnionService> for TorEd25519SigningKey {
427    fn from(onion_service: OnionService) -> Self {
428        onion_service.into_signing_key()
429    }
430}
431
432/// Response returned by the Tor server in response to a command
433#[derive(Debug)]
434pub struct ControlResponse {
435    pub status_code: u16,
436    pub reply: String,
437}
438
439impl ControlResponse {
440    fn new() -> Self {
441        Self {
442            status_code: 0,
443            reply: String::new(),
444        }
445    }
446}
447
448fn parse_status_code(code_str: &str) -> Result<u16, TorError> {
449    match code_str.parse::<u16>() {
450        Ok(status_code) => Ok(status_code),
451        Err(error) => Err(TorError::protocol_error(&format!(
452            "Error parsing response status code: {}",
453            error
454        ))),
455    }
456}
457
458/// Read a response to a controller command
459async fn read_control_response<S: StreamExt<Item = Result<String, LinesCodecError>> + Unpin>(
460    reader: &mut S,
461) -> Result<ControlResponse, TorError> {
462    lazy_static! {
463        // Mid reply
464        static ref MID_REGEX: Regex = Regex::new(r"^(?P<code>\d{3})-(?P<reply_line>.*)$").unwrap();
465
466        // Data reply
467        static ref DATA_REGEX: Regex =
468            Regex::new(r"^(?P<code>\d{3})\+(?P<reply_line>.*)$").unwrap();
469
470        // End of reply message
471        static ref END_REGEX: Regex = Regex::new(r"^(?P<code>\d{3}) (?P<reply_line>.*)$").unwrap();
472    }
473
474    let mut control_response = ControlResponse::new();
475    loop {
476        let mut line = read_line(reader).await?;
477        info!("<= {}", line);
478        match MID_REGEX.captures(&line) {
479            // Read Mid replies line-by-line, and append their reply lines to the reply
480            Some(captures) => {
481                control_response.status_code = parse_status_code(&captures["code"])?;
482                control_response
483                    .reply
484                    .push_str(&format!("{}\n", &captures["reply_line"]));
485            }
486            None => match DATA_REGEX.captures(&line.clone()) {
487                // For Data replies, append everything between the initial line and the "." to the reply line
488                Some(captures) => {
489                    control_response.status_code = parse_status_code(&captures["code"])?;
490                    let mut reply_line = captures["reply_line"].to_string();
491                    reply_line.push('\n');
492                    loop {
493                        line = read_line(reader).await?;
494                        if line == "." {
495                            break;
496                        }
497                        reply_line.push_str(&line);
498                        reply_line.push('\n');
499                    }
500                    control_response.reply = reply_line;
501                    // Read the final "250 OK"
502                    read_line(reader).await?;
503                    return Ok(control_response);
504                }
505                None => match END_REGEX.captures(&line) {
506                    Some(captures) => {
507                        control_response.status_code = parse_status_code(&captures["code"])?;
508                        // If we haven't gotten any other replies, use this one as the message
509                        if control_response.reply.is_empty() {
510                            control_response.reply.push_str(&captures["reply_line"]);
511                        }
512                        return Ok(control_response);
513                    }
514                    None => {
515                        return Err(TorError::ProtocolError(format!(
516                            "Unknown response: {}",
517                            line
518                        )))
519                    }
520                },
521            },
522        }
523    }
524}
525
526/// Read a response line
527async fn read_line<S: StreamExt<Item = Result<String, LinesCodecError>> + Unpin>(
528    reader: &mut S,
529) -> Result<String, TorError> {
530    match reader.next().await {
531        Some(Ok(line)) => Ok(line),
532        Some(Err(error)) => Err(error.into()),
533        None => Err(TorError::protocol_error("Unexpected EOF on stream")),
534    }
535}
536
537/// Format the ADD_ONION request arguments
538fn format_onion_service_request_string(
539    key_type: &str,
540    key_blob: &str,
541    ports: &[OnionServiceMapping],
542    transient: bool,
543) -> String {
544    let flags = if transient { "" } else { "Flags=Detach" };
545    let port_string = ports
546        .iter()
547        .map(|p| format!("Port={},{}", p.virt_port, p.listen_address))
548        .collect::<Vec<String>>()
549        .join(" ");
550    format!("{}:{} {} {}", key_type, key_blob, flags, port_string)
551}
552
553fn format_key_request_string(
554    ports: &[OnionServiceMapping],
555    transient: bool,
556    signing_key: Option<&TorEd25519SigningKey>,
557) -> String {
558    match signing_key {
559        Some(signing_key) => format_onion_service_request_string(
560            "ED25519-V3",
561            &signing_key.to_blob(),
562            ports,
563            transient,
564        ),
565        None => format_onion_service_request_string("NEW", "BEST", ports, transient),
566    }
567}
568
569/// Parse a response field that is required, i.e., throw an error if it's not there
570fn parse_required_response_field<'a>(
571    captures: &Captures<'a>,
572    field_name: &str,
573    field_arg: &str,
574    response_type: &str,
575) -> Result<&'a str, TorError> {
576    match captures.name(field_name) {
577        Some(field) => Ok(field.as_str()),
578        None => Err(TorError::protocol_error(&format!(
579            "'{}' field not found in {} response",
580            field_arg, response_type,
581        ))),
582    }
583}
584
585fn parse_add_onion_response(
586    captures: &Captures<'_>,
587    ports: &[OnionServiceMapping],
588    signing_key: Option<TorEd25519SigningKey>,
589) -> Result<OnionService, TorError> {
590    let service_id =
591        parse_required_response_field(captures, "service_id", "ServiceID", "ADD_ONION")?;
592
593    // Retrieve the key, either the one passed in or the one
594    // returned from the controller
595    let (returned_signing_key, verifying_key) = match signing_key {
596        Some(signing_key) => {
597            let verifying_key = signing_key.verifying_key();
598            (signing_key, verifying_key)
599        }
600        None => match captures.name("key_type") {
601            Some(_) => {
602                let signing_key =
603                    TorEd25519SigningKey::from_blob(captures.name("key_blob").unwrap().as_str())
604                        .unwrap();
605                let verifying_key = signing_key.verifying_key();
606                (signing_key, verifying_key)
607            }
608            None => {
609                return Err(TorError::protocol_error(
610                    "Expected signing key to be returned by Tor",
611                ));
612            }
613        },
614    };
615
616    let expected_service_id: TorServiceId = verifying_key.into();
617
618    if expected_service_id.as_str() != service_id {
619        return Err(
620            TorError::protocol_error(&format!(
621                    "Service ID for onion service returned by tor ({}) doesn't match the service ID generated from verifying key ({})",
622                    service_id, expected_service_id.as_str())));
623    }
624
625    if let Err(error) = TorServiceId::from_str(service_id) {
626        return Err(TorError::protocol_error(&format!(
627            "Error parsing Tor Service ID: {}",
628            error
629        )));
630    }
631
632    // Return the Onion Service
633    Ok(OnionService::new(returned_signing_key, ports))
634}
635
636/// ProtocolInfo struct, contains information from the response to the
637/// PROTOCOLINFO command
638#[derive(Clone, Debug)]
639pub struct ProtocolInfo {
640    pub auth_methods: Vec<String>,
641    pub cookie_file: Option<String>,
642    pub tor_version: String,
643}
644
645/// Control connection, used to send commands to and receive responses from
646/// the Tor server
647#[derive(Debug)]
648pub struct TorControlConnection {
649    reader: FramedRead<ReadHalf<TcpStream>, LinesCodec>,
650    writer: FramedWrite<WriteHalf<TcpStream>, LinesCodec>,
651    protocol_info: Option<ProtocolInfo>,
652}
653
654impl TorControlConnection {
655    /// Connect to the Tor server. This is generally how you create a connection to the server
656    pub async fn connect<A: ToSocketAddrs>(addrs: A) -> Result<Self, TorError> {
657        let this = Self::with_stream(TcpStream::connect(addrs).await?)?;
658        Ok(this)
659    }
660
661    /// Convert an existing TCPStream into a connection object
662    pub(crate) fn with_stream(stream: TcpStream) -> Result<Self, TorError> {
663        let (reader, writer) = tokio::io::split(stream);
664        Ok(Self {
665            reader: FramedRead::new(reader, LinesCodec::new()),
666            writer: FramedWrite::new(writer, LinesCodec::new()),
667            protocol_info: None,
668        })
669    }
670
671    /// Write to the Tor Server
672    async fn write(&mut self, data: &str) -> Result<(), TorError> {
673        self.writer.send(data).await?;
674        Ok(())
675    }
676
677    /// Send the PROTOCOLINFO command and parse the response
678    pub async fn get_protocol_info(&mut self) -> Result<ProtocolInfo, TorError> {
679        if self.protocol_info.is_some() {
680            Ok(self.protocol_info.clone().unwrap())
681        } else {
682            let control_response = self.send_command("PROTOCOLINFO", Some("1")).await?;
683
684            if control_response.status_code != 250 {
685                return Err(TorError::protocol_error(&format!(
686                    "Expected status code 250, got {}",
687                    control_response.status_code
688                )));
689            }
690
691            // Parse the controller response
692            lazy_static! {
693                static ref RE: Regex =
694                    Regex::new(r"^PROTOCOLINFO 1\nAUTH METHODS=(?P<auth_methods>[^ ]*)( COOKIEFILE=(?P<cookie_file>.*))*\nVERSION Tor=(?P<tor_version>.*)\n")
695                        .unwrap();
696            }
697            let captures = match RE.captures(&control_response.reply) {
698                Some(captures) => captures,
699                None => {
700                    return Err(TorError::protocol_error(
701                        "Error parsing PROTOCOLINFO response",
702                    ))
703                }
704            };
705            let auth_methods = parse_required_response_field(
706                &captures,
707                "auth_methods",
708                "AUTH METHODS",
709                "PROTOCOLINFO",
710            )?
711            .split(',')
712            .map(|s| s.to_string())
713            .collect();
714            let tor_version =
715                parse_required_response_field(&captures, "tor_version", "VERSION", "PROTOCOLINFO")?
716                    .replace('"', "");
717            let protocol_info = ProtocolInfo {
718                auth_methods,
719                cookie_file: captures
720                    .name("cookie_file")
721                    .map(|c| c.as_str().replace('"', "").to_string()),
722                tor_version,
723            };
724            self.protocol_info = Some(protocol_info.clone());
725            Ok(protocol_info)
726        }
727    }
728
729    /// Send the GETINFO command and parse the response
730    pub async fn get_info(&mut self, info: &str) -> Result<Vec<String>, TorError> {
731        let control_response = self.send_command("GETINFO", Some(info)).await?;
732        info!(
733            "Send GETINFO command, got control response {:?}",
734            control_response
735        );
736        if control_response.status_code != 250 {
737            return Err(TorError::protocol_error(&format!(
738                "Expected status code 250, got {}",
739                control_response.status_code
740            )));
741        }
742        let split_response = &control_response
743            .reply
744            .trim_end()
745            .split('=')
746            .collect::<Vec<&str>>();
747        if split_response.len() <= 1 {
748            return Err(TorError::protocol_error(&format!(
749                "Got unexpected reply '{}', expected key/value pair",
750                control_response.reply
751            )));
752        }
753
754        let response = split_response[1].split('\n').collect::<Vec<&str>>();
755
756        let mut ret = Vec::new();
757        for value in response.iter() {
758            if !value.is_empty() {
759                ret.push(value.to_string());
760            }
761        }
762
763        Ok(ret)
764    }
765
766    /// Authenticate to the Tor server using the passed-in method
767    pub async fn authenticate(&mut self, method: &TorAuthentication) -> Result<(), TorError> {
768        method.authenticate(self).await?;
769        Ok(())
770    }
771
772    /// Send a general command to the Tor server
773    pub(crate) async fn send_command(
774        &mut self,
775        command: &str,
776        arguments: Option<&str>,
777    ) -> Result<ControlResponse, TorError> {
778        let command_string = match arguments {
779            None => command.to_string(),
780            Some(arguments) => format!("{} {}", command, arguments),
781        };
782        info!("=> {}", command_string);
783        self.write(&command_string).await?;
784        match read_control_response(&mut self.reader).await {
785            Ok(control_response) => match control_response.status_code {
786                250 | 251 => Ok(control_response),
787                _ => Err(TorError::ProtocolError(control_response.reply)),
788            },
789            Err(error) => Err(error),
790        }
791    }
792
793    /// Create an onion service.
794    pub async fn create_onion_service(
795        &mut self,
796        ports: &[OnionServiceMapping],
797        transient: bool,
798        signing_key: Option<TorEd25519SigningKey>,
799    ) -> Result<OnionService, TorError> {
800        // Create the request string from the arguments
801        let request_string = format_key_request_string(ports, transient, signing_key.as_ref());
802
803        // Send command to Tor controller
804        let control_response = self
805            .send_command("ADD_ONION", Some(&request_string))
806            .await?;
807        info!(
808            "Sent ADD_ONION command, got control response {:?}",
809            control_response
810        );
811
812        if control_response.status_code != 250 {
813            return Err(TorError::protocol_error(&format!(
814                "Expected status code 250, got {}",
815                control_response.status_code
816            )));
817        }
818
819        // Parse the controller response
820        lazy_static! {
821            static ref RE: Regex =
822                Regex::new(r"(?m)^ServiceID=(?P<service_id>.*)\n(PrivateKey=(?P<key_type>[^:]*):(?<key_blob>.*)$)?$")
823                    .unwrap();
824        }
825        match RE.captures(&control_response.reply) {
826            Some(captures) => parse_add_onion_response(&captures, ports, signing_key),
827            None => Err(TorError::ProtocolError(format!(
828                "Unexpected response: {} {}",
829                control_response.status_code, control_response.reply,
830            ))),
831        }
832    }
833
834    pub async fn delete_onion_service(&mut self, service_id: &str) -> Result<(), TorError> {
835        // Just in case someone passes in the ".onion" part
836        let service_id_string = service_id.replace(".onion", "");
837
838        // Send command to Tor controller
839        let control_response = self
840            .send_command("DEL_ONION", Some(&service_id_string))
841            .await?;
842        info!(
843            "Sent DEL_ONION command, got control response {:?}",
844            control_response
845        );
846
847        if control_response.status_code != 250 {
848            Err(TorError::protocol_error(&format!(
849                "Expected status code 250, got {}",
850                control_response.status_code
851            )))
852        } else {
853            Ok(())
854        }
855    }
856}
857
858#[cfg(test)]
859mod tests {
860    use super::*;
861    use futures::SinkExt;
862    use tokio;
863    use tokio::net::{TcpListener, TcpStream};
864    use tokio_util::codec::{Framed, LinesCodec};
865
866    async fn create_mock() -> Result<(TcpStream, TcpStream), Box<dyn std::error::Error>> {
867        let listener = TcpListener::bind("127.0.0.1:0").await?;
868        let addr = listener.local_addr()?;
869        let join_handle = tokio::spawn(async move { listener.accept().await.unwrap() });
870        let client = TcpStream::connect(addr).await?;
871        let (server_stream, _) = join_handle.await?;
872
873        Ok((client, server_stream))
874    }
875
876    async fn create_framed_mock() -> Result<
877        (Framed<TcpStream, LinesCodec>, Framed<TcpStream, LinesCodec>),
878        Box<dyn std::error::Error>,
879    > {
880        let (client, server) = create_mock().await?;
881        let reader = Framed::new(client, LinesCodec::new());
882        let server = Framed::new(server, LinesCodec::new());
883
884        Ok((reader, server))
885    }
886
887    #[tokio::test]
888    async fn test_read_good_control_response() -> Result<(), Box<dyn std::error::Error>> {
889        // 250 OK response
890        let (mut client, mut server) = create_framed_mock().await?;
891        server.send("250 OK").await?;
892        let result = read_control_response(&mut client).await;
893        assert!(result.is_ok());
894        let control_response = result.unwrap();
895        assert_eq!(250, control_response.status_code);
896        assert_eq!("OK", control_response.reply);
897
898        Ok(())
899    }
900
901    #[tokio::test]
902    async fn test_read_garbled_control_response() -> Result<(), Box<dyn std::error::Error>> {
903        // garbled response
904        let (mut client, mut server) = create_framed_mock().await?;
905        server.send("idon'tknowwhatthisis").await?;
906        let result = read_control_response(&mut client).await;
907        assert!(result.is_err());
908        match result.err() {
909            Some(TorError::ProtocolError(_)) => assert!(true),
910            _ => assert!(false),
911        }
912
913        // Multiline response
914        let (mut client, mut server) = create_framed_mock().await?;
915        server
916            .send("250-ServiceID=647qjf6w3evdbdpy7oidf5vda6rsjzsl5a6ofsaou2v77hj7dmn2spqd")
917            .await?;
918        server.send("250-PrivateKey=ED25519-V3:yLSDc8b11PaIHTtNtvi9lNW99IME2mdrO4k381zDkHv//WRUGrkBALBQ9MbHy2SLA/NmfS7YxmcR/FY8ppRfIA==").await?;
919        server.send("250 OK").await?;
920        let result = read_control_response(&mut client).await;
921        assert!(result.is_ok());
922        let control_response = result.unwrap();
923        assert_eq!(250, control_response.status_code);
924        assert_eq!(
925            "ServiceID=647qjf6w3evdbdpy7oidf5vda6rsjzsl5a6ofsaou2v77hj7dmn2spqd\nPrivateKey=ED25519-V3:yLSDc8b11PaIHTtNtvi9lNW99IME2mdrO4k381zDkHv//WRUGrkBALBQ9MbHy2SLA/NmfS7YxmcR/FY8ppRfIA==\n",
926            control_response.reply);
927
928        Ok(())
929    }
930
931    #[tokio::test]
932    async fn test_read_data_control_response() -> Result<(), Box<dyn std::error::Error>> {
933        // Data response
934        let (mut client, mut server) = create_framed_mock().await?;
935        server.send("250+onions/current=").await?;
936        server
937            .send("647qjf6w3evdbdpy7oidf5vda6rsjzsl5a6ofsaou2v77hj7dmn2spqd")
938            .await?;
939        server
940            .send("yxq7fa63tthq3nd2ul52jjcdpblyai6k3cfmdkyw23ljsoob66z3ywid")
941            .await?;
942        server.send(".").await?;
943        server.send("250 OK").await?;
944        let result = read_control_response(&mut client).await;
945        assert!(result.is_ok());
946        let control_response = result.unwrap();
947        assert_eq!(250, control_response.status_code);
948        assert_eq!("onions/current=\n647qjf6w3evdbdpy7oidf5vda6rsjzsl5a6ofsaou2v77hj7dmn2spqd\nyxq7fa63tthq3nd2ul52jjcdpblyai6k3cfmdkyw23ljsoob66z3ywid\n",
949            control_response.reply,
950        );
951
952        Ok(())
953    }
954
955    #[tokio::test]
956    async fn test_authenticate() -> Result<(), Box<dyn std::error::Error>> {
957        let (client, server) = create_mock().await?;
958        let mut server = Framed::new(server, LinesCodec::new());
959        server
960            .send("250-PROTOCOLINFO 1\n250-AUTH METHODS=NULL\n250-VERSION Tor=1\n250 OK")
961            .await?;
962        server.send("250 OK").await?;
963        let mut tor = TorControlConnection::with_stream(client)?;
964        let result = tor.authenticate(&TorAuthentication::Null).await;
965        assert!(result.is_ok());
966
967        let (client, server) = create_mock().await?;
968        let mut server = Framed::new(server, LinesCodec::new());
969        server.send("551 Oops").await?;
970        let mut tor = TorControlConnection::with_stream(client)?;
971        let result = tor.authenticate(&TorAuthentication::Null).await;
972        assert!(result.is_err());
973
974        Ok(())
975    }
976
977    #[tokio::test]
978    async fn test_create_onion_service() -> Result<(), Box<dyn std::error::Error>> {
979        let (client, server) = create_mock().await?;
980        let mut server = Framed::new(server, LinesCodec::new());
981        server
982            .send("250-ServiceID=vvqbbaknxi6w44t6rplzh7nmesfzw3rjujdijpqsu5xl3nhlkdscgqad")
983            .await?;
984        server
985            .send("250-PrivateKey=ED25519-V3:0H/jnBeWzMoU1MGNRQPnmd8JqlpTNS3UeTiDOMyPTGGXXpLd0KinCtQbcgz2fCYjbzfK3ElJ7x3zGCkB1fAtAA==")
986            .await?;
987        server.send("250 OK").await?;
988        let mut tor = TorControlConnection::with_stream(client)?;
989        let onion_service = tor
990            .create_onion_service(&[OnionServiceMapping::new(8080, None)], true, None)
991            .await?;
992        assert_eq!(8080, onion_service.ports[0].virt_port);
993        assert_eq!(
994            TorSocketAddr::from_tcp_string("127.0.0.1:8080"),
995            Ok(onion_service.ports[0].clone().listen_address)
996        );
997        assert_eq!(
998            "vvqbbaknxi6w44t6rplzh7nmesfzw3rjujdijpqsu5xl3nhlkdscgqad",
999            onion_service.service_id.as_str()
1000        );
1001        assert_eq!(
1002            OnionAddress::from_str(
1003                "vvqbbaknxi6w44t6rplzh7nmesfzw3rjujdijpqsu5xl3nhlkdscgqad.onion:8080"
1004            )?,
1005            onion_service.onion_address(8080).unwrap()
1006        );
1007        Ok(())
1008    }
1009
1010    #[tokio::test]
1011    async fn test_get_protocol_info() -> Result<(), Box<dyn std::error::Error>> {
1012        let (client, server) = create_mock().await?;
1013        let mut server = Framed::new(server, LinesCodec::new());
1014        server.send("250-PROTOCOLINFO 1").await?;
1015        server.send("250-AUTH METHODS=NULL,FOO").await?;
1016        server.send("250-VERSION Tor=\"0.4.7.13\"").await?;
1017        server.send("250 OK").await?;
1018        let mut tor = TorControlConnection::with_stream(client)?;
1019        tor.get_protocol_info().await?;
1020
1021        Ok(())
1022    }
1023
1024    #[test]
1025    fn test_parse_onion_address() -> Result<(), Box<dyn std::error::Error>> {
1026        let address = OnionAddress::from_str(
1027            "647qjf6w3evdbdpy7oidf5vda6rsjzsl5a6ofsaou2v77hj7dmn2spqd.onion:80",
1028        )?;
1029        assert_eq!(
1030            TorServiceId::from_str("647qjf6w3evdbdpy7oidf5vda6rsjzsl5a6ofsaou2v77hj7dmn2spqd")?,
1031            address.service_id
1032        );
1033        assert_eq!(80, address.service_port);
1034
1035        if let Ok(_) = OnionAddress::from_str("foobar:27") {
1036            assert!(false);
1037        }
1038
1039        if let Ok(_) = OnionAddress::from_str(
1040            "647qjf6w3evdbdpy7oidf5vda6rsjzsl5a6ofsaou2v77hj7dmn2spqd.onion:abcd",
1041        ) {
1042            assert!(false);
1043        }
1044
1045        Ok(())
1046    }
1047}