1use std::{
2 fmt::{self, Display, Formatter},
3 io::{Error, ErrorKind, Result},
4};
5
6use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
7
8use crate::{Address, AuthMethod};
9
10const SOCKS5_VER: u8 = 0x05;
11const SOCKS5_AUTH_VER: u8 = 0x01;
12
13#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
22pub enum Socks5Command {
23 #[allow(missing_docs)]
24 Connect = 0x01,
25 #[allow(missing_docs)]
26 Bind = 0x02,
27 #[allow(missing_docs)]
28 UdpAssociate = 0x03,
29}
30
31impl TryFrom<u8> for Socks5Command {
32 type Error = Error;
33
34 fn try_from(value: u8) -> Result<Self> {
35 match value {
36 0x01 => Ok(Socks5Command::Connect),
37 0x02 => Ok(Socks5Command::Bind),
38 0x03 => Ok(Socks5Command::UdpAssociate),
39 _ => Err(Socks5Error::InvalidCommand.into()),
40 }
41 }
42}
43
44#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
52pub enum Socks5Reply {
53 #[allow(missing_docs)]
54 Succeeded = 0x00,
55 #[allow(missing_docs)]
56 GeneralFailure = 0x01,
57 #[allow(missing_docs)]
58 ConnectionNotAllowed = 0x02,
59 #[allow(missing_docs)]
60 NetworkUnreachable = 0x03,
61 #[allow(missing_docs)]
62 HostUnreachable = 0x04,
63 #[allow(missing_docs)]
64 ConnectionRefused = 0x05,
65 #[allow(missing_docs)]
66 TTLExpired = 0x06,
67 #[allow(missing_docs)]
68 CommandNotSupported = 0x07,
69 #[allow(missing_docs)]
70 AddressTypeNotSupported = 0x08,
71}
72
73impl TryFrom<u8> for Socks5Reply {
74 type Error = Error;
75
76 fn try_from(value: u8) -> Result<Self> {
77 match value {
78 0x00 => Ok(Socks5Reply::Succeeded),
79 0x01 => Ok(Socks5Reply::GeneralFailure),
80 0x02 => Ok(Socks5Reply::ConnectionNotAllowed),
81 0x03 => Ok(Socks5Reply::NetworkUnreachable),
82 0x04 => Ok(Socks5Reply::HostUnreachable),
83 0x05 => Ok(Socks5Reply::ConnectionRefused),
84 0x06 => Ok(Socks5Reply::TTLExpired),
85 0x07 => Ok(Socks5Reply::CommandNotSupported),
86 0x08 => Ok(Socks5Reply::AddressTypeNotSupported),
87 _ => Err(Socks5Error::InvalidReply.into()),
88 }
89 }
90}
91
92pub async fn socks5_accept<T>(
105 stream: &mut T,
106 auth_method: &AuthMethod,
107) -> Result<(Socks5Command, Address)>
108where
109 T: AsyncRead + AsyncWrite + Unpin,
110{
111 let client_auth_method = read_client_hello(stream).await?;
113
114 if !client_auth_method.contains(&auth_method.into()) {
115 write_server_hello(stream, &Socks5AuthOption::NoAcceptable).await?;
117 return Err(Socks5Error::NoAcceptableAuthMethod.into());
118 }
119
120 write_server_hello(stream, &auth_method.into()).await?;
122
123 match auth_method {
125 AuthMethod::NoAuth => (), AuthMethod::UserPass { username, password } => {
127 let auth = read_auth_request(stream).await?;
128 if &auth.username != username || &auth.password != password {
129 write_auth_response(stream, false).await?;
130 return Err(Socks5Error::AuthenticationFailed.into());
131 } else {
132 write_auth_response(stream, true).await?;
133 }
134 }
135 }
136
137 let (command, address) = read_connection_request(stream).await?;
139 Ok((command, address))
140}
141
142pub async fn socks5_finalize_accept<T>(
155 stream: &mut T,
156 reply: &Socks5Reply,
157 address: &Address,
158) -> Result<()>
159where
160 T: AsyncWrite + Unpin,
161{
162 write_connection_response(stream, reply, address).await?;
164
165 Ok(())
166}
167
168pub async fn socks5_connect<T>(
183 stream: &mut T,
184 command: &Socks5Command,
185 address: &Address,
186 auth: &[AuthMethod],
187) -> Result<Address>
188where
189 T: AsyncRead + AsyncWrite + Unpin,
190{
191 let client_auth_methods = auth.iter().map(|a| a.into()).collect::<Vec<_>>();
192 if client_auth_methods.len() > 255 {
193 return Err(Socks5Error::TooManyAuthMethods.into());
194 }
195
196 write_client_hello(stream, &client_auth_methods).await?;
198
199 let server_auth_method = read_server_hello(stream).await?;
201
202 let auth_method = match client_auth_methods
203 .iter()
204 .position(|c| c == &server_auth_method)
205 {
206 Some(i) => auth[i].clone(),
207 None => {
208 return Err(Socks5Error::NoAcceptableAuthMethod.into());
209 }
210 };
211
212 match auth_method {
214 AuthMethod::NoAuth => (), AuthMethod::UserPass { username, password } => {
216 write_auth_request(stream, &UserPassAuth { username, password }).await?;
217 read_auth_response(stream).await?;
218 }
219 }
220
221 write_connection_request(stream, command, address).await?;
223
224 let (reply, address) = read_connection_response(stream).await?;
226
227 if reply != Socks5Reply::Succeeded {
229 return Err(Socks5Error::ConnectionFailed.into());
230 }
231
232 Ok(address)
233}
234
235pub fn socks5_read_udp_header(buf: &[u8]) -> Result<(Address, usize)> {
257 let first = buf
258 .first_chunk::<3>()
259 .ok_or(Error::new(ErrorKind::UnexpectedEof, "buffer too short"))?;
260 if first != &[0, 0, 0] {
261 return Err(Error::new(ErrorKind::InvalidData, "invalid UDP header"));
262 }
263 let (address, len) = Address::decode_from_buf(&buf[3..])?;
264 Ok((address, 2 + 1 + len))
265}
266
267pub fn socks5_write_udp_header(address: &Address, buf: &mut [u8]) -> Result<usize> {
278 let first = buf
279 .first_chunk_mut::<3>()
280 .ok_or(Error::new(ErrorKind::UnexpectedEof, "buffer too short"))?;
281 *first = [0, 0, 0];
282 let len = address.encode_to_buf(&mut buf[3..])?;
283 Ok(2 + 1 + len)
284}
285
286#[derive(Copy, Clone, Debug, Eq, PartialEq)]
287enum Socks5AuthOption {
288 NoAuth = 0x00,
289 GssApi = 0x01,
290 UserPass = 0x02,
291 NoAcceptable = 0xFF,
292}
293
294impl TryFrom<u8> for Socks5AuthOption {
295 type Error = Error;
296
297 fn try_from(value: u8) -> Result<Self> {
298 match value {
299 0x00 => Ok(Socks5AuthOption::NoAuth),
300 0x01 => Ok(Socks5AuthOption::GssApi),
301 0x02 => Ok(Socks5AuthOption::UserPass),
302 0xFF => Ok(Socks5AuthOption::NoAcceptable),
303 _ => Err(Socks5Error::InvalidAuthMethod.into()),
304 }
305 }
306}
307
308impl From<&AuthMethod> for Socks5AuthOption {
309 fn from(value: &AuthMethod) -> Self {
310 match value {
311 AuthMethod::NoAuth => Socks5AuthOption::NoAuth,
312 AuthMethod::UserPass { .. } => Socks5AuthOption::UserPass,
313 }
314 }
315}
316
317#[derive(Debug)]
318struct UserPassAuth {
319 username: String,
320 password: String,
321}
322
323async fn read_client_hello<T>(reader: &mut T) -> Result<Vec<Socks5AuthOption>>
335where
336 T: AsyncRead + Unpin,
337{
338 let ver = reader.read_u8().await?;
340 if ver != SOCKS5_VER {
341 return Err(Socks5Error::InvalidSocksVersion.into());
342 }
343
344 let nmethods = reader.read_u8().await?;
346 if nmethods == 0 {
347 return Err(Socks5Error::NoAuthMethods.into());
348 }
349
350 let mut methods = Vec::with_capacity(nmethods as usize);
352 for _ in 0..nmethods {
353 let method_byte = reader.read_u8().await?;
354 match Socks5AuthOption::try_from(method_byte) {
355 Ok(method) => methods.push(method),
356 Err(_) => continue, }
358 }
359
360 if methods.is_empty() {
361 return Err(Socks5Error::NoSupportedAuthMethods.into());
362 }
363
364 Ok(methods)
365}
366
367async fn write_client_hello<T>(writer: &mut T, auth_method: &[Socks5AuthOption]) -> Result<()>
368where
369 T: AsyncWrite + Unpin,
370{
371 if auth_method.is_empty() {
372 return Err(Socks5Error::NoAuthMethods.into());
373 }
374
375 writer.write_u8(SOCKS5_VER).await?;
377
378 writer.write_u8(auth_method.len() as u8).await?;
380
381 for method in auth_method {
383 writer.write_u8(*method as u8).await?;
384 }
385
386 writer.flush().await?;
387 Ok(())
388}
389
390async fn read_server_hello<T>(reader: &mut T) -> Result<Socks5AuthOption>
401where
402 T: AsyncRead + Unpin,
403{
404 let ver = reader.read_u8().await?;
406 if ver != SOCKS5_VER {
407 return Err(Socks5Error::InvalidSocksVersion.into());
408 }
409
410 let method_byte = reader.read_u8().await?;
412 Socks5AuthOption::try_from(method_byte)
413}
414
415async fn write_server_hello<T>(writer: &mut T, auth_method: &Socks5AuthOption) -> Result<()>
416where
417 T: AsyncWrite + Unpin,
418{
419 writer.write_u8(SOCKS5_VER).await?;
421
422 writer.write_u8(*auth_method as u8).await?;
424
425 writer.flush().await?;
426 Ok(())
427}
428
429async fn read_auth_request<T>(reader: &mut T) -> Result<UserPassAuth>
443where
444 T: AsyncRead + Unpin,
445{
446 let ver = reader.read_u8().await?;
448 if ver != SOCKS5_AUTH_VER {
449 return Err(Socks5Error::InvalidAuthVersion.into());
450 }
451
452 let ulen = reader.read_u8().await? as usize;
454 let mut uname = vec![0u8; ulen];
455 reader.read_exact(&mut uname).await?;
456 let username = String::from_utf8(uname).map_err(|_| Socks5Error::InvalidUsernameEncoding)?;
457
458 let plen = reader.read_u8().await? as usize;
460 let mut passwd = vec![0u8; plen];
461 reader.read_exact(&mut passwd).await?;
462 let password = String::from_utf8(passwd).map_err(|_| Socks5Error::InvalidPasswordEncoding)?;
463
464 Ok(UserPassAuth { username, password })
465}
466
467async fn write_auth_request<T>(writer: &mut T, auth: &UserPassAuth) -> Result<()>
468where
469 T: AsyncWrite + Unpin,
470{
471 writer.write_u8(SOCKS5_AUTH_VER).await?;
473
474 let username_bytes = auth.username.as_bytes();
476 if username_bytes.len() > 255 {
477 return Err(Socks5Error::UsernameTooLong.into());
478 }
479 writer.write_u8(username_bytes.len() as u8).await?;
480 writer.write_all(username_bytes).await?;
481
482 let password_bytes = auth.password.as_bytes();
484 if password_bytes.len() > 255 {
485 return Err(Socks5Error::PasswordTooLong.into());
486 }
487 writer.write_u8(password_bytes.len() as u8).await?;
488 writer.write_all(password_bytes).await?;
489
490 writer.flush().await?;
491 Ok(())
492}
493
494async fn read_auth_response<T>(reader: &mut T) -> Result<()>
505where
506 T: AsyncRead + Unpin,
507{
508 let ver = reader.read_u8().await?;
510 if ver != SOCKS5_AUTH_VER {
511 return Err(Socks5Error::InvalidAuthVersion.into());
512 }
513
514 let status = reader.read_u8().await?;
516 if status != 0 {
517 return Err(Socks5Error::AuthenticationFailed.into());
518 }
519
520 Ok(())
521}
522
523async fn write_auth_response<T>(writer: &mut T, is_ok: bool) -> Result<()>
524where
525 T: AsyncWrite + Unpin,
526{
527 writer.write_u8(SOCKS5_AUTH_VER).await?;
529
530 writer.write_u8(if is_ok { 0 } else { 1 }).await?;
532
533 writer.flush().await?;
534 Ok(())
535}
536
537async fn read_connection_request<T>(reader: &mut T) -> Result<(Socks5Command, Address)>
552where
553 T: AsyncRead + Unpin,
554{
555 let ver = reader.read_u8().await?;
557 if ver != SOCKS5_VER {
558 return Err(Socks5Error::InvalidSocksVersion.into());
559 }
560
561 let cmd = Socks5Command::try_from(reader.read_u8().await?)?;
563
564 let rsv = reader.read_u8().await?;
566 if rsv != 0 {
567 return Err(Socks5Error::InvalidRsvValue.into());
568 }
569
570 let (address, _) = Address::decode_from_reader(reader).await?;
572
573 Ok((cmd, address))
574}
575
576async fn write_connection_request<T>(
577 writer: &mut T,
578 command: &Socks5Command,
579 address: &Address,
580) -> Result<()>
581where
582 T: AsyncWrite + Unpin,
583{
584 writer.write_u8(SOCKS5_VER).await?;
586
587 writer.write_u8(*command as u8).await?;
589
590 writer.write_u8(0).await?;
592
593 address.encode_to_writer(writer).await?;
595
596 writer.flush().await?;
597 Ok(())
598}
599
600async fn read_connection_response<T>(reader: &mut T) -> Result<(Socks5Reply, Address)>
615where
616 T: AsyncRead + Unpin,
617{
618 let ver = reader.read_u8().await?;
620 if ver != SOCKS5_VER {
621 return Err(Socks5Error::InvalidSocksVersion.into());
622 }
623
624 let reply = Socks5Reply::try_from(reader.read_u8().await?)?;
626
627 let rsv = reader.read_u8().await?;
629 if rsv != 0 {
630 return Err(Socks5Error::InvalidRsvValue.into());
631 }
632
633 let (address, _) = Address::decode_from_reader(reader).await?;
635
636 Ok((reply, address))
637}
638
639async fn write_connection_response<T>(
640 writer: &mut T,
641 reply: &Socks5Reply,
642 address: &Address,
643) -> Result<()>
644where
645 T: AsyncWrite + Unpin,
646{
647 writer.write_u8(SOCKS5_VER).await?;
649
650 writer.write_u8(*reply as u8).await?;
652
653 writer.write_u8(0).await?;
655
656 address.encode_to_writer(writer).await?;
658
659 writer.flush().await?;
660 Ok(())
661}
662
663#[derive(Clone, Debug, Eq, PartialEq)]
668#[non_exhaustive]
669pub enum Socks5Error {
670 NoAcceptableAuthMethod,
672 AuthenticationFailed,
674 ConnectionFailed,
676 InvalidSocksVersion,
678 InvalidAuthVersion,
680 NoAuthMethods,
682 NoSupportedAuthMethods,
684 InvalidAuthMethod,
686 InvalidCommand,
688 InvalidReply,
690 InvalidRsvValue,
692 InvalidUsernameEncoding,
694 InvalidPasswordEncoding,
696 UsernameTooLong,
698 PasswordTooLong,
700 TooManyAuthMethods,
702}
703
704impl Display for Socks5Error {
705 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
706 match self {
707 Self::NoAcceptableAuthMethod => write!(f, "No acceptable authentication method"),
708 Self::AuthenticationFailed => write!(f, "Authentication failed"),
709 Self::ConnectionFailed => write!(f, "Connection failed"),
710 Self::InvalidSocksVersion => write!(f, "Invalid SOCKS version"),
711 Self::InvalidAuthVersion => write!(f, "Invalid auth version"),
712 Self::NoAuthMethods => write!(f, "No authentication methods provided"),
713 Self::NoSupportedAuthMethods => write!(f, "No supported authentication methods"),
714 Self::InvalidAuthMethod => write!(f, "Invalid AuthMethod"),
715 Self::InvalidCommand => write!(f, "Invalid Command"),
716
717 Self::InvalidReply => write!(f, "Invalid Reply"),
718 Self::InvalidRsvValue => write!(f, "Invalid RSV value"),
719 Self::InvalidUsernameEncoding => write!(f, "Invalid username encoding"),
720 Self::InvalidPasswordEncoding => write!(f, "Invalid password encoding"),
721 Self::UsernameTooLong => write!(f, "Username too long"),
722 Self::PasswordTooLong => write!(f, "Password too long"),
723 Self::TooManyAuthMethods => write!(f, "Too many authentication methods"),
724 }
725 }
726}
727
728impl std::error::Error for Socks5Error {}
729
730impl From<Socks5Error> for Error {
731 fn from(e: Socks5Error) -> Self {
732 match e {
733 Socks5Error::NoAcceptableAuthMethod => Error::new(ErrorKind::PermissionDenied, e),
734 Socks5Error::AuthenticationFailed => Error::new(ErrorKind::PermissionDenied, e),
735 Socks5Error::ConnectionFailed => Error::new(ErrorKind::ConnectionRefused, e),
736 Socks5Error::InvalidSocksVersion => Error::new(ErrorKind::InvalidData, e),
737 Socks5Error::InvalidAuthVersion => Error::new(ErrorKind::InvalidData, e),
738 Socks5Error::NoAuthMethods => Error::new(ErrorKind::InvalidInput, e),
739 Socks5Error::NoSupportedAuthMethods => Error::new(ErrorKind::InvalidData, e),
740 Socks5Error::InvalidAuthMethod => Error::new(ErrorKind::InvalidData, e),
741 Socks5Error::InvalidCommand => Error::new(ErrorKind::InvalidData, e),
742 Socks5Error::InvalidReply => Error::new(ErrorKind::InvalidData, e),
743 Socks5Error::InvalidRsvValue => Error::new(ErrorKind::InvalidData, e),
744 Socks5Error::InvalidUsernameEncoding => Error::new(ErrorKind::InvalidData, e),
745 Socks5Error::InvalidPasswordEncoding => Error::new(ErrorKind::InvalidData, e),
746 Socks5Error::UsernameTooLong => Error::new(ErrorKind::InvalidInput, e),
747 Socks5Error::PasswordTooLong => Error::new(ErrorKind::InvalidInput, e),
748 Socks5Error::TooManyAuthMethods => Error::new(ErrorKind::InvalidInput, e),
749 }
750 }
751}
752
753#[cfg(test)]
754mod test {
755 use std::net::{Ipv4Addr, Ipv6Addr};
756
757 use super::*;
758 use crate::test_utils::create_mock_stream;
759
760 #[tokio::test]
761 async fn test_client_hello_write_read() {
762 let all_methods = [
763 vec![Socks5AuthOption::NoAuth],
764 vec![
765 Socks5AuthOption::NoAuth,
766 Socks5AuthOption::UserPass,
767 Socks5AuthOption::GssApi,
768 ],
769 ];
770 for methods in all_methods {
771 let (mut stream1, mut stream2) = create_mock_stream();
772 write_client_hello(&mut stream1, &methods).await.unwrap();
773 let recevied_methods = read_client_hello(&mut stream2).await.unwrap();
774 assert_eq!(methods.as_slice(), recevied_methods.as_slice());
775 }
776 }
777
778 #[tokio::test]
779 async fn test_server_hello_write_read() {
780 let (mut stream1, mut stream2) = create_mock_stream();
781 write_server_hello(&mut stream1, &Socks5AuthOption::NoAuth)
782 .await
783 .unwrap();
784 let method = read_server_hello(&mut stream2).await.unwrap();
785 assert_eq!(Socks5AuthOption::NoAuth, method);
786 }
787
788 #[tokio::test]
789 async fn test_auth_request_write_read() {
790 let (mut stream1, mut stream2) = create_mock_stream();
791 let auth = UserPassAuth {
792 username: "test_user".to_string(),
793 password: "test_pass".to_string(),
794 };
795 write_auth_request(&mut stream1, &auth).await.unwrap();
796 let received_auth = read_auth_request(&mut stream2).await.unwrap();
797 assert_eq!(auth.username, received_auth.username);
798 assert_eq!(auth.password, received_auth.password);
799 }
800
801 #[tokio::test]
802 async fn test_auth_response_write_read() {
803 let (mut stream1, mut stream2) = create_mock_stream();
805 write_auth_response(&mut stream1, true).await.unwrap();
806 read_auth_response(&mut stream2).await.unwrap();
807
808 let (mut stream1, mut stream2) = create_mock_stream();
810 write_auth_response(&mut stream1, false).await.unwrap();
811 let err = read_auth_response(&mut stream2).await.unwrap_err();
812 assert_eq!(
813 err.downcast::<Socks5Error>().unwrap(),
814 Socks5Error::AuthenticationFailed
815 );
816 }
817
818 #[tokio::test]
819 async fn test_connection_request_write_read() {
820 let all_commands = [
821 Socks5Command::Connect,
822 Socks5Command::Bind,
823 Socks5Command::UdpAssociate,
824 ];
825 let all_addresses = [
826 Address::IPv4((Ipv4Addr::new(192, 168, 1, 1), 8080)),
827 Address::DomainName(("example.com".to_string(), 443)),
828 Address::IPv6((
829 Ipv6Addr::new(0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x01),
830 8080,
831 )),
832 ];
833 for command in all_commands {
834 for address in all_addresses.iter() {
835 let (mut stream1, mut stream2) = create_mock_stream();
836 write_connection_request(&mut stream1, &command, address)
837 .await
838 .unwrap();
839 let (received_command, received_address) =
840 read_connection_request(&mut stream2).await.unwrap();
841 assert_eq!(command, received_command);
842 assert_eq!(address, &received_address);
843 }
844 }
845 }
846
847 #[tokio::test]
848 async fn test_connection_response_write_read() {
849 let all_replies = [
850 Socks5Reply::Succeeded,
851 Socks5Reply::GeneralFailure,
852 Socks5Reply::ConnectionNotAllowed,
853 Socks5Reply::NetworkUnreachable,
854 Socks5Reply::HostUnreachable,
855 Socks5Reply::ConnectionRefused,
856 Socks5Reply::TTLExpired,
857 Socks5Reply::CommandNotSupported,
858 Socks5Reply::AddressTypeNotSupported,
859 ];
860 let all_addresses = [
861 Address::IPv4((Ipv4Addr::new(192, 168, 1, 1), 8080)),
862 Address::DomainName(("example.com".to_string(), 443)),
863 Address::IPv6((
864 Ipv6Addr::new(0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x01),
865 8080,
866 )),
867 ];
868 for reply in all_replies {
869 for address in all_addresses.iter() {
870 let (mut stream1, mut stream2) = create_mock_stream();
871 write_connection_response(&mut stream1, &reply, address)
872 .await
873 .unwrap();
874 let (received_reply, received_address) =
875 read_connection_response(&mut stream2).await.unwrap();
876 assert_eq!(reply, received_reply);
877 assert_eq!(address, &received_address);
878 }
879 }
880 }
881
882 #[tokio::test]
883 async fn test_read_client_hello_invalid_version() {
884 let (mut client, server) = create_mock_stream();
885
886 server.write_immediate(&[0x04, 0x01, 0x00]).unwrap();
888
889 let result = read_client_hello(&mut client).await;
890
891 let err = result.unwrap_err();
892 assert_eq!(err.kind(), ErrorKind::InvalidData);
893 assert_eq!(
894 err.downcast::<Socks5Error>().unwrap(),
895 Socks5Error::InvalidSocksVersion
896 );
897 }
898
899 #[tokio::test]
900 async fn test_read_client_hello_no_auth_method() {
901 let (mut client, server) = create_mock_stream();
902
903 server.write_immediate(&[0x05, 0x00]).unwrap();
905
906 let result = read_client_hello(&mut client).await;
907
908 let err = result.unwrap_err();
909 assert_eq!(err.kind(), ErrorKind::InvalidInput);
910 assert_eq!(
911 err.downcast::<Socks5Error>().unwrap(),
912 Socks5Error::NoAuthMethods
913 );
914 }
915
916 #[tokio::test]
917 async fn test_read_client_hello_unsupported_auth_methods() {
918 let (mut client, server) = create_mock_stream();
919
920 server.write_immediate(&[0x05, 0x01, 0x80]).unwrap();
922
923 let result = read_client_hello(&mut client).await;
924
925 let err = result.unwrap_err();
926 assert_eq!(err.kind(), ErrorKind::InvalidData);
927 assert_eq!(
928 err.downcast::<Socks5Error>().unwrap(),
929 Socks5Error::NoSupportedAuthMethods
930 );
931 }
932
933 #[tokio::test]
934 async fn test_write_client_hello_no_auth_method() {
935 let (mut client, _server) = create_mock_stream();
936
937 let result = write_client_hello(&mut client, &[]).await;
938
939 let err = result.unwrap_err();
940 assert_eq!(err.kind(), ErrorKind::InvalidInput);
941 assert_eq!(
942 err.downcast::<Socks5Error>().unwrap(),
943 Socks5Error::NoAuthMethods
944 );
945 }
946
947 #[tokio::test]
948 async fn test_read_server_hello_invalid_version() {
949 let (mut client, server) = create_mock_stream();
950
951 server.write_immediate(&[0x04, 0x00]).unwrap();
953
954 let result = read_server_hello(&mut client).await;
955
956 let err = result.unwrap_err();
957 assert_eq!(err.kind(), ErrorKind::InvalidData);
958 assert_eq!(
959 err.downcast::<Socks5Error>().unwrap(),
960 Socks5Error::InvalidSocksVersion
961 );
962 }
963
964 #[tokio::test]
965 async fn test_read_auth_request_invalid_version() {
966 let (mut client, server) = create_mock_stream();
967
968 server
971 .write_immediate(&[
972 0x02, 0x04, b'u', b's', b'e', b'r', 0x04, b'p', b'a', b's', b's',
973 ])
974 .unwrap();
975
976 let result = read_auth_request(&mut client).await;
977
978 let err = result.unwrap_err();
979 assert_eq!(err.kind(), ErrorKind::InvalidData);
980 assert_eq!(
981 err.downcast::<Socks5Error>().unwrap(),
982 Socks5Error::InvalidAuthVersion
983 );
984 }
985
986 #[tokio::test]
987 async fn test_read_auth_request_invalid_username_encoding() {
988 let (mut client, server) = create_mock_stream();
989
990 server
992 .write_immediate(&[
993 0x01, 0x04, 0xFF, 0xFF, 0xFF, 0xFF, 0x04, b'p', b'a', b's', b's',
995 ])
996 .unwrap();
997
998 let result = read_auth_request(&mut client).await;
999
1000 let err = result.unwrap_err();
1001 assert_eq!(err.kind(), ErrorKind::InvalidData);
1002 assert_eq!(
1003 err.downcast::<Socks5Error>().unwrap(),
1004 Socks5Error::InvalidUsernameEncoding
1005 );
1006 }
1007
1008 #[tokio::test]
1009 async fn test_read_auth_request_invalid_password_encoding() {
1010 let (mut client, server) = create_mock_stream();
1011
1012 server
1014 .write_immediate(&[
1015 0x01, 0x04, b'u', b's', b'e', b'r', 0x04, 0xFF, 0xFF, 0xFF,
1016 0xFF, ])
1018 .unwrap();
1019
1020 let result = read_auth_request(&mut client).await;
1021
1022 let err = result.unwrap_err();
1023 assert_eq!(err.kind(), ErrorKind::InvalidData);
1024 assert_eq!(
1025 err.downcast::<Socks5Error>().unwrap(),
1026 Socks5Error::InvalidPasswordEncoding
1027 );
1028 }
1029
1030 #[tokio::test]
1031 async fn test_write_auth_request_username_too_long() {
1032 let (mut client, _server) = create_mock_stream();
1033
1034 let long_username = "a".repeat(256);
1036 let auth = UserPassAuth {
1037 username: long_username,
1038 password: "password".to_string(),
1039 };
1040
1041 let result = write_auth_request(&mut client, &auth).await;
1042
1043 let err = result.unwrap_err();
1044 assert_eq!(err.kind(), ErrorKind::InvalidInput);
1045 assert_eq!(
1046 err.downcast::<Socks5Error>().unwrap(),
1047 Socks5Error::UsernameTooLong
1048 );
1049 }
1050
1051 #[tokio::test]
1052 async fn test_write_auth_request_password_too_long() {
1053 let (mut client, _server) = create_mock_stream();
1054
1055 let long_password = "a".repeat(256);
1057 let auth = UserPassAuth {
1058 username: "username".to_string(),
1059 password: long_password,
1060 };
1061
1062 let result = write_auth_request(&mut client, &auth).await;
1063
1064 let err = result.unwrap_err();
1065 assert_eq!(err.kind(), ErrorKind::InvalidInput);
1066 assert_eq!(
1067 err.downcast::<Socks5Error>().unwrap(),
1068 Socks5Error::PasswordTooLong
1069 );
1070 }
1071
1072 #[tokio::test]
1073 async fn test_read_auth_response_invalid_auth_version() {
1074 let (mut client, server) = create_mock_stream();
1075
1076 server.write_immediate(&[0x02, 0x00]).unwrap();
1078
1079 let result = read_auth_response(&mut client).await;
1080
1081 let err = result.unwrap_err();
1082 assert_eq!(err.kind(), ErrorKind::InvalidData);
1083 assert_eq!(
1084 err.downcast::<Socks5Error>().unwrap(),
1085 Socks5Error::InvalidAuthVersion
1086 );
1087 }
1088
1089 #[tokio::test]
1090 async fn test_read_auth_response_auth_failed() {
1091 let (mut client, server) = create_mock_stream();
1092
1093 server.write_immediate(&[0x01, 0x01]).unwrap();
1095
1096 let result = read_auth_response(&mut client).await;
1097
1098 let err = result.unwrap_err();
1099 assert_eq!(err.kind(), ErrorKind::PermissionDenied);
1100 assert_eq!(
1101 err.downcast::<Socks5Error>().unwrap(),
1102 Socks5Error::AuthenticationFailed
1103 );
1104 }
1105
1106 #[tokio::test]
1107 async fn test_read_connection_request_invalid_version() {
1108 let (mut client, server) = create_mock_stream();
1109
1110 server
1112 .write_immediate(&[
1113 0x04, 0x01, 0x00, 0x01, 0x7F, 0x00, 0x00, 0x01, 0x00, 0x50, ])
1119 .unwrap();
1120
1121 let result = read_connection_request(&mut client).await;
1122
1123 let err = result.unwrap_err();
1124 assert_eq!(err.kind(), ErrorKind::InvalidData);
1125 assert_eq!(
1126 err.downcast::<Socks5Error>().unwrap(),
1127 Socks5Error::InvalidSocksVersion
1128 );
1129 }
1130
1131 #[tokio::test]
1132 async fn test_read_connection_request_invalid_rsv() {
1133 let (mut client, server) = create_mock_stream();
1134
1135 server
1137 .write_immediate(&[
1138 0x05, 0x01, 0x01, 0x01, 0x7F, 0x00, 0x00, 0x01, 0x00, 0x50, ])
1145 .unwrap();
1146
1147 let result = read_connection_request(&mut client).await;
1148
1149 let err = result.unwrap_err();
1150 assert_eq!(err.kind(), ErrorKind::InvalidData);
1151 assert_eq!(
1152 err.downcast::<Socks5Error>().unwrap(),
1153 Socks5Error::InvalidRsvValue
1154 );
1155 }
1156
1157 #[tokio::test]
1158 async fn test_read_connection_response_invalid_version() {
1159 let (mut client, server) = create_mock_stream();
1160
1161 server
1163 .write_immediate(&[
1164 0x04, 0x00, 0x00, 0x01, 0x7F, 0x00, 0x00, 0x01, 0x00, 0x50, ])
1171 .unwrap();
1172
1173 let result = read_connection_response(&mut client).await;
1174
1175 let err = result.unwrap_err();
1176 assert_eq!(err.kind(), ErrorKind::InvalidData);
1177 assert_eq!(
1178 err.downcast::<Socks5Error>().unwrap(),
1179 Socks5Error::InvalidSocksVersion
1180 );
1181 }
1182
1183 #[tokio::test]
1184 async fn test_read_connection_response_invalid_rsv() {
1185 let (mut client, server) = create_mock_stream();
1186
1187 server
1189 .write_immediate(&[
1190 0x05, 0x00, 0x01, 0x01, 0x7F, 0x00, 0x00, 0x01, 0x00, 0x50, ])
1197 .unwrap();
1198
1199 let result = read_connection_response(&mut client).await;
1200
1201 let err = result.unwrap_err();
1202 assert_eq!(err.kind(), ErrorKind::InvalidData);
1203 assert_eq!(
1204 err.downcast::<Socks5Error>().unwrap(),
1205 Socks5Error::InvalidRsvValue
1206 );
1207 }
1208}