1use std::io;
2use std::sync::Arc;
3use std::time::Duration;
4
5use crate::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
6use crate::net::TcpListener;
7use crate::v5::{Method, Request, Stream};
8
9const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
10
11pub struct Config<A, H> {
12 auth: A,
13 handler: H,
14 timeout: Duration,
15}
16
17impl<A, H> Config<A, H> {
18 pub fn new(auth: A, handler: H) -> Self {
19 Self {
20 auth,
21 handler,
22 timeout: DEFAULT_TIMEOUT,
23 }
24 }
25
26 pub fn with_timeout(mut self, timeout: Duration) -> Self {
27 self.timeout = timeout;
28 self
29 }
30}
31
32pub struct Server;
34
35impl Server {
36 pub async fn run<H, A>(
37 listener: TcpListener,
38 config: Arc<Config<A, H>>,
39 shutdown_signal: impl Future<Output = ()>,
40 ) -> io::Result<()>
41 where
42 H: Handler + 'static,
43 A: Authenticator + 'static,
44 {
45 tokio::pin!(shutdown_signal);
46
47 loop {
48 tokio::select! {
49 biased;
51
52 _ = &mut shutdown_signal => return Ok(()),
53
54 result = listener.accept() => {
55 let (inner, addr) = match result {
56 Ok(res) => res,
57 Err(_err) => {
58 #[cfg(feature = "tracing")]
59 tracing::error!("Failed to accept connection: {}", _err);
60 continue;
61 }
62 };
63
64 let config = config.clone();
65 tokio::spawn(async move {
66 let mut stream = Stream::with(inner, addr);
67
68 if let Err(_err) = Self::handle_connection(&mut stream, &config).await {
69 #[cfg(feature = "tracing")]
70 tracing::warn!("Connection {} error: {}", addr, _err);
71 }
72 });
73 }
74 }
75 }
76 }
77
78 async fn handle_connection<H, A, S>(
79 stream: &mut Stream<S>,
80 config: &Config<A, H>,
81 ) -> io::Result<()>
82 where
83 H: Handler + 'static,
84 A: Authenticator + 'static,
85 S: AsyncRead + AsyncWrite + Unpin + Send + Sync,
86 {
87 let request = tokio::time::timeout(config.timeout, async {
89 let methods = stream.read_methods().await?;
90 config.auth.auth(stream, methods).await?;
91 stream.read_request().await
92 })
93 .await
94 .map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "Timeout during authentication"))??;
95
96 config.handler.handle(stream, request).await
97 }
98}
99
100pub trait Authenticator: Send + Sync {
102 fn auth<T>(
103 &self,
104 stream: &mut Stream<T>,
105 methods: Vec<Method>,
106 ) -> impl Future<Output = io::Result<()>> + Send
107 where
108 T: AsyncRead + AsyncWrite + Unpin + Send + Sync;
109}
110
111pub trait Handler: Send + Sync {
113 fn handle<T>(
114 &self,
115 stream: &mut Stream<T>,
116 request: Request,
117 ) -> impl Future<Output = io::Result<()>> + Send
118 where
119 T: AsyncRead + AsyncWrite + Unpin + Send + Sync;
120}
121
122pub mod auth {
123 use super::*;
124
125 pub struct NoAuthentication;
126
127 impl Authenticator for NoAuthentication {
128 async fn auth<T>(&self, stream: &mut Stream<T>, _methods: Vec<Method>) -> io::Result<()>
129 where
130 T: AsyncRead + AsyncWrite + Unpin + Send + Sync,
131 {
132 stream.write_auth_method(Method::NoAuthentication).await?;
133 Ok(())
134 }
135 }
136
137 pub struct UserPassword {
138 username: String,
139 password: String,
140 }
141
142 impl UserPassword {
143 pub fn new(username: String, password: String) -> Self {
144 Self { username, password }
145 }
146 }
147
148 impl Authenticator for UserPassword {
149 async fn auth<T>(&self, stream: &mut Stream<T>, methods: Vec<Method>) -> io::Result<()>
150 where
151 T: AsyncRead + AsyncWrite + Unpin + Send + Sync,
152 {
153 if !methods.contains(&Method::UsernamePassword) {
154 return Err(io::Error::new(
155 io::ErrorKind::PermissionDenied,
156 "Username/Password authentication required",
157 ));
158 }
159
160 stream.write_auth_method(Method::UsernamePassword).await?;
161
162 let version = stream.read_u8().await?;
164 if version != 0x01 {
165 return Err(io::Error::new(
166 io::ErrorKind::InvalidData,
167 "Invalid subnegotiation version",
168 ));
169 }
170
171 let ulen = stream.read_u8().await?;
172 let mut username = vec![0; ulen as usize];
173 stream.read_exact(&mut username).await?;
174
175 let plen = stream.read_u8().await?;
176 let mut password = vec![0; plen as usize];
177 stream.read_exact(&mut password).await?;
178
179 if username != self.username.as_bytes() || password != self.password.as_bytes() {
181 stream.write_all(&[0x01, 0x01]).await?;
182 return Err(io::Error::new(
183 io::ErrorKind::PermissionDenied,
184 "Invalid username or password",
185 ));
186 }
187
188 stream.write_all(&[0x01, 0x00]).await?;
189
190 Ok(())
191 }
192 }
193}