use std::io::ErrorKind;
use std::net::{Shutdown, SocketAddr};
use std::sync::Arc;
use log::{debug, error, info};
use tokio::net::{TcpListener, TcpStream};
use tokio::stream::StreamExt;
use crate::async_io::{AsyncReadTrait, AsyncWriteTrait};
use crate::encrypted_stream::EncryptedStream;
use crate::socks5_addr::{Socks5Addr, Socks5AddrType};
use crate::{Error, GlobalConfig, Result};
pub struct SocksServer {
remote_addr: SocketAddr,
tcp_listener: TcpListener,
global_config: GlobalConfig,
}
#[repr(u8)]
#[derive(Clone, Copy, Debug, PartialEq)]
enum Method {
NoAuthenticationRequired = 0x00,
Gssapi = 0x01,
UsernamePassword = 0x02,
IanaAssigned = 0x03,
PrivateMethods = 0x80,
NoAcceptableMethods = 0xFF,
}
impl From<u8> for Method {
fn from(method: u8) -> Self {
match method {
0x00 => Method::NoAuthenticationRequired,
0x01 => Method::Gssapi,
0x02 => Method::UsernamePassword,
0x03..=0x7F => Method::IanaAssigned,
0x80..=0xFE => Method::PrivateMethods,
0xFF => Method::NoAcceptableMethods,
}
}
}
#[repr(u8)]
#[derive(Clone, Copy, Debug, PartialEq)]
enum Command {
Connect = 0x01,
Bind = 0x02,
UdpAssociate = 0x03,
}
#[repr(u8)]
#[derive(Debug)]
enum ReplyStatus {
Succeeded = 0x00,
GeneralFailure = 0x01,
ConnectionNotAllowed = 0x02,
NetworkUnreachable = 0x03,
HostUnreachable = 0x04,
ConnectionRefused = 0x05,
TtlExpired = 0x06,
CommandNotSupported = 0x07,
AddressTypeNotSupported = 0x08,
}
impl SocksServer {
const SOCKET_VERSION: u8 = 0x05u8;
const RSV: u8 = 0x00u8;
pub async fn create(
addr: SocketAddr,
remote: SocketAddr,
global_config: GlobalConfig,
) -> Result<Self> {
info!("Creating SOCKS5 server ...");
info!("Starting socks server at address {} ...", addr);
Ok(Self {
remote_addr: remote,
tcp_listener: TcpListener::bind(addr).await?,
global_config,
})
}
pub fn create_from_std(
tcp_listener: std::net::TcpListener,
remote: SocketAddr,
global_config: GlobalConfig,
) -> Result<Self> {
info!("Creating SOCKS5 server ...");
let tcp_listener = TcpListener::from_std(tcp_listener)?;
info!(
"Starting socks server at address {} ...",
tcp_listener.local_addr()?
);
Ok(Self {
remote_addr: remote,
tcp_listener,
global_config,
})
}
fn check_socks_version(version: u8) -> Result<()> {
if version != Self::SOCKET_VERSION {
error!("Failed: socks version does not match {:#02X?}", version);
Err(Error::UnsupportedSocksVersion(version))
} else {
Ok(())
}
}
fn check_rsv(rsv: u8) -> Result<()> {
if rsv != Self::RSV {
error!("Failed: reserved bit does not match {:#02X?}", rsv);
Err(Error::UnexpectedReservedBit(rsv))
} else {
Ok(())
}
}
async fn read_and_parse_first_request(
stream: &mut (impl AsyncReadTrait + Unpin),
) -> Result<Vec<Method>> {
info!("SOCKS5 handshaking ...");
let mut buf = [0u8; 2];
stream.read_exact(&mut buf).await?;
Self::check_socks_version(buf[0])?;
let nmethods = buf[1] as usize;
let mut methods = vec![0u8; nmethods];
debug!("Expecting {} following bytes", nmethods);
info!("Reading acceptable auth methods ...");
stream.read_exact(&mut methods.as_mut_slice()).await?;
let mut ret = Vec::with_capacity(nmethods);
for method in methods {
ret.push(Method::from(method));
}
info!("Acceptable auth methods processed.");
Ok(ret)
}
async fn read_and_parse_command_request(
stream: &mut (impl AsyncReadTrait + Unpin),
) -> Result<Option<Command>> {
info!("Reading command request and rsv ...");
let mut buf = [0u8; 3];
stream.read_exact(&mut buf).await?;
Self::check_socks_version(buf[0])?;
let cmd_byte = buf[1];
let cmd = match cmd_byte {
0x01 => Some(Command::Connect),
0x02 => Some(Command::Bind),
0x03 => Some(Command::UdpAssociate),
_ => {
error!("Unrecognized socks command {}", cmd_byte);
None
}
};
if cmd.is_some() {
debug_assert_eq!(cmd_byte, cmd.expect("cmd should be some") as u8);
}
Self::check_rsv(buf[2])?;
Ok(cmd)
}
async fn serve_socks5_stream(
mut stream: TcpStream,
remote_addr: SocketAddr,
global_config: Arc<GlobalConfig>,
) -> Result<()> {
let available_methods =
Self::read_and_parse_first_request(&mut stream).await?;
let method =
if available_methods.contains(&Method::NoAuthenticationRequired) {
Method::NoAuthenticationRequired
} else {
Method::NoAcceptableMethods
};
info!("Agreed on auth method {:#?}", method);
stream
.write_all(&[Self::SOCKET_VERSION, method as u8])
.await?;
if method == Method::NoAcceptableMethods {
info!("No auth methods available, shutting down connection.");
stream.shutdown(Shutdown::Both)?;
return Ok(());
}
let cmd_option =
Self::read_and_parse_command_request(&mut stream).await?;
let cmd = match cmd_option {
Some(cmd) => cmd,
None => {
stream
.write_all(&[
Self::SOCKET_VERSION,
ReplyStatus::CommandNotSupported as u8,
Self::RSV,
])
.await?;
return Ok(());
}
};
let target_addr_result =
Socks5Addr::read_and_parse_address(&mut stream).await;
let target_addr = match target_addr_result {
Ok(target_addr) => target_addr,
Err(Error::UnsupportedAddressType(_cmd)) => {
stream
.write_all(&[
Self::SOCKET_VERSION,
ReplyStatus::AddressTypeNotSupported as u8,
Self::RSV,
])
.await?;
return Ok(());
}
Err(e) => return Err(e),
};
debug!("Executing command {:#?} to target {:?}", cmd, target_addr);
match cmd {
Command::Connect => {
info!("Connecting to remote ...");
let remote_stream =
TcpStream::connect(remote_addr).await.map_err(|e| {
let socks5_error = match e.kind() {
ErrorKind::PermissionDenied => {
ReplyStatus::ConnectionNotAllowed
}
ErrorKind::NotConnected => {
ReplyStatus::NetworkUnreachable
}
ErrorKind::NotFound => ReplyStatus::HostUnreachable,
ErrorKind::ConnectionRefused => {
ReplyStatus::ConnectionRefused
}
ErrorKind::TimedOut => ReplyStatus::TtlExpired,
_ => ReplyStatus::GeneralFailure,
};
error!("Error connecting to remote: {}", e);
socks5_error
});
let remote_stream = match remote_stream {
Ok(remote_stream) => remote_stream,
Err(reply_status) => {
#[rustfmt::skip]
let error_reply: [u8; 10] = [
Self::SOCKET_VERSION,
reply_status as u8,
Self::RSV,
Socks5AddrType::V4 as u8,
0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8,
];
stream.write_all(&error_reply).await?;
return Ok(());
}
};
let local_to_remote_port = remote_stream.local_addr()?.port();
let mut remote_encrypted_stream = EncryptedStream::establish(
remote_stream,
global_config.master_key.as_slice(),
global_config.cipher_type,
global_config.compatible_mode,
)
.await?;
info!("Setting shadow address on remote ...");
remote_encrypted_stream
.write_all(&target_addr.bytes())
.await?;
#[rustfmt::skip]
stream
.write_all(&[
Self::SOCKET_VERSION,
ReplyStatus::Succeeded as u8,
Self::RSV,
Socks5AddrType::V4 as u8,
0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8,
])
.await?;
info!("Creating connection relay ...");
crate::async_io::proxy(
stream,
remote_encrypted_stream,
target_addr,
);
info!("Relay created on port {}", local_to_remote_port);
}
_ => {
#[rustfmt::skip]
let unsupported_reply: [u8; 10] = [
Self::SOCKET_VERSION,
ReplyStatus::CommandNotSupported as u8,
Self::RSV,
Socks5AddrType::V4 as u8,
0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8, 0x00u8,
];
stream.write_all(&unsupported_reply).await?;
info!("Closing connection");
stream.shutdown(Shutdown::Both)?;
info!("Connection closed.");
}
}
Ok(())
}
pub async fn run(mut self) {
info!("Running socks server loop ...");
info!("Timeout of {:?} is ignored.", self.global_config.timeout);
info!("Connection will be kept alive until there is an error.");
let base_global_config = Arc::new(self.global_config);
while let Some(stream) = self.tcp_listener.next().await {
match stream {
Ok(stream) => {
let remote_addr = self.remote_addr;
let global_config = base_global_config.clone();
tokio::spawn(async move {
info!("New connection");
let response = Self::serve_socks5_stream(
stream,
remote_addr,
global_config,
)
.await;
if let Err(e) = response {
error!("Error serving client: {}", e);
}
});
}
Err(e) => {
error!("Error accepting connection: {}", e);
}
}
}
}
}
#[cfg(test)]
mod test {
use std::io::{Read, Write};
use std::net::{TcpListener, TcpStream};
use std::time::Duration;
use crate::crypto::CipherType;
use crate::test_utils::local_tcp_server::run_local_tcp_server;
use crate::test_utils::ready_buf::ReadyBuf;
use super::*;
const DEFAULT_REMOTE_ADDR: &str = "127.0.0.1:80";
const SOCKS_SERVER_ADDR: &str = "127.0.0.1:0";
fn start_and_connect_to_server() -> Result<TcpStream> {
start_and_connect_to_server_remote(
DEFAULT_REMOTE_ADDR
.parse()
.expect("Parsing should not fail"),
)
}
fn start_and_connect_to_server_remote(
remote_addr: SocketAddr,
) -> Result<TcpStream> {
let local_socket_addr: SocketAddr =
SOCKS_SERVER_ADDR.parse().expect("Parsing should not fail.");
let tcp_listener = TcpListener::bind(local_socket_addr)?;
let server_addr = tcp_listener.local_addr()?;
std::thread::spawn(move || {
let mut rt = tokio::runtime::Runtime::new()
.expect("Shout not error when creating a runtime.");
rt.block_on(async {
let server = SocksServer {
remote_addr,
tcp_listener: tokio::net::TcpListener::from_std(
tcp_listener,
)
.expect("Creating tcp listener should not fail"),
global_config: GlobalConfig {
master_key: vec![],
cipher_type: CipherType::None,
timeout: Duration::from_secs(1),
fast_open: false,
compatible_mode: false,
},
};
server.run().await
});
});
Ok(TcpStream::connect(server_addr)?)
}
#[tokio::test]
async fn test_socks5_handshake_async() -> Result<()> {
let mut ready_buf = ReadyBuf::make(&[&[0x05, 0x02, 0x00, 0x80]]);
let methods =
SocksServer::read_and_parse_first_request(&mut ready_buf).await?;
assert_eq!(
methods,
vec![Method::NoAuthenticationRequired, Method::PrivateMethods]
);
Ok(())
}
#[tokio::test]
async fn test_socks5_handshake_version_mismatch_async() -> Result<()> {
let mut ready_buf = ReadyBuf::make(&[&[0x04, 0x00]]);
let result =
SocksServer::read_and_parse_first_request(&mut ready_buf).await;
if let Err(Error::UnsupportedSocksVersion(v)) = result {
assert_eq!(v, 0x04);
} else {
panic!("Should return error UnsupportedSocksVersion = 0x04");
}
Ok(())
}
#[tokio::test]
async fn test_socks5_handshake_no_methods_async() -> Result<()> {
let mut ready_buf = ReadyBuf::make(&[&[0x05, 0x00]]);
let methods =
SocksServer::read_and_parse_first_request(&mut ready_buf).await?;
assert_eq!(methods, vec![]);
Ok(())
}
#[tokio::test]
async fn test_socks5_command_async() -> Result<()> {
let mut ready_buf = ReadyBuf::make(&[&[0x05, 0x02, 0x00]]);
let command =
SocksServer::read_and_parse_command_request(&mut ready_buf).await?;
assert_eq!(command, Some(Command::Bind));
Ok(())
}
#[tokio::test]
async fn test_socks5_command_version_mismatch_async() -> Result<()> {
let mut ready_buf = ReadyBuf::make(&[&[0x04, 0x02, 0x00]]);
let result =
SocksServer::read_and_parse_command_request(&mut ready_buf).await;
if let Err(Error::UnsupportedSocksVersion(v)) = result {
assert_eq!(v, 0x04);
} else {
panic!("Should return error UnsupportedSocksVersion = 0x04");
}
Ok(())
}
#[tokio::test]
async fn test_socks5_command_none_async() -> Result<()> {
let mut ready_buf = ReadyBuf::make(&[&[0x05, 0x04, 0x00]]);
let cmd =
SocksServer::read_and_parse_command_request(&mut ready_buf).await?;
assert!(cmd.is_none());
Ok(())
}
#[tokio::test]
async fn test_socks5_command_rsv_async() -> Result<()> {
let mut ready_buf = ReadyBuf::make(&[&[0x05, 0x03, 0x01]]);
let result =
SocksServer::read_and_parse_command_request(&mut ready_buf).await;
if let Err(Error::UnexpectedReservedBit(v)) = result {
assert_eq!(v, 0x01);
} else {
panic!("Should return error UnexpectedReservedBit = 0x01");
}
Ok(())
}
#[tokio::test]
async fn test_socks5_command_rsv_and_none_async() -> Result<()> {
let mut ready_buf = ReadyBuf::make(&[&[0x05, 0x04, 0x01]]);
let result =
SocksServer::read_and_parse_command_request(&mut ready_buf).await;
if let Err(Error::UnexpectedReservedBit(v)) = result {
assert_eq!(v, 0x01);
} else {
panic!("Should return error UnexpectedReservedBit = 0x01");
}
Ok(())
}
#[test]
fn test_socks5_no_auth_methods() -> Result<()> {
let mut stream = start_and_connect_to_server()?;
stream.write_all(&[0x05, 0x01, 0x08])?;
let mut buf = [0u8; 2];
stream.read_exact(&mut buf)?;
assert_eq!(buf, [0x05, 0xFF]);
if let Err(e) = stream.read_exact(&mut buf) {
assert_eq!(e.kind(), ErrorKind::UnexpectedEof);
} else {
panic!("The connection should have shutdown.");
}
Ok(())
}
#[test]
fn test_socks5_agreed_auth_methods() -> Result<()> {
let mut stream = start_and_connect_to_server()?;
stream.write_all(&[0x05, 0x02, 0x08, 0x00])?;
let mut buf = [0u8; 2];
stream.read_exact(&mut buf)?;
assert_eq!(buf, [0x05, 0x00]);
Ok(())
}
#[test]
fn test_socks5_command_not_supported() -> Result<()> {
let mut stream = start_and_connect_to_server()?;
stream.write_all(&[0x05, 0x01, 0x00])?;
let mut buf = [0u8; 2];
stream.read_exact(&mut buf)?;
assert_eq!(buf, [0x05, 0x00]);
stream.write_all(&[0x05, 0x04, 0x00])?;
let mut buf = [0u8; 3];
stream.read_exact(&mut buf)?;
assert_eq!(buf, [0x05, 0x07, 0x00]);
Ok(())
}
#[test]
fn test_socks5_command_address_not_supported() -> Result<()> {
let mut stream = start_and_connect_to_server()?;
stream.write_all(&[0x05, 0x01, 0x00])?;
let mut buf = [0u8; 2];
stream.read_exact(&mut buf)?;
assert_eq!(buf, [0x05, 0x00]);
stream.write_all(&[0x05, 0x01, 0x00, 0x02])?;
let mut buf = [0u8; 3];
stream.read_exact(&mut buf)?;
assert_eq!(buf, [0x05, 0x08, 0x00]);
Ok(())
}
#[test]
fn test_socks5_command_connect() -> Result<()> {
let (local_tcp_server_addr, _tcp_server_running) =
run_local_tcp_server()?;
let mut stream =
start_and_connect_to_server_remote(local_tcp_server_addr)?;
stream.write_all(&[0x05, 0x01, 0x00])?;
let mut buf = [0u8; 2];
stream.read_exact(&mut buf)?;
assert_eq!(buf, [0x05, 0x00]);
stream.write_all(&[0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80])?;
let mut buf = [0u8; 10];
stream.read_exact(&mut buf)?;
assert_eq!(
buf,
[0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]
);
Ok(())
}
#[test]
fn test_socks5_command_other() -> Result<()> {
let mut stream = start_and_connect_to_server()?;
stream.write_all(&[0x05, 0x01, 0x00])?;
let mut buf = [0u8; 2];
stream.read_exact(&mut buf)?;
assert_eq!(buf, [0x05, 0x00]);
stream.write_all(&[0x05, 0x02, 0x00, 0x03, 0x01, 0x40, 0x00, 0x00])?;
let mut buf = [0u8; 10];
stream.read_exact(&mut buf)?;
assert_eq!(
buf,
[0x05, 0x07, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]
);
Ok(())
}
#[test]
fn test_socks5_command_connect_failure() -> Result<()> {
let mut stream = start_and_connect_to_server()?;
stream.write_all(&[0x05, 0x01, 0x00])?;
let mut buf = [0u8; 2];
stream.read_exact(&mut buf)?;
assert_eq!(buf, [0x05, 0x00]);
stream.write_all(&[0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80])?;
let mut buf = [0u8; 10];
stream.read_exact(&mut buf)?;
assert_eq!(
buf,
[0x05, 0x05, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]
);
Ok(())
}
#[test]
fn test_socks5_command_connect_proxy() -> Result<()> {
let (local_tcp_server_addr, _tcp_server_running) =
run_local_tcp_server()?;
let socks5_addr = match local_tcp_server_addr {
SocketAddr::V4(socket_addr_v4) => Socks5Addr::V4(socket_addr_v4),
SocketAddr::V6(socket_addr_v6) => Socks5Addr::V6(socket_addr_v6),
};
let mut stream =
start_and_connect_to_server_remote(local_tcp_server_addr)?;
stream.write_all(&[0x05, 0x01, 0x00])?;
let mut buf = [0u8; 2];
stream.read_exact(&mut buf)?;
assert_eq!(buf, [0x05, 0x00]);
stream.write_all(&[0x05, 0x01, 0x00])?;
stream.write_all(&socks5_addr.bytes())?;
let mut buf = [0u8; 10];
stream.read_exact(&mut buf)?;
assert_eq!(
buf,
[0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]
);
let mut buf = [0u8; 2];
stream.read_exact(&mut buf)?;
assert_eq!(buf, [0x00, 0x01]);
Ok(())
}
}