1use crate::{Error, Result};
4
5use caret::caret_int;
6use std::fmt;
7use std::net::IpAddr;
8
9#[cfg(feature = "arbitrary")]
10use std::net::Ipv6Addr;
11
12use tor_error::bad_api_usage;
13
14#[cfg(feature = "arbitrary")]
15use arbitrary::{Arbitrary, Result as ArbitraryResult, Unstructured};
16
17#[derive(Copy, Clone, Debug, Eq, PartialEq)]
19#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
20#[non_exhaustive]
21pub enum SocksVersion {
22 V4,
24 V5,
26}
27
28impl TryFrom<u8> for SocksVersion {
29 type Error = Error;
30 fn try_from(v: u8) -> Result<SocksVersion> {
31 match v {
32 4 => Ok(SocksVersion::V4),
33 5 => Ok(SocksVersion::V5),
34 _ => Err(Error::BadProtocol(v)),
35 }
36 }
37}
38
39impl fmt::Display for SocksVersion {
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 match self {
42 SocksVersion::V4 => write!(f, "socks4"),
43 SocksVersion::V5 => write!(f, "socks5"),
44 }
45 }
46}
47
48#[derive(Clone, Debug)]
54#[cfg_attr(test, derive(PartialEq, Eq))]
55pub struct SocksRequest {
56 version: SocksVersion,
58 cmd: SocksCmd,
60 addr: SocksAddr,
62 port: u16,
64 auth: SocksAuth,
69}
70
71#[cfg(feature = "arbitrary")]
72impl<'a> Arbitrary<'a> for SocksRequest {
73 fn arbitrary(u: &mut Unstructured<'a>) -> ArbitraryResult<Self> {
74 let version = SocksVersion::arbitrary(u)?;
75 let cmd = SocksCmd::arbitrary(u)?;
76 let addr = SocksAddr::arbitrary(u)?;
77 let port = u16::arbitrary(u)?;
78 let auth = SocksAuth::arbitrary(u)?;
79
80 SocksRequest::new(version, cmd, addr, port, auth)
81 .map_err(|_| arbitrary::Error::IncorrectFormat)
82 }
83}
84
85#[derive(Clone, Debug, PartialEq, Eq)]
87#[allow(clippy::exhaustive_enums)]
88pub enum SocksAddr {
89 Hostname(SocksHostname),
91 Ip(IpAddr),
95}
96
97#[cfg(feature = "arbitrary")]
98impl<'a> Arbitrary<'a> for SocksAddr {
99 fn arbitrary(u: &mut Unstructured<'a>) -> ArbitraryResult<Self> {
100 use std::net::Ipv4Addr;
101 let b = u8::arbitrary(u)?;
102 Ok(match b % 3 {
103 0 => SocksAddr::Hostname(SocksHostname::arbitrary(u)?),
104 1 => SocksAddr::Ip(IpAddr::V4(Ipv4Addr::arbitrary(u)?)),
105 _ => SocksAddr::Ip(IpAddr::V6(Ipv6Addr::arbitrary(u)?)),
106 })
107 }
108 fn size_hint(_depth: usize) -> (usize, Option<usize>) {
109 (1, Some(256))
110 }
111}
112
113#[derive(Clone, Debug, PartialEq, Eq)]
115pub struct SocksHostname(String);
116
117#[cfg(feature = "arbitrary")]
118impl<'a> Arbitrary<'a> for SocksHostname {
119 fn arbitrary(u: &mut Unstructured<'a>) -> ArbitraryResult<Self> {
120 String::arbitrary(u)?
121 .try_into()
122 .map_err(|_| arbitrary::Error::IncorrectFormat)
123 }
124 fn size_hint(_depth: usize) -> (usize, Option<usize>) {
125 (0, Some(255))
126 }
127}
128
129#[derive(Clone, Debug, PartialEq, Eq, Hash)]
131#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
132#[non_exhaustive]
133pub enum SocksAuth {
134 NoAuth,
136 Socks4(Vec<u8>),
138 Username(Vec<u8>, Vec<u8>),
140}
141
142caret_int! {
143 #[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
145 pub struct SocksCmd(u8) {
146 CONNECT = 1,
148 BIND = 2,
150 UDP_ASSOCIATE = 3,
152
153 RESOLVE = 0xF0,
155 RESOLVE_PTR = 0xF1,
157 }
158}
159
160caret_int! {
161 #[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
167 pub struct SocksStatus(u8) {
168 SUCCEEDED = 0x00,
170 GENERAL_FAILURE = 0x01,
172 NOT_ALLOWED = 0x02,
177 NETWORK_UNREACHABLE = 0x03,
179 HOST_UNREACHABLE = 0x04,
181 CONNECTION_REFUSED = 0x05,
183 TTL_EXPIRED = 0x06,
187 COMMAND_NOT_SUPPORTED = 0x07,
189 ADDRTYPE_NOT_SUPPORTED = 0x08,
191 HS_DESC_NOT_FOUND = 0xF0,
193 HS_DESC_INVALID = 0xF1,
195 HS_INTRO_FAILED = 0xF2,
197 HS_REND_FAILED = 0xF3,
199 HS_MISSING_CLIENT_AUTH = 0xF4,
201 HS_WRONG_CLIENT_AUTH = 0xF5,
203 HS_BAD_ADDRESS = 0xF6,
207 HS_INTRO_TIMEOUT = 0xF7
211 }
212}
213
214impl SocksCmd {
215 fn recognized(self) -> bool {
217 matches!(
218 self,
219 SocksCmd::CONNECT | SocksCmd::RESOLVE | SocksCmd::RESOLVE_PTR
220 )
221 }
222
223 fn requires_port(self) -> bool {
225 matches!(
226 self,
227 SocksCmd::CONNECT | SocksCmd::BIND | SocksCmd::UDP_ASSOCIATE
228 )
229 }
230}
231
232impl SocksStatus {
233 #[cfg(feature = "proxy-handshake")]
235 pub(crate) fn into_socks4_status(self) -> u8 {
236 match self {
237 SocksStatus::SUCCEEDED => 0x5A,
238 _ => 0x5B,
239 }
240 }
241 #[cfg(feature = "client-handshake")]
243 pub(crate) fn from_socks4_status(status: u8) -> Self {
244 match status {
245 0x5A => SocksStatus::SUCCEEDED,
246 0x5B => SocksStatus::GENERAL_FAILURE,
247 0x5C | 0x5D => SocksStatus::NOT_ALLOWED,
248 _ => SocksStatus::GENERAL_FAILURE,
249 }
250 }
251}
252
253impl TryFrom<String> for SocksHostname {
254 type Error = Error;
255 fn try_from(s: String) -> Result<SocksHostname> {
256 if s.len() > 255 {
257 Err(bad_api_usage!("hostname too long").into())
260 } else if contains_zeros(s.as_bytes()) {
261 Err(Error::Syntax)
264 } else {
265 Ok(SocksHostname(s))
266 }
267 }
268}
269
270impl AsRef<str> for SocksHostname {
271 fn as_ref(&self) -> &str {
272 self.0.as_ref()
273 }
274}
275
276impl SocksAuth {
277 fn validate(&self, version: SocksVersion) -> Result<()> {
282 match self {
283 SocksAuth::NoAuth => {}
284 SocksAuth::Socks4(data) => {
285 if version != SocksVersion::V4 || contains_zeros(data) {
286 return Err(Error::Syntax);
287 }
288 }
289 SocksAuth::Username(user, pass) => {
290 if version != SocksVersion::V5
291 || user.len() > u8::MAX as usize
292 || pass.len() > u8::MAX as usize
293 {
294 return Err(Error::Syntax);
295 }
296 }
297 }
298 Ok(())
299 }
300}
301
302fn contains_zeros(b: &[u8]) -> bool {
306 use subtle::{Choice, ConstantTimeEq};
307 let c: Choice = b
308 .iter()
309 .fold(Choice::from(0), |seen_any, byte| seen_any | byte.ct_eq(&0));
310 c.unwrap_u8() != 0
311}
312
313impl SocksRequest {
314 pub fn new(
318 version: SocksVersion,
319 cmd: SocksCmd,
320 addr: SocksAddr,
321 port: u16,
322 auth: SocksAuth,
323 ) -> Result<Self> {
324 if !cmd.recognized() {
325 return Err(Error::NotImplemented(
326 format!("SOCKS command {}", cmd).into(),
327 ));
328 }
329 if port == 0 && cmd.requires_port() {
330 return Err(Error::Syntax);
331 }
332 auth.validate(version)?;
333
334 Ok(SocksRequest {
335 version,
336 cmd,
337 addr,
338 port,
339 auth,
340 })
341 }
342
343 pub fn version(&self) -> SocksVersion {
345 self.version
346 }
347
348 pub fn command(&self) -> SocksCmd {
350 self.cmd
351 }
352
353 pub fn auth(&self) -> &SocksAuth {
355 &self.auth
356 }
357
358 pub fn port(&self) -> u16 {
360 self.port
361 }
362
363 pub fn addr(&self) -> &SocksAddr {
365 &self.addr
366 }
367}
368
369impl fmt::Display for SocksAddr {
370 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
373 match self {
374 SocksAddr::Ip(a) => write!(f, "{}", a),
375 SocksAddr::Hostname(h) => write!(f, "{}", h.0),
376 }
377 }
378}
379
380#[derive(Debug, Clone)]
382pub struct SocksReply {
383 status: SocksStatus,
385 addr: SocksAddr,
387 port: u16,
389}
390
391impl SocksReply {
392 #[cfg(feature = "client-handshake")]
394 pub(crate) fn new(status: SocksStatus, addr: SocksAddr, port: u16) -> Self {
395 Self { status, addr, port }
396 }
397
398 pub fn status(&self) -> SocksStatus {
400 self.status
401 }
402
403 pub fn addr(&self) -> &SocksAddr {
411 &self.addr
412 }
413
414 pub fn port(&self) -> u16 {
419 self.port
420 }
421}
422
423#[cfg(test)]
424mod test {
425 #![allow(clippy::bool_assert_comparison)]
427 #![allow(clippy::clone_on_copy)]
428 #![allow(clippy::dbg_macro)]
429 #![allow(clippy::mixed_attributes_style)]
430 #![allow(clippy::print_stderr)]
431 #![allow(clippy::print_stdout)]
432 #![allow(clippy::single_char_pattern)]
433 #![allow(clippy::unwrap_used)]
434 #![allow(clippy::unchecked_time_subtraction)]
435 #![allow(clippy::useless_vec)]
436 #![allow(clippy::needless_pass_by_value)]
437 use super::*;
439
440 #[test]
441 fn display_sa() {
442 let a = SocksAddr::Ip(IpAddr::V4("127.0.0.1".parse().unwrap()));
443 assert_eq!(a.to_string(), "127.0.0.1");
444
445 let a = SocksAddr::Ip(IpAddr::V6("f00::9999".parse().unwrap()));
446 assert_eq!(a.to_string(), "f00::9999");
447
448 let a = SocksAddr::Hostname("www.torproject.org".to_string().try_into().unwrap());
449 assert_eq!(a.to_string(), "www.torproject.org");
450 }
451
452 #[test]
453 fn ok_request() {
454 let localhost_v4 = SocksAddr::Ip(IpAddr::V4("127.0.0.1".parse().unwrap()));
455 let r = SocksRequest::new(
456 SocksVersion::V4,
457 SocksCmd::CONNECT,
458 localhost_v4.clone(),
459 1024,
460 SocksAuth::NoAuth,
461 )
462 .unwrap();
463 assert_eq!(r.version(), SocksVersion::V4);
464 assert_eq!(r.command(), SocksCmd::CONNECT);
465 assert_eq!(r.addr(), &localhost_v4);
466 assert_eq!(r.auth(), &SocksAuth::NoAuth);
467 }
468
469 #[test]
470 fn bad_request() {
471 let localhost_v4 = SocksAddr::Ip(IpAddr::V4("127.0.0.1".parse().unwrap()));
472
473 let e = SocksRequest::new(
474 SocksVersion::V4,
475 SocksCmd::BIND,
476 localhost_v4.clone(),
477 1024,
478 SocksAuth::NoAuth,
479 );
480 assert!(matches!(e, Err(Error::NotImplemented(_))));
481
482 let e = SocksRequest::new(
483 SocksVersion::V4,
484 SocksCmd::CONNECT,
485 localhost_v4,
486 0,
487 SocksAuth::NoAuth,
488 );
489 assert!(matches!(e, Err(Error::Syntax)));
490 }
491
492 #[test]
493 fn test_contains_zeros() {
494 assert!(contains_zeros(b"Hello\0world"));
495 assert!(!contains_zeros(b"Hello world"));
496 }
497}