#![forbid(unsafe_code)]
#[macro_use]
extern crate log;
mod socks;
use futures::future::try_join;
pub use socks::AuthMethod;
use socks::{AddrType, Command, Response, RESERVED, VERSION5};
use std::{
boxed::Box,
error::Error,
io,
net::{Shutdown, SocketAddr, SocketAddrV4, SocketAddrV6},
};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream},
sync::{mpsc, oneshot},
};
type AuthCheckMsg = (String, String, oneshot::Sender<bool>);
pub struct SocksServer {
listener: TcpListener,
allow_no_auth: bool,
auth_tx: mpsc::Sender<AuthCheckMsg>,
}
impl SocksServer {
pub async fn new(
socket_addr: SocketAddr,
allow_no_auth: bool,
auth: Box<dyn Fn(String, String) -> bool + Send>,
) -> SocksServer {
let (tx, mut rx) = mpsc::channel::<AuthCheckMsg>(100);
tokio::spawn(async move {
while let Some((username, password, sender)) = rx.recv().await {
if let Err(_) = sender.send(auth(username, password)) {
error!("Failed to send back authentication result.");
}
}
});
println!("SOCKS5 server listening on {}", socket_addr);
SocksServer {
listener: TcpListener::bind(socket_addr).await.unwrap(),
allow_no_auth,
auth_tx: tx,
}
}
pub async fn serve(&mut self) {
loop {
let no_auth = self.allow_no_auth.clone();
if let Ok((socket, address)) = self.listener.accept().await {
let tx2 = self.auth_tx.clone();
tokio::spawn(async move {
info!("Client connected: {}", address);
let mut client = SocksServerConnection::new(socket, no_auth, tx2);
match client.serve().await {
Ok(_) => info!("Request was served successfully."),
Err(err) => error!("{}", err.to_string()),
}
});
}
}
}
}
struct SocksServerConnection {
socket: TcpStream,
no_auth: bool,
auth_ch: mpsc::Sender<AuthCheckMsg>,
}
impl SocksServerConnection {
fn new(
socket: TcpStream,
no_auth: bool,
auth_ch: mpsc::Sender<(String, String, oneshot::Sender<bool>)>,
) -> SocksServerConnection {
SocksServerConnection {
socket,
no_auth,
auth_ch,
}
}
fn shutdown(&mut self, msg: &str) -> Result<(), Box<dyn Error>> {
self.socket.shutdown(Shutdown::Both)?;
warn!("{}", msg);
Ok(())
}
async fn serve(&mut self) -> Result<(), Box<dyn Error>> {
let mut header = [0u8; 2];
self.socket.read_exact(&mut header).await?;
if header[0] != VERSION5 {
self.shutdown("Unsupported version")?;
Err(Response::Failure)?;
}
let methods = AuthMethod::get_available_methods(header[1], &mut self.socket).await?;
self.auth(methods).await?;
self.handle_req().await?;
Ok(())
}
async fn auth(&mut self, methods: Vec<AuthMethod>) -> Result<(), Box<dyn Error>> {
if methods.contains(&AuthMethod::UserPass) {
self.socket
.write_all(&[VERSION5, AuthMethod::UserPass as u8])
.await?;
let mut ulen = [0u8; 2];
self.socket.read_exact(&mut ulen).await?;
let ulen = ulen[1];
let mut username: Vec<u8> = Vec::with_capacity(ulen as usize);
for _ in 0..ulen {
username.push(0)
}
self.socket.read_exact(&mut username).await?;
let username = String::from_utf8(username).unwrap();
let mut plen = [0u8; 1];
self.socket.read_exact(&mut plen).await?;
let plen = plen[0];
let mut password: Vec<u8> = Vec::with_capacity(plen as usize);
for _ in 0..plen {
password.push(0)
}
self.socket.read_exact(&mut password).await?;
let password = String::from_utf8(password).unwrap();
let (tx, rx) = oneshot::channel::<bool>();
self.auth_ch.send((username.clone(), password, tx)).await?;
if rx.await? {
info!("User authenticated: {}", username);
self.socket.write_all(&[1, Response::Success as u8]).await?;
} else {
self.socket
.write_all(&[VERSION5, Response::Failure as u8])
.await?;
self.shutdown("Authentication failed.")?;
}
} else if self.no_auth && methods.contains(&AuthMethod::NoAuth) {
warn!("Client connected with no authentication");
self.socket
.write_all(&[VERSION5, AuthMethod::NoAuth as u8])
.await?
} else {
self.socket
.write_all(&[VERSION5, Response::Failure as u8])
.await?;
self.shutdown("No acceptable method found.")?;
}
Ok(())
}
async fn handle_req(&mut self) -> Result<(), Box<dyn Error>> {
let mut data = [0u8; 3];
self.socket.read(&mut data).await?;
let addresses = AddrType::get_socket_addrs(&mut self.socket).await?;
match Command::from(data[1] as usize) {
Some(Command::Connect) => self.cmd_connect(addresses).await?,
_ => {
self.shutdown("Command not supported.")?;
Err(Response::CommandNotSupported)?;
}
};
Ok(())
}
async fn cmd_connect(&mut self, addrs: Vec<SocketAddr>) -> Result<(), Box<dyn Error>> {
let mut dest = TcpStream::connect(&addrs[..]).await?;
self.socket
.write_all(&[
VERSION5,
Response::Success as u8,
RESERVED,
1,
127,
0,
0,
1,
0,
0,
])
.await
.unwrap();
let (mut ro, mut wo) = dest.split();
let (mut ri, mut wi) = self.socket.split();
let client_to_server = async {
tokio::io::copy(&mut ri, &mut wo).await?;
wo.shutdown().await
};
let server_to_client = async {
tokio::io::copy(&mut ro, &mut wi).await?;
wi.shutdown().await
};
try_join(client_to_server, server_to_client).await?;
Ok(())
}
}
pub struct SocksStream {
stream: TcpStream,
}
impl SocksStream {
pub async fn connect(
proxy_addr: SocketAddr,
target_addr: impl ToTargetAddr,
user_pass: Option<(String, String)>,
) -> Result<TcpStream, Box<dyn Error>> {
let mut socks_stream = SocksStream {
stream: TcpStream::connect(proxy_addr).await?,
};
connect_with_stream(&mut socks_stream.stream, target_addr, user_pass).await?;
Ok(socks_stream.stream)
}
}
pub async fn socks_handshake(
stream: &mut TcpStream,
user_pass: Option<(String, String)>
) -> Result<(), Box<dyn Error>> {
let with_userpass = user_pass.is_some();
let methods_len = if with_userpass { 2 } else { 1 };
let mut data = vec![0; methods_len + 2];
data[0] = VERSION5;
data[1] = methods_len as u8;
if with_userpass {
data[2] = AuthMethod::UserPass as u8;
}
data[1 + methods_len] = AuthMethod::NoAuth as u8;
stream.write_all(&mut data).await?;
let mut response = [0u8; 2];
stream.read_exact(&mut response).await?;
if response[0] != VERSION5 {
Err(io::Error::new(
io::ErrorKind::InvalidData,
"Invalid SOCKS version",
))?;
}
if response[1] == AuthMethod::UserPass as u8 {
if let Some((username, password)) = user_pass {
let mut data = vec![0; username.len() + password.len() + 3];
data[0] = VERSION5;
data[1] = username.len() as u8;
data[2..2 + username.len()].copy_from_slice(username.as_bytes());
data[2 + username.len()] = password.len() as u8;
data[3 + username.len()..].copy_from_slice(password.as_bytes());
stream.write_all(&data).await?;
let mut response = [0; 2];
stream.read_exact(&mut response).await?;
if response[1] != Response::Success as u8 {
Err(io::Error::new(
io::ErrorKind::Other,
"Wrong username/password",
))?;
}
} else {
Err(io::Error::new(
io::ErrorKind::Other,
"Username & password requried",
))?;
}
} else if response[1] != AuthMethod::NoAuth as u8 {
Err(io::Error::new(
io::ErrorKind::Other,
"Invalid authentication method",
))?;
}
Ok(())
}
pub async fn cmd_connect(
stream: &mut TcpStream,
target_addr: impl ToTargetAddr,
) -> Result<(), Box<dyn Error>> {
let target_addr = target_addr.target_addr();
let mut data = vec![0; 6 + target_addr.len()];
data[0] = VERSION5;
data[1] = Command::Connect as u8;
data[2] = RESERVED;
data[3] = target_addr.addr_type() as u8;
target_addr.write_to(&mut data[4..]);
stream.write_all(&data).await?;
let mut response = [0u8; 3];
stream.read(&mut response).await?;
AddrType::get_socket_addrs(stream).await?;
Ok(())
}
pub async fn connect_with_stream(
stream: &mut TcpStream,
target_addr: impl ToTargetAddr,
user_pass: Option<(String, String)>,
) -> Result<(), Box<dyn Error>> {
socks_handshake(stream, user_pass).await?;
cmd_connect(stream, target_addr).await?;
Ok(())
}
#[derive(Debug, Clone)]
pub enum TargetAddr {
V4(SocketAddrV4),
V6(SocketAddrV6),
Domain((String, u16)),
}
impl TargetAddr {
fn len(&self) -> usize {
match self {
TargetAddr::V4(_) => 4,
TargetAddr::V6(_) => 16,
TargetAddr::Domain((domain, _)) => domain.len() + 1,
}
}
fn addr_type(&self) -> AddrType {
match self {
TargetAddr::V4(_) => AddrType::V4,
TargetAddr::V6(_) => AddrType::V4,
TargetAddr::Domain(_) => AddrType::Domain,
}
}
fn write_to(&self, buf: &mut [u8]) {
match self {
TargetAddr::V4(addr) => {
let mut ip = addr.ip().octets().to_vec();
ip.extend(&addr.port().to_be_bytes());
buf[..].copy_from_slice(&ip[..]);
}
TargetAddr::V6(addr) => {
let mut ip = addr.ip().octets().to_vec();
ip.extend(&addr.port().to_be_bytes());
buf[..].copy_from_slice(&ip[..]);
}
TargetAddr::Domain((domain, port)) => {
let mut ip = domain.as_bytes().to_vec();
ip.extend(&port.to_be_bytes());
buf[0] = domain.len() as u8;
buf[1..].copy_from_slice(&ip[..]);
}
}
}
}
pub trait ToTargetAddr {
fn target_addr(self) -> TargetAddr;
}
impl ToTargetAddr for TargetAddr {
fn target_addr(self) -> TargetAddr {
self
}
}
impl ToTargetAddr for SocketAddrV4 {
fn target_addr(self) -> TargetAddr {
TargetAddr::V4(self)
}
}
impl ToTargetAddr for SocketAddrV6 {
fn target_addr(self) -> TargetAddr {
TargetAddr::V6(self)
}
}
impl ToTargetAddr for SocketAddr {
fn target_addr(self) -> TargetAddr {
match self {
SocketAddr::V4(addr) => TargetAddr::V4(addr),
SocketAddr::V6(addr) => TargetAddr::V6(addr),
}
}
}