starberry_core/app/
protocol.rs

1use std::{
2    any::{Any, TypeId}, default, future::Future, pin::Pin, sync::Arc
3};
4use tokio::io::{
5    AsyncBufReadExt,
6    AsyncWriteExt,
7    BufReader,
8    BufWriter,
9    ReadHalf,
10    WriteHalf,
11};
12use crate::{app::{middleware::{AsyncMiddleware, AsyncMiddlewareChain}, urls::{PathPattern, Url}}, connection::{Connection, Rx}, extensions::ParamsClone};
13use super::application::App; 
14
15type TestFn = fn(&[u8]) -> bool;
16
17type HandlerFn<R: Rx> =
18    fn(Arc<App>, Arc<Url<R>>, BufReader<ReadHalf<Connection>>, BufWriter<WriteHalf<Connection>>)
19        -> Pin<Box<dyn Future<Output = ()> + Send>>;
20
21/// Internal struct tying a single protocol's detection function (`test`)
22/// to its processing function (`handle`).
23/// Concrete handler for a specific protocol
24struct ProtocolHandler<R: Rx> {
25    root_handler: Arc<Url<R>>, 
26    middlewares: AsyncMiddlewareChain<R>, 
27} 
28
29impl<R: Rx> ProtocolHandler<R> { 
30    pub fn new(
31        root_handler: Arc<Url<R>>,
32        middlewares: AsyncMiddlewareChain<R>,
33    ) -> Self {
34        Self { 
35            root_handler,
36            middlewares,    
37        }
38    }
39}
40
41pub trait ProtocolHandlerTrait: Send + Sync {
42    /// A function pointer to inspect the first bytes of a connection
43    /// and decide whether a protocol should handle it.
44    /// Returns `true` if the given buffer matches the protocol signature.
45    fn test(&self, buf: &[u8]) -> bool; 
46
47    /// A function pointer that, given the `App` and split I/O halves wrapped
48    /// in buffered reader/writer, returns a boxed `Future` that drives the
49    /// protocol handler to completion.
50    fn handle(
51        &self,
52        app: Arc<App>,
53        reader: BufReader<ReadHalf<Connection>>,
54        writer: BufWriter<WriteHalf<Connection>>,
55    ) -> Pin<Box<dyn Future<Output = ()> + Send>>; 
56
57    /// Allows downcasting to the concrete `ProtocolHandler<R>` type.
58    fn as_any(&self) -> &dyn Any; 
59
60    /// Like `as_any`, but for mutable downcasting.
61    fn as_any_mut(&mut self) -> &mut dyn Any;
62} 
63
64impl<R: Rx + 'static> ProtocolHandlerTrait for ProtocolHandler<R> {
65    fn test(&self, buf: &[u8]) -> bool {
66        R::test_protocol(buf)
67    }
68
69    fn handle(
70        &self,
71        app: Arc<App>,
72        reader: BufReader<ReadHalf<Connection>>,
73        writer: BufWriter<WriteHalf<Connection>>,
74    ) -> Pin<Box<dyn Future<Output = ()> + Send>> {
75        let root_handler = self.root_handler.clone();
76        Box::pin(async move {
77            R::process(app, root_handler, reader, writer).await;
78        })
79    } 
80
81    fn as_any(&self) -> &dyn Any {
82        self
83    } 
84
85    fn as_any_mut(&mut self) -> &mut dyn Any {
86        self
87    } 
88} 
89
90/// Registry for multiple protocol handlers
91/// using a simple `Vec<ProtocolHandler>` for O(n) dispatch.
92/// This avoids hash lookups and TypeId overhead, trading for a small
93/// linear scan over handlers in registration order.
94pub struct ProtocolRegistry {
95    /// Ordered list of protocol handlers (test + handle).
96    handlers: Vec<Arc<dyn ProtocolHandlerTrait>>,
97}
98
99impl ProtocolRegistry {
100    /// Construct an empty registry with no protocols registered.
101    pub fn new() -> Self {
102        Self {
103            handlers: Vec::new(),
104        }
105    }
106
107    /// Register a protocol `P` that implements `Rx + 'static`.
108    /// This pushes its `test_protocol` and `process` functions
109    /// onto the `handlers` vector, preserving registration order.
110    pub fn register<R: Rx + 'static>(&mut self, root_handler: Arc<Url<R>>, middleware_chain: AsyncMiddlewareChain<R>) {
111        self.handlers.push(Arc::new(ProtocolHandler::new(root_handler, middleware_chain)));
112    } 
113
114    /// Attempt to detect and run one of the registered protocols.
115    ///
116    /// Steps:
117    /// 1. Split the `Connection` into read/write halves.
118    /// 2. Peek at the initial bytes without consuming them.
119    /// 3. Iterate in registration order and run the first matching protocol.
120    /// 4. If no match is found, cleanly shutdown the write half.
121    pub async fn run_multi(&self, app: Arc<App>, conn: Connection) {
122        // 1) split into raw halves
123        let (read_half, write_half) = conn.split();
124        let mut reader = BufReader::new(read_half);
125        let mut writer = BufWriter::new(write_half);
126
127        // 2) peek at buffered data without consuming
128        let buf = reader.fill_buf().await.unwrap_or(&[]);
129        let n = buf.len();
130
131        // 3) test each registered protocol in order
132        for handler in &self.handlers {
133            if handler.test(&buf[..n]) {
134                // 4) if test passes, dispatch to this protocol's handler
135                handler.handle(app.clone(), reader, writer).await;
136                return;
137            }
138        }
139
140        // 5) no protocol matched → close the connection gracefully
141        let _ = writer.shutdown().await;
142    }
143}
144
145/// Enum used in `App` to select between single‐protocol mode
146/// (direct dispatch to one protocol P) and multi‐protocol mode
147/// (detection loop over a `ProtocolRegistry`).
148pub enum ProtocolRegistryKind {
149    /// Single‐protocol mode. Stores only the handler function for zero‐overhead dispatch.
150    Single(Arc<dyn ProtocolHandlerTrait>), 
151    /// Multi‐protocol mode. Contains a full `ProtocolRegistry`.
152    Multi(ProtocolRegistry),
153} 
154
155
156pub struct ProtocolHandlerBuilder<R: Rx + 'static> {
157    url: Arc<Url<R>>,
158    middlewares: Vec<Arc<dyn AsyncMiddleware<R>>>,
159}
160
161impl<R: Rx> ProtocolHandlerBuilder<R> {
162    pub fn new() -> Self {
163        Self {
164            url: Arc::new(Url::default()),
165            middlewares: Vec::new(), 
166        }
167    }
168
169    pub fn with_default_middlewares(mut self) -> Self {
170        self.middlewares = Self::default_middlewares();
171        self
172    }
173
174    pub fn default_middlewares() -> Vec<Arc<dyn AsyncMiddleware<R>>> {
175        vec![
176            // Add your default middleware implementations here
177        ]
178    } 
179
180    pub fn set_url(mut self, url: Arc<Url<R>>) -> Self { 
181        self.url = url; 
182        self 
183    }
184
185    // Append a middleware instance created by T to the end of the vector.
186    pub fn append_middleware<M>(mut self) -> Self
187    where
188        M: AsyncMiddleware<R> + Default + 'static,
189    {
190        self.middlewares.push(Arc::new(M::default()));
191        self
192    }
193
194    // Insert a middleware instance created by T at the beginning of the vector.
195    pub fn prepend_middleware<M>(mut self) -> Self
196    where
197        M: AsyncMiddleware<R> + Default + 'static,
198    {
199        self.middlewares.insert(0, Arc::new(M::default()));
200        self
201    }
202
203    pub fn remove_middleware<M>(mut self) -> Self
204    where
205        M: 'static,
206    {
207        self.middlewares.retain(|m| {
208            m.as_any().type_id() != TypeId::of::<M>()
209        });
210        self
211    }
212
213    pub fn build(self) -> Arc<dyn ProtocolHandlerTrait> {
214        Arc::new(ProtocolHandler::new(self.url, self.middlewares))
215    }
216}
217
218pub struct ProtocolRegistryBuilder {
219    handlers: Vec<Arc<dyn ProtocolHandlerTrait>>,
220}
221
222impl ProtocolRegistryBuilder {
223    pub fn new() -> Self {
224        Self { handlers: Vec::new() }
225    }
226
227    pub fn protocol<R: Rx>(mut self, builder: ProtocolHandlerBuilder<R>) -> Self {
228        self.handlers.push(builder.build());
229        self
230    }
231
232    pub fn build(self) -> ProtocolRegistryKind {
233        match self.handlers.len() {
234            // 0 => ProtocolRegistryKind::empty(), 
235            1 => ProtocolRegistryKind::Single(self.handlers.into_iter().next().unwrap()) ,
236            _ => ProtocolRegistryKind::Multi(ProtocolRegistry{handlers: self.handlers}),
237        }
238    }
239} 
240
241impl ProtocolRegistryKind {
242    /// Construct a `Single` variant for protocol `P`, avoiding any
243    /// loops or lookups. This is the fastest path when you know at
244    /// compile time which protocol to run.
245    pub fn single<R: Rx + 'static>(root_handler: Arc<Url<R>>, middlewares: AsyncMiddlewareChain<R>) -> Self {
246        ProtocolRegistryKind::Single(Arc::new(ProtocolHandler::new(root_handler, middlewares)))
247    } 
248
249    /// Construct a `Multi` variant from an existing registry.
250    pub fn multi(registry: ProtocolRegistry) -> Self {
251        ProtocolRegistryKind::Multi(registry)
252    } 
253
254    /// Entry point: dispatch the connection according to the selected mode.
255    ///
256    /// - `Single` mode directly invokes the stored `handler`.
257    /// - `Multi` mode calls `run_multi` on the inner registry.
258    pub async fn run(&self, app: Arc<App>, conn: Connection) {
259        match self {
260            ProtocolRegistryKind::Single(handler) => {
261                let (read_half, write_half) = conn.split();
262                let reader = BufReader::new(read_half);
263                let writer = BufWriter::new(write_half);
264                handler.handle(app, reader, writer).await;
265            } 
266            ProtocolRegistryKind::Multi(registry) => {
267                // Use detection logic for multiple protocols.
268                registry.run_multi(app, conn).await;
269            }
270        }
271    } 
272
273    /// Retrieve the root Url<R> for a given protocol type `R`.
274    /// Returns `Some(Arc<Url<R>>)` if a handler of type `R` is present.
275    pub fn url<R: Rx + 'static>(&self) -> Option<Arc<Url<R>>> {
276        match self {
277            ProtocolRegistryKind::Single(handler) => {
278                handler
279                    .as_any()
280                    .downcast_ref::<ProtocolHandler<R>>()
281                    .map(|ph| ph.root_handler.clone())
282            }
283            ProtocolRegistryKind::Multi(registry) => {
284                for handler in &registry.handlers {
285                    if let Some(ph) = handler.as_any().downcast_ref::<ProtocolHandler<R>>() {
286                        return Some(ph.root_handler.clone());
287                    }
288                }
289                None
290            }
291        }
292    } 
293
294    /// Retrieve the Middleware<R> for a given protocol type `R`.
295    /// Returns `Some(AsymcMiddlewareChain<R>)` if a handler of type `R` is present.
296    pub fn middlewares<R: Rx + 'static>(&self) -> Option<AsyncMiddlewareChain<R>> {
297        match self {
298            ProtocolRegistryKind::Single(handler) => {
299                handler
300                    .as_any()
301                    .downcast_ref::<ProtocolHandler<R>>()
302                    .map(|ph| ph.middlewares.clone())
303            }
304            ProtocolRegistryKind::Multi(registry) => {
305                for handler in &registry.handlers {
306                    if let Some(ph) = handler.as_any().downcast_ref::<ProtocolHandler<R>>() {
307                        return Some(ph.middlewares.clone());
308                    }
309                }
310                None
311            }
312        }
313    } 
314
315    /// This function add a new url to the app. It will be added to the root url 
316    /// # Arguments 
317    /// * `url` - The url to add. It should be a string.
318    pub fn lit_url<R: Rx + 'static, T: Into<String>>(
319        &self, 
320        url: T, 
321    ) -> Result<Arc<Url<R>>, String> { 
322        let url = url.into(); 
323        println!("Adding url: {}", url); 
324        match self.url::<R>() 
325            .map(|root| {  
326                root.clone()
327                .literal_url(
328                    &url, 
329                    None, 
330                    self.middlewares::<R>().unwrap_or(vec![]), 
331                    ParamsClone::default()
332                )
333            }) 
334        {
335            Some(Ok(url)) => Ok(url),
336            Some(Err(e)) => Err(e),
337            None => Err("Protocol Not Found".to_string()), 
338        }
339    } 
340
341    pub fn reg_from<R: Rx + 'static>(
342        &self,
343        segments: &[PathPattern]
344    ) -> Result<Arc<Url<R>>, String> { 
345        match self.url::<R>()
346            .map(|root| { 
347                let mut current = root.clone(); 
348                for seg in segments { 
349                    current = current.get_child_or_create(seg.clone())?; 
350                    current.set_middlewares(self.middlewares::<R>().unwrap_or(vec![])); 
351                } 
352                Ok::<Arc<Url<R>>, String>(current) 
353            }) {  
354                Some(Ok(url)) => Ok(url), 
355                Some(Err(e)) => Err(e), 
356                None => Err("Protocol Not Found".to_string()) 
357
358        }
359        // for seg in segments { 
360        //     current = current.get_child_or_create(seg.clone())?; 
361        //     current.set_middlewares((*self.middlewares).clone()); 
362        // }
363        // Ok(current)
364    }
365}