1use std::{
2 any::{Any, TypeId}, default, future::Future, pin::Pin, sync::Arc
3};
4use tokio::io::{
5 AsyncBufReadExt,
6 AsyncWriteExt,
7 BufReader,
8 BufWriter,
9 ReadHalf,
10 WriteHalf,
11};
12use crate::{app::{middleware::{AsyncMiddleware, AsyncMiddlewareChain}, urls::{PathPattern, Url}}, connection::{Connection, Rx}, extensions::ParamsClone};
13use super::application::App;
14
15type TestFn = fn(&[u8]) -> bool;
16
17type HandlerFn<R: Rx> =
18 fn(Arc<App>, Arc<Url<R>>, BufReader<ReadHalf<Connection>>, BufWriter<WriteHalf<Connection>>)
19 -> Pin<Box<dyn Future<Output = ()> + Send>>;
20
21struct ProtocolHandler<R: Rx> {
25 root_handler: Arc<Url<R>>,
26 middlewares: AsyncMiddlewareChain<R>,
27}
28
29impl<R: Rx> ProtocolHandler<R> {
30 pub fn new(
31 root_handler: Arc<Url<R>>,
32 middlewares: AsyncMiddlewareChain<R>,
33 ) -> Self {
34 Self {
35 root_handler,
36 middlewares,
37 }
38 }
39}
40
41pub trait ProtocolHandlerTrait: Send + Sync {
42 fn test(&self, buf: &[u8]) -> bool;
46
47 fn handle(
51 &self,
52 app: Arc<App>,
53 reader: BufReader<ReadHalf<Connection>>,
54 writer: BufWriter<WriteHalf<Connection>>,
55 ) -> Pin<Box<dyn Future<Output = ()> + Send>>;
56
57 fn as_any(&self) -> &dyn Any;
59
60 fn as_any_mut(&mut self) -> &mut dyn Any;
62}
63
64impl<R: Rx + 'static> ProtocolHandlerTrait for ProtocolHandler<R> {
65 fn test(&self, buf: &[u8]) -> bool {
66 R::test_protocol(buf)
67 }
68
69 fn handle(
70 &self,
71 app: Arc<App>,
72 reader: BufReader<ReadHalf<Connection>>,
73 writer: BufWriter<WriteHalf<Connection>>,
74 ) -> Pin<Box<dyn Future<Output = ()> + Send>> {
75 let root_handler = self.root_handler.clone();
76 Box::pin(async move {
77 R::process(app, root_handler, reader, writer).await;
78 })
79 }
80
81 fn as_any(&self) -> &dyn Any {
82 self
83 }
84
85 fn as_any_mut(&mut self) -> &mut dyn Any {
86 self
87 }
88}
89
90pub struct ProtocolRegistry {
95 handlers: Vec<Arc<dyn ProtocolHandlerTrait>>,
97}
98
99impl ProtocolRegistry {
100 pub fn new() -> Self {
102 Self {
103 handlers: Vec::new(),
104 }
105 }
106
107 pub fn register<R: Rx + 'static>(&mut self, root_handler: Arc<Url<R>>, middleware_chain: AsyncMiddlewareChain<R>) {
111 self.handlers.push(Arc::new(ProtocolHandler::new(root_handler, middleware_chain)));
112 }
113
114 pub async fn run_multi(&self, app: Arc<App>, conn: Connection) {
122 let (read_half, write_half) = conn.split();
124 let mut reader = BufReader::new(read_half);
125 let mut writer = BufWriter::new(write_half);
126
127 let buf = reader.fill_buf().await.unwrap_or(&[]);
129 let n = buf.len();
130
131 for handler in &self.handlers {
133 if handler.test(&buf[..n]) {
134 handler.handle(app.clone(), reader, writer).await;
136 return;
137 }
138 }
139
140 let _ = writer.shutdown().await;
142 }
143}
144
145pub enum ProtocolRegistryKind {
149 Single(Arc<dyn ProtocolHandlerTrait>),
151 Multi(ProtocolRegistry),
153}
154
155
156pub struct ProtocolHandlerBuilder<R: Rx + 'static> {
157 url: Arc<Url<R>>,
158 middlewares: Vec<Arc<dyn AsyncMiddleware<R>>>,
159}
160
161impl<R: Rx> ProtocolHandlerBuilder<R> {
162 pub fn new() -> Self {
163 Self {
164 url: Arc::new(Url::default()),
165 middlewares: Vec::new(),
166 }
167 }
168
169 pub fn with_default_middlewares(mut self) -> Self {
170 self.middlewares = Self::default_middlewares();
171 self
172 }
173
174 pub fn default_middlewares() -> Vec<Arc<dyn AsyncMiddleware<R>>> {
175 vec![
176 ]
178 }
179
180 pub fn set_url(mut self, url: Arc<Url<R>>) -> Self {
181 self.url = url;
182 self
183 }
184
185 pub fn append_middleware<M>(mut self) -> Self
187 where
188 M: AsyncMiddleware<R> + Default + 'static,
189 {
190 self.middlewares.push(Arc::new(M::default()));
191 self
192 }
193
194 pub fn prepend_middleware<M>(mut self) -> Self
196 where
197 M: AsyncMiddleware<R> + Default + 'static,
198 {
199 self.middlewares.insert(0, Arc::new(M::default()));
200 self
201 }
202
203 pub fn remove_middleware<M>(mut self) -> Self
204 where
205 M: 'static,
206 {
207 self.middlewares.retain(|m| {
208 m.as_any().type_id() != TypeId::of::<M>()
209 });
210 self
211 }
212
213 pub fn build(self) -> Arc<dyn ProtocolHandlerTrait> {
214 Arc::new(ProtocolHandler::new(self.url, self.middlewares))
215 }
216}
217
218pub struct ProtocolRegistryBuilder {
219 handlers: Vec<Arc<dyn ProtocolHandlerTrait>>,
220}
221
222impl ProtocolRegistryBuilder {
223 pub fn new() -> Self {
224 Self { handlers: Vec::new() }
225 }
226
227 pub fn protocol<R: Rx>(mut self, builder: ProtocolHandlerBuilder<R>) -> Self {
228 self.handlers.push(builder.build());
229 self
230 }
231
232 pub fn build(self) -> ProtocolRegistryKind {
233 match self.handlers.len() {
234 1 => ProtocolRegistryKind::Single(self.handlers.into_iter().next().unwrap()) ,
236 _ => ProtocolRegistryKind::Multi(ProtocolRegistry{handlers: self.handlers}),
237 }
238 }
239}
240
241impl ProtocolRegistryKind {
242 pub fn single<R: Rx + 'static>(root_handler: Arc<Url<R>>, middlewares: AsyncMiddlewareChain<R>) -> Self {
246 ProtocolRegistryKind::Single(Arc::new(ProtocolHandler::new(root_handler, middlewares)))
247 }
248
249 pub fn multi(registry: ProtocolRegistry) -> Self {
251 ProtocolRegistryKind::Multi(registry)
252 }
253
254 pub async fn run(&self, app: Arc<App>, conn: Connection) {
259 match self {
260 ProtocolRegistryKind::Single(handler) => {
261 let (read_half, write_half) = conn.split();
262 let reader = BufReader::new(read_half);
263 let writer = BufWriter::new(write_half);
264 handler.handle(app, reader, writer).await;
265 }
266 ProtocolRegistryKind::Multi(registry) => {
267 registry.run_multi(app, conn).await;
269 }
270 }
271 }
272
273 pub fn url<R: Rx + 'static>(&self) -> Option<Arc<Url<R>>> {
276 match self {
277 ProtocolRegistryKind::Single(handler) => {
278 handler
279 .as_any()
280 .downcast_ref::<ProtocolHandler<R>>()
281 .map(|ph| ph.root_handler.clone())
282 }
283 ProtocolRegistryKind::Multi(registry) => {
284 for handler in ®istry.handlers {
285 if let Some(ph) = handler.as_any().downcast_ref::<ProtocolHandler<R>>() {
286 return Some(ph.root_handler.clone());
287 }
288 }
289 None
290 }
291 }
292 }
293
294 pub fn middlewares<R: Rx + 'static>(&self) -> Option<AsyncMiddlewareChain<R>> {
297 match self {
298 ProtocolRegistryKind::Single(handler) => {
299 handler
300 .as_any()
301 .downcast_ref::<ProtocolHandler<R>>()
302 .map(|ph| ph.middlewares.clone())
303 }
304 ProtocolRegistryKind::Multi(registry) => {
305 for handler in ®istry.handlers {
306 if let Some(ph) = handler.as_any().downcast_ref::<ProtocolHandler<R>>() {
307 return Some(ph.middlewares.clone());
308 }
309 }
310 None
311 }
312 }
313 }
314
315 pub fn lit_url<R: Rx + 'static, T: Into<String>>(
319 &self,
320 url: T,
321 ) -> Result<Arc<Url<R>>, String> {
322 let url = url.into();
323 println!("Adding url: {}", url);
324 match self.url::<R>()
325 .map(|root| {
326 root.clone()
327 .literal_url(
328 &url,
329 None,
330 self.middlewares::<R>().unwrap_or(vec![]),
331 ParamsClone::default()
332 )
333 })
334 {
335 Some(Ok(url)) => Ok(url),
336 Some(Err(e)) => Err(e),
337 None => Err("Protocol Not Found".to_string()),
338 }
339 }
340
341 pub fn reg_from<R: Rx + 'static>(
342 &self,
343 segments: &[PathPattern]
344 ) -> Result<Arc<Url<R>>, String> {
345 match self.url::<R>()
346 .map(|root| {
347 let mut current = root.clone();
348 for seg in segments {
349 current = current.get_child_or_create(seg.clone())?;
350 current.set_middlewares(self.middlewares::<R>().unwrap_or(vec![]));
351 }
352 Ok::<Arc<Url<R>>, String>(current)
353 }) {
354 Some(Ok(url)) => Ok(url),
355 Some(Err(e)) => Err(e),
356 None => Err("Protocol Not Found".to_string())
357
358 }
359 }
365}