use self::{associate::UdpAssociate, bind::Bind, connect::Connect};
use crate::{
protocol::{Address, AuthMethod, Command, HandshakeRequest, HandshakeResponse, Reply, Request, Response},
server::AuthExecutor,
};
use std::{net::SocketAddr, sync::Arc};
use tokio::{io::AsyncWriteExt, net::TcpStream};
pub mod associate;
pub mod bind;
pub mod connect;
pub struct IncomingConnection {
stream: TcpStream,
auth: Arc<dyn AuthExecutor + Send + Sync>,
}
impl IncomingConnection {
#[inline]
pub(crate) fn new(stream: TcpStream, auth: Arc<dyn AuthExecutor + Send + Sync>) -> Self {
IncomingConnection { stream, auth }
}
pub async fn handshake(mut self) -> std::io::Result<Connection> {
if let Err(err) = self.auth().await {
let _ = self.stream.shutdown().await;
return Err(err);
}
let req = match Request::rebuild_from_stream(&mut self.stream).await {
Ok(req) => req,
Err(err) => {
let resp = Response::new(Reply::GeneralFailure, Address::unspecified());
resp.write_to(&mut self.stream).await?;
let _ = self.stream.shutdown().await;
return Err(err);
}
};
match req.command {
Command::UdpAssociate => Ok(Connection::UdpAssociate(
UdpAssociate::<associate::NeedReply>::new(self.stream),
req.address,
)),
Command::Bind => Ok(Connection::Bind(Bind::<bind::NeedFirstReply>::new(self.stream), req.address)),
Command::Connect => Ok(Connection::Connect(Connect::<connect::NeedReply>::new(self.stream), req.address)),
}
}
#[inline]
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
self.stream.local_addr()
}
#[inline]
pub fn peer_addr(&self) -> std::io::Result<SocketAddr> {
self.stream.peer_addr()
}
#[inline]
pub async fn shutdown(&mut self) -> std::io::Result<()> {
self.stream.shutdown().await
}
#[inline]
async fn auth(&mut self) -> std::io::Result<()> {
let request = HandshakeRequest::rebuild_from_stream(&mut self.stream).await?;
if let Some(method) = self.evaluate_request(&request) {
let response = HandshakeResponse::new(method);
response.write_to_stream(&mut self.stream).await?;
self.auth.execute(&mut self.stream).await
} else {
let response = HandshakeResponse::new(AuthMethod::NoAcceptableMethods);
response.write_to_stream(&mut self.stream).await?;
let err = "No available handshake method provided by client";
Err(std::io::Error::new(std::io::ErrorKind::Unsupported, err))
}
}
fn evaluate_request(&self, req: &HandshakeRequest) -> Option<AuthMethod> {
let method = self.auth.auth_method();
req.methods.iter().find(|&&m| m == method).copied()
}
}
impl std::fmt::Debug for IncomingConnection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("IncomingConnection").field("stream", &self.stream).finish()
}
}
#[derive(Debug)]
pub enum Connection {
UdpAssociate(UdpAssociate<associate::NeedReply>, Address),
Bind(Bind<bind::NeedFirstReply>, Address),
Connect(Connect<connect::NeedReply>, Address),
}