1#![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)]
8
9use std::{
10 convert::TryInto,
11 fmt, io,
12 net::{IpAddr, Ipv4Addr, SocketAddr},
13 str::FromStr,
14};
15use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
16
17pub use error::{Error, Result};
18
19mod error;
20#[cfg(feature = "sync")]
21pub mod sync;
23
24#[derive(Debug, PartialEq, Eq)]
38pub enum Version {
39 V5,
41}
42impl Version {
43 pub async fn read(mut reader: impl AsyncRead + Unpin) -> Result<Version> {
45 let version = &mut [0u8];
46 reader.read_exact(version).await?;
47 match version[0] {
48 5 => Ok(Version::V5),
49 other => Err(Error::InvalidVersion(other)),
50 }
51 }
52 pub async fn write(&self, mut writer: impl AsyncWrite + Unpin) -> Result<()> {
54 let v = match self {
55 Version::V5 => 5u8,
56 };
57 writer.write_all(&[v]).await?;
58 Ok(())
59 }
60}
61
62#[derive(Debug, Eq, PartialEq, Clone, Copy)]
64pub enum AuthMethod {
65 Noauth,
67 Gssapi,
69 UsernamePassword,
71 NoAcceptableMethod,
73 Other(u8),
75}
76
77impl From<u8> for AuthMethod {
78 fn from(n: u8) -> Self {
79 match n {
80 0x00 => AuthMethod::Noauth,
81 0x01 => AuthMethod::Gssapi,
82 0x02 => AuthMethod::UsernamePassword,
83 0xff => AuthMethod::NoAcceptableMethod,
84 other => AuthMethod::Other(other),
85 }
86 }
87}
88
89impl Into<u8> for AuthMethod {
90 fn into(self) -> u8 {
91 match self {
92 AuthMethod::Noauth => 0x00,
93 AuthMethod::Gssapi => 0x01,
94 AuthMethod::UsernamePassword => 0x02,
95 AuthMethod::NoAcceptableMethod => 0xff,
96 AuthMethod::Other(other) => other,
97 }
98 }
99}
100
101#[derive(Debug, Eq, PartialEq, Clone)]
111pub struct AuthRequest(pub Vec<AuthMethod>);
112
113impl AuthRequest {
114 pub fn new(methods: impl Into<Vec<AuthMethod>>) -> AuthRequest {
116 AuthRequest(methods.into())
117 }
118 pub async fn read(mut reader: impl AsyncRead + Unpin) -> Result<AuthRequest> {
120 let count = &mut [0u8];
121 reader.read_exact(count).await?;
122 let mut methods = vec![0u8; count[0] as usize];
123 reader.read_exact(&mut methods).await?;
124
125 Ok(AuthRequest(methods.into_iter().map(Into::into).collect()))
126 }
127 pub async fn write(&self, mut writer: impl AsyncWrite + Unpin) -> Result<()> {
129 let count = self.0.len();
130 if count > 255 {
131 return Err(Error::TooManyMethods);
132 }
133
134 writer.write_all(&[count as u8]).await?;
135 writer
136 .write_all(
137 &self
138 .0
139 .iter()
140 .map(|i| Into::<u8>::into(*i))
141 .collect::<Vec<_>>(),
142 )
143 .await?;
144
145 Ok(())
146 }
147 pub fn select_from(&self, auth: &[AuthMethod]) -> AuthMethod {
149 self.0
150 .iter()
151 .enumerate()
152 .find(|(_, m)| auth.contains(*m))
153 .map(|(v, _)| AuthMethod::from(v as u8))
154 .unwrap_or(AuthMethod::NoAcceptableMethod)
155 }
156}
157
158#[derive(Debug, Eq, PartialEq, Clone)]
168pub struct AuthResponse(AuthMethod);
169
170impl AuthResponse {
171 pub fn new(method: AuthMethod) -> AuthResponse {
173 AuthResponse(method)
174 }
175 pub async fn read(mut reader: impl AsyncRead + Unpin) -> Result<AuthResponse> {
177 let method = &mut [0u8];
178 reader.read_exact(method).await?;
179 Ok(AuthResponse(method[0].into()))
180 }
181 pub async fn write(&self, mut writer: impl AsyncWrite + Unpin) -> Result<()> {
183 writer.write_all(&[self.0.into()]).await?;
184 Ok(())
185 }
186 pub fn method(&self) -> AuthMethod {
188 self.0
189 }
190}
191
192#[derive(Debug)]
196pub enum Command {
197 Connect,
199 Bind,
201 UdpAssociate,
203}
204
205#[derive(Debug)]
215pub struct CommandRequest {
216 pub command: Command,
218 pub address: Address,
220}
221
222impl CommandRequest {
223 pub fn connect(address: Address) -> CommandRequest {
225 CommandRequest {
226 command: Command::Connect,
227 address,
228 }
229 }
230 pub fn udp_associate(address: Address) -> CommandRequest {
232 CommandRequest {
233 command: Command::UdpAssociate,
234 address,
235 }
236 }
237 pub async fn read(mut reader: impl AsyncRead + Unpin) -> Result<CommandRequest> {
239 let buf = &mut [0u8; 3];
240 reader.read_exact(buf).await?;
241 if buf[0] != 5 {
242 return Err(Error::InvalidVersion(buf[0]));
243 }
244 if buf[2] != 0 {
245 return Err(Error::InvalidHandshake);
246 }
247 let cmd = match buf[1] {
248 1 => Command::Connect,
249 2 => Command::Bind,
250 3 => Command::UdpAssociate,
251 _ => return Err(Error::InvalidCommand(buf[1])),
252 };
253
254 let address = Address::read(reader).await?;
255
256 Ok(CommandRequest {
257 command: cmd,
258 address,
259 })
260 }
261 pub async fn write(&self, mut writer: impl AsyncWrite + Unpin) -> Result<()> {
263 let cmd = match self.command {
264 Command::Connect => 1u8,
265 Command::Bind => 2,
266 Command::UdpAssociate => 3,
267 };
268 writer.write_all(&[0x05, cmd, 0x00]).await?;
269 self.address.write(writer).await?;
270 Ok(())
271 }
272}
273
274#[derive(Debug, PartialEq, PartialOrd)]
276pub enum CommandReply {
277 Succeeded,
279 GeneralSocksServerFailure,
281 ConnectionNotAllowedByRuleset,
283 NetworkUnreachable,
285 HostUnreachable,
287 ConnectionRefused,
289 TtlExpired,
291 CommandNotSupported,
293 AddressTypeNotSupported,
295}
296
297impl CommandReply {
298 pub fn from_u8(n: u8) -> Result<CommandReply> {
300 Ok(match n {
301 0 => CommandReply::Succeeded,
302 1 => CommandReply::GeneralSocksServerFailure,
303 2 => CommandReply::ConnectionNotAllowedByRuleset,
304 3 => CommandReply::NetworkUnreachable,
305 4 => CommandReply::HostUnreachable,
306 5 => CommandReply::ConnectionRefused,
307 6 => CommandReply::TtlExpired,
308 7 => CommandReply::CommandNotSupported,
309 8 => CommandReply::AddressTypeNotSupported,
310 _ => return Err(Error::InvalidCommandReply(n)),
311 })
312 }
313 pub fn to_u8(&self) -> u8 {
315 match self {
316 CommandReply::Succeeded => 0,
317 CommandReply::GeneralSocksServerFailure => 1,
318 CommandReply::ConnectionNotAllowedByRuleset => 2,
319 CommandReply::NetworkUnreachable => 3,
320 CommandReply::HostUnreachable => 4,
321 CommandReply::ConnectionRefused => 5,
322 CommandReply::TtlExpired => 6,
323 CommandReply::CommandNotSupported => 7,
324 CommandReply::AddressTypeNotSupported => 8,
325 }
326 }
327}
328
329#[derive(Debug)]
339pub struct CommandResponse {
340 pub reply: CommandReply,
342 pub address: Address,
344}
345
346impl CommandResponse {
347 pub fn success(address: Address) -> CommandResponse {
349 CommandResponse {
350 reply: CommandReply::Succeeded,
351 address,
352 }
353 }
354 pub fn reply_error(reply: CommandReply) -> CommandResponse {
356 CommandResponse {
357 reply,
358 address: Default::default(),
359 }
360 }
361 pub fn error(e: impl TryInto<io::Error>) -> CommandResponse {
363 match e.try_into() {
364 Ok(v) => {
365 use io::ErrorKind;
366 let reply = match v.kind() {
367 ErrorKind::ConnectionRefused => CommandReply::ConnectionRefused,
368 _ => CommandReply::GeneralSocksServerFailure,
369 };
370 CommandResponse {
371 reply,
372 address: Default::default(),
373 }
374 }
375 Err(_) => CommandResponse {
376 reply: CommandReply::GeneralSocksServerFailure,
377 address: Default::default(),
378 },
379 }
380 }
381 pub async fn read(mut reader: impl AsyncRead + Unpin) -> Result<CommandResponse> {
383 let buf = &mut [0u8; 3];
384 reader.read_exact(buf).await?;
385 if buf[0] != 5 {
386 return Err(Error::InvalidVersion(buf[0]));
387 }
388 if buf[2] != 0 {
389 return Err(Error::InvalidHandshake);
390 }
391 let reply = CommandReply::from_u8(buf[1])?;
392
393 let address = Address::read(reader).await?;
394
395 if reply != CommandReply::Succeeded {
396 return Err(Error::CommandReply(reply));
397 }
398
399 Ok(CommandResponse { reply, address })
400 }
401 pub async fn write(&self, mut writer: impl AsyncWrite + Unpin) -> Result<()> {
403 writer.write_all(&[0x05, self.reply.to_u8(), 0x00]).await?;
404 self.address.write(writer).await?;
405 Ok(())
406 }
407}
408
409#[derive(Debug)]
411pub enum Address {
412 SocketAddr(SocketAddr),
414 Domain(String, u16),
416}
417
418impl fmt::Display for Address {
419 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
420 match self {
421 Address::SocketAddr(s) => fmt::Display::fmt(s, f),
422 Address::Domain(domain, port) => write!(f, "{}:{}", domain, port),
423 }
424 }
425}
426
427impl Default for Address {
428 fn default() -> Self {
429 Address::SocketAddr(SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0))
430 }
431}
432
433impl From<SocketAddr> for Address {
434 fn from(addr: SocketAddr) -> Self {
435 Address::SocketAddr(addr)
436 }
437}
438
439fn strip_brackets(host: &str) -> &str {
440 host.strip_prefix('[')
441 .and_then(|h| h.strip_suffix(']'))
442 .unwrap_or(host)
443}
444
445fn host_to_address(host: &str, port: u16) -> Address {
446 match strip_brackets(host).parse::<IpAddr>() {
447 Ok(ip) => {
448 let addr = SocketAddr::new(ip, port);
449 addr.into()
450 }
451 Err(_) => Address::Domain(host.to_string(), port),
452 }
453}
454fn no_addr() -> io::Error {
455 io::ErrorKind::AddrNotAvailable.into()
456}
457
458impl FromStr for Address {
459 type Err = Error;
460
461 fn from_str(s: &str) -> Result<Self, Self::Err> {
462 let mut parts = s.rsplitn(2, ':');
463 let port: u16 = parts
464 .next()
465 .ok_or_else(no_addr)?
466 .parse()
467 .map_err(|_| no_addr())?;
468 let host = parts.next().ok_or_else(no_addr)?;
469 Ok(host_to_address(host, port))
470 }
471}
472
473impl Address {
474 pub fn to_socket_addr(self) -> Result<SocketAddr> {
476 match self {
477 Address::SocketAddr(s) => Ok(s),
478 _ => Err(Error::Io(io::ErrorKind::InvalidInput.into())),
479 }
480 }
481 async fn read_port<R>(mut reader: R) -> Result<u16>
482 where
483 R: AsyncRead + Unpin,
484 {
485 let mut buf = [0u8; 2];
486 reader.read_exact(&mut buf).await?;
487 let port = u16::from_be_bytes(buf);
488 Ok(port)
489 }
490 async fn write_port<W>(mut writer: W, port: u16) -> Result<()>
491 where
492 W: AsyncWrite + Unpin,
493 {
494 writer.write_all(&port.to_be_bytes()).await?;
495 Ok(())
496 }
497 pub fn serialized_len(&self) -> Result<usize> {
499 Ok(match self {
500 Address::SocketAddr(SocketAddr::V4(_)) => {
501 1 + 4 + 2
503 }
504 Address::SocketAddr(SocketAddr::V6(_)) => {
505 1 + 16 + 2
507 }
508 Address::Domain(domain, _) => {
509 if domain.len() >= 256 {
510 return Err(Error::DomainTooLong(domain.len()));
511 }
512 1 + 1 + domain.len() + 2
514 }
515 })
516 }
517 pub async fn write<W>(&self, mut writer: W) -> Result<()>
519 where
520 W: AsyncWrite + Unpin,
521 {
522 match self {
523 Address::SocketAddr(SocketAddr::V4(addr)) => {
524 writer.write_all(&[0x01]).await?;
525 writer.write_all(&addr.ip().octets()).await?;
526 Self::write_port(writer, addr.port()).await?;
527 }
528 Address::SocketAddr(SocketAddr::V6(addr)) => {
529 writer.write_all(&[0x04]).await?;
530 writer.write_all(&addr.ip().octets()).await?;
531 Self::write_port(writer, addr.port()).await?;
532 }
533 Address::Domain(domain, port) => {
534 if domain.len() >= 256 {
535 return Err(Error::DomainTooLong(domain.len()));
536 }
537 let header = [0x03, domain.len() as u8];
538 writer.write_all(&header).await?;
539 writer.write_all(domain.as_bytes()).await?;
540 Self::write_port(writer, *port).await?;
541 }
542 };
543 Ok(())
544 }
545 pub async fn read<R>(mut reader: R) -> Result<Self>
547 where
548 R: AsyncRead + Unpin,
549 {
550 let mut atyp = [0u8; 1];
551 reader.read_exact(&mut atyp).await?;
552
553 Ok(match atyp[0] {
554 1 => {
555 let mut ip = [0u8; 4];
556 reader.read_exact(&mut ip).await?;
557 Address::SocketAddr(SocketAddr::new(
558 ip.into(),
559 Self::read_port(&mut reader).await?,
560 ))
561 }
562 3 => {
563 let mut len = [0u8; 1];
564 reader.read_exact(&mut len).await?;
565 let len = len[0] as usize;
566 let mut domain = vec![0u8; len];
567 reader.read_exact(&mut domain).await?;
568
569 let domain =
570 String::from_utf8(domain).map_err(|e| Error::InvalidDomain(e.into_bytes()))?;
571
572 Address::Domain(domain, Self::read_port(&mut reader).await?)
573 }
574 4 => {
575 let mut ip = [0u8; 16];
576 reader.read_exact(&mut ip).await?;
577 Address::SocketAddr(SocketAddr::new(
578 ip.into(),
579 Self::read_port(&mut reader).await?,
580 ))
581 }
582 _ => return Err(Error::InvalidAddressType(atyp[0])),
583 })
584 }
585}
586
587#[cfg(test)]
588mod tests {
589 use super::*;
590
591 #[test]
592 fn test_address_display() {
593 let addr = Address::SocketAddr("1.2.3.4:56789".parse().unwrap());
594 assert_eq!(addr.to_string(), "1.2.3.4:56789");
595
596 let addr = Address::Domain("example.com".to_string(), 80);
597 assert_eq!(addr.to_string(), "example.com:80");
598 }
599
600 #[test]
601 fn test_address_from_str() {
602 let addr: Address = "1.2.3.4:56789".parse().unwrap();
603 assert_eq!(addr.to_string(), "1.2.3.4:56789");
604
605 let addr: Address = "example.com:80".parse().unwrap();
606 assert_eq!(addr.to_string(), "example.com:80");
607
608 let addr: Result<Address, _> = "example.com".parse();
609 assert!(addr.is_err());
610 }
611
612 #[test]
613 fn test_address_serialized_len() {
614 let addr: Address = "1.2.3.4:56789".parse().unwrap();
615 assert_eq!(addr.serialized_len().unwrap(), 7);
616
617 let addr: Address = "[::1]:56789".parse().unwrap();
618 assert_eq!(addr.serialized_len().unwrap(), 19);
619
620 let addr: Address = "example.com:80".parse().unwrap();
621 assert_eq!(addr.serialized_len().unwrap(), 15);
622 }
623}