whatsapp_rust/handlers/
router.rs

1use super::traits::StanzaHandler;
2use crate::client::Client;
3use std::collections::HashMap;
4use std::sync::Arc;
5use wacore_binary::node::Node;
6
7/// Central router for dispatching XML stanzas to their appropriate handlers.
8///
9/// The router maintains a registry of handlers keyed by XML tag and efficiently
10/// dispatches incoming nodes to the correct handler based on the node's tag.
11pub struct StanzaRouter {
12    /// Map of XML tag -> handler for fast lookups
13    handlers: HashMap<&'static str, Arc<dyn StanzaHandler>>,
14}
15
16impl StanzaRouter {
17    /// Create a new empty router.
18    pub fn new() -> Self {
19        Self {
20            handlers: HashMap::new(),
21        }
22    }
23
24    /// Register a handler for a specific XML tag.
25    ///
26    /// # Arguments
27    /// * `handler` - The handler implementation to register
28    ///
29    /// # Panics
30    /// Panics if a handler is already registered for the same tag to prevent
31    /// accidental overwrites during initialization.
32    pub fn register(&mut self, handler: Arc<dyn StanzaHandler>) {
33        let tag = handler.tag();
34        if self.handlers.insert(tag, handler).is_some() {
35            panic!("Handler for tag '{}' already registered", tag);
36        }
37    }
38
39    /// Dispatch a node to its appropriate handler.
40    ///
41    /// # Arguments
42    /// * `client` - Arc reference to the client instance
43    /// * `node` - The XML node to dispatch
44    ///
45    /// # Returns
46    /// Returns `true` if a handler was found and successfully processed the node,
47    /// `false` if no handler was registered for the node's tag or the handler
48    /// indicated it couldn't process the node.
49    pub async fn dispatch(&self, client: Arc<Client>, node: &Node, cancelled: &mut bool) -> bool {
50        if let Some(handler) = self.handlers.get(node.tag.as_str()) {
51            handler.handle(client, node, cancelled).await
52        } else {
53            false
54        }
55    }
56
57    /// Get the number of registered handlers (useful for testing).
58    pub fn handler_count(&self) -> usize {
59        self.handlers.len()
60    }
61}
62
63impl Default for StanzaRouter {
64    fn default() -> Self {
65        Self::new()
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use super::*;
72    use std::sync::Arc;
73    use wacore_binary::node::Node;
74
75    #[derive(Debug)]
76    struct MockHandler {
77        tag: &'static str,
78        handled: std::sync::atomic::AtomicBool,
79    }
80
81    impl MockHandler {
82        fn new(tag: &'static str) -> Self {
83            Self {
84                tag,
85                handled: std::sync::atomic::AtomicBool::new(false),
86            }
87        }
88
89        fn was_handled(&self) -> bool {
90            self.handled.load(std::sync::atomic::Ordering::SeqCst)
91        }
92    }
93
94    #[async_trait::async_trait]
95    impl StanzaHandler for MockHandler {
96        fn tag(&self) -> &'static str {
97            self.tag
98        }
99
100        async fn handle(
101            &self,
102            _client: Arc<crate::client::Client>,
103            _node: &Node,
104            _cancelled: &mut bool,
105        ) -> bool {
106            self.handled
107                .store(true, std::sync::atomic::Ordering::SeqCst);
108            true
109        }
110    }
111
112    #[test]
113    fn test_router_registration() {
114        let mut router = StanzaRouter::new();
115        let handler = Arc::new(MockHandler::new("test"));
116
117        router.register(handler);
118        assert_eq!(router.handler_count(), 1);
119    }
120
121    #[test]
122    #[should_panic(expected = "Handler for tag 'test' already registered")]
123    fn test_router_double_registration_panics() {
124        let mut router = StanzaRouter::new();
125        let handler1 = Arc::new(MockHandler::new("test"));
126        let handler2 = Arc::new(MockHandler::new("test"));
127
128        router.register(handler1);
129        router.register(handler2); // Should panic
130    }
131
132    #[tokio::test]
133    async fn test_router_dispatch_found() {
134        let mut router = StanzaRouter::new();
135        let handler = Arc::new(MockHandler::new("test"));
136        let handler_ref = handler.clone();
137
138        router.register(handler);
139
140        let node = Node {
141            tag: "test".to_string(),
142            attrs: std::collections::HashMap::new(),
143            content: None,
144        };
145
146        // Create a minimal client for testing with an in-memory database
147        use crate::store::persistence_manager::PersistenceManager;
148
149        let backend = Arc::new(
150            crate::store::sqlite_store::SqliteStore::new(":memory:")
151                .await
152                .unwrap(),
153        ) as Arc<dyn crate::store::traits::Backend>;
154        let pm = PersistenceManager::new(backend).await.unwrap();
155        let (client, _rx) = crate::client::Client::new(Arc::new(pm)).await;
156
157        let mut cancelled = false;
158        let result = router.dispatch(client, &node, &mut cancelled).await;
159
160        assert!(result);
161        assert!(handler_ref.was_handled());
162    }
163
164    #[tokio::test]
165    async fn test_router_dispatch_not_found() {
166        let router = StanzaRouter::new();
167
168        let node = Node {
169            tag: "unknown".to_string(),
170            attrs: std::collections::HashMap::new(),
171            content: None,
172        };
173
174        // Create a minimal client for testing with an in-memory database
175        use crate::store::persistence_manager::PersistenceManager;
176
177        let backend = Arc::new(
178            crate::store::sqlite_store::SqliteStore::new(":memory:")
179                .await
180                .unwrap(),
181        ) as Arc<dyn crate::store::traits::Backend>;
182        let pm = PersistenceManager::new(backend).await.unwrap();
183        let (client, _rx) = crate::client::Client::new(Arc::new(pm)).await;
184
185        let mut cancelled = false;
186        let result = router.dispatch(client, &node, &mut cancelled).await;
187
188        assert!(!result);
189    }
190}