Skip to main content

shaperail_core/
channel.rs

1use serde::{Deserialize, Serialize};
2
3use crate::AuthRule;
4
5/// Definition of a WebSocket channel, parsed from a `.channel.yaml` file.
6///
7/// ```yaml
8/// channel: notifications
9/// auth: [member, admin]
10/// rooms: true
11/// hooks:
12///   on_connect: [log_connect]
13///   on_disconnect: [log_disconnect]
14///   on_message: [validate_message]
15/// ```
16#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
17#[serde(deny_unknown_fields)]
18pub struct ChannelDefinition {
19    /// Channel name (e.g., "notifications").
20    pub channel: String,
21
22    /// Authentication rule for connecting to this channel.
23    #[serde(default, skip_serializing_if = "Option::is_none")]
24    pub auth: Option<AuthRule>,
25
26    /// Whether this channel supports room subscriptions.
27    #[serde(default)]
28    pub rooms: bool,
29
30    /// Lifecycle hook configuration.
31    #[serde(default, skip_serializing_if = "Option::is_none")]
32    pub hooks: Option<ChannelHooks>,
33}
34
35/// Lifecycle hooks for a WebSocket channel.
36#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
37#[serde(deny_unknown_fields)]
38pub struct ChannelHooks {
39    /// Hooks executed when a client connects.
40    #[serde(default, skip_serializing_if = "Option::is_none")]
41    pub on_connect: Option<Vec<String>>,
42
43    /// Hooks executed when a client disconnects.
44    #[serde(default, skip_serializing_if = "Option::is_none")]
45    pub on_disconnect: Option<Vec<String>>,
46
47    /// Hooks executed when a client sends a message.
48    #[serde(default, skip_serializing_if = "Option::is_none")]
49    pub on_message: Option<Vec<String>>,
50}
51
52/// Client-to-server WebSocket message format.
53#[derive(Debug, Clone, Serialize, Deserialize)]
54#[serde(tag = "action", rename_all = "lowercase")]
55pub enum WsClientMessage {
56    /// Subscribe to a room within the channel.
57    Subscribe { room: String },
58    /// Unsubscribe from a room.
59    Unsubscribe { room: String },
60    /// Send a message to a room.
61    Message {
62        room: String,
63        data: serde_json::Value,
64    },
65    /// Respond to server ping (client pong).
66    Pong,
67}
68
69/// Server-to-client WebSocket message format.
70#[derive(Debug, Clone, Serialize, Deserialize)]
71#[serde(tag = "type", rename_all = "lowercase")]
72pub enum WsServerMessage {
73    /// Broadcast data to subscribed clients.
74    Broadcast {
75        room: String,
76        event: String,
77        data: serde_json::Value,
78    },
79    /// Acknowledgement of subscription.
80    Subscribed { room: String },
81    /// Acknowledgement of unsubscription.
82    Unsubscribed { room: String },
83    /// Error message.
84    Error { message: String },
85    /// Server ping for heartbeat.
86    Ping,
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92
93    #[test]
94    fn channel_definition_minimal() {
95        let yaml = r#"{"channel": "notifications"}"#;
96        let cd: ChannelDefinition = serde_json::from_str(yaml).unwrap();
97        assert_eq!(cd.channel, "notifications");
98        assert!(cd.auth.is_none());
99        assert!(!cd.rooms);
100        assert!(cd.hooks.is_none());
101    }
102
103    #[test]
104    fn channel_definition_full() {
105        let json = r#"{
106            "channel": "updates",
107            "auth": ["member", "admin"],
108            "rooms": true,
109            "hooks": {
110                "on_connect": ["log_connect"],
111                "on_disconnect": ["log_disconnect"],
112                "on_message": ["validate_message"]
113            }
114        }"#;
115        let cd: ChannelDefinition = serde_json::from_str(json).unwrap();
116        assert_eq!(cd.channel, "updates");
117        assert!(cd.rooms);
118        let hooks = cd.hooks.as_ref().unwrap();
119        assert_eq!(hooks.on_connect.as_ref().unwrap(), &["log_connect"]);
120        assert_eq!(hooks.on_disconnect.as_ref().unwrap(), &["log_disconnect"]);
121        assert_eq!(hooks.on_message.as_ref().unwrap(), &["validate_message"]);
122    }
123
124    #[test]
125    fn channel_definition_serde_roundtrip() {
126        let cd = ChannelDefinition {
127            channel: "chat".to_string(),
128            auth: Some(AuthRule::Roles(vec!["member".to_string()])),
129            rooms: true,
130            hooks: Some(ChannelHooks {
131                on_connect: Some(vec!["log_connect".to_string()]),
132                on_disconnect: None,
133                on_message: None,
134            }),
135        };
136        let json = serde_json::to_string(&cd).unwrap();
137        let back: ChannelDefinition = serde_json::from_str(&json).unwrap();
138        assert_eq!(cd, back);
139    }
140
141    #[test]
142    fn ws_client_message_subscribe() {
143        let json = r#"{"action": "subscribe", "room": "org:123"}"#;
144        let msg: WsClientMessage = serde_json::from_str(json).unwrap();
145        match msg {
146            WsClientMessage::Subscribe { room } => assert_eq!(room, "org:123"),
147            _ => panic!("Expected Subscribe"),
148        }
149    }
150
151    #[test]
152    fn ws_client_message_unsubscribe() {
153        let json = r#"{"action": "unsubscribe", "room": "org:123"}"#;
154        let msg: WsClientMessage = serde_json::from_str(json).unwrap();
155        match msg {
156            WsClientMessage::Unsubscribe { room } => assert_eq!(room, "org:123"),
157            _ => panic!("Expected Unsubscribe"),
158        }
159    }
160
161    #[test]
162    fn ws_client_message_message() {
163        let json = r#"{"action": "message", "room": "org:123", "data": {"text": "hello"}}"#;
164        let msg: WsClientMessage = serde_json::from_str(json).unwrap();
165        match msg {
166            WsClientMessage::Message { room, data } => {
167                assert_eq!(room, "org:123");
168                assert_eq!(data["text"], "hello");
169            }
170            _ => panic!("Expected Message"),
171        }
172    }
173
174    #[test]
175    fn ws_client_message_pong() {
176        let json = r#"{"action": "pong"}"#;
177        let msg: WsClientMessage = serde_json::from_str(json).unwrap();
178        assert!(matches!(msg, WsClientMessage::Pong));
179    }
180
181    #[test]
182    fn ws_server_message_broadcast() {
183        let msg = WsServerMessage::Broadcast {
184            room: "org:123".to_string(),
185            event: "user.created".to_string(),
186            data: serde_json::json!({"id": "abc"}),
187        };
188        let json = serde_json::to_string(&msg).unwrap();
189        assert!(json.contains("\"type\":\"broadcast\""));
190        assert!(json.contains("\"room\":\"org:123\""));
191    }
192
193    #[test]
194    fn ws_server_message_subscribed() {
195        let msg = WsServerMessage::Subscribed {
196            room: "org:123".to_string(),
197        };
198        let json = serde_json::to_string(&msg).unwrap();
199        assert!(json.contains("\"type\":\"subscribed\""));
200    }
201
202    #[test]
203    fn ws_server_message_error() {
204        let msg = WsServerMessage::Error {
205            message: "bad request".to_string(),
206        };
207        let json = serde_json::to_string(&msg).unwrap();
208        assert!(json.contains("\"type\":\"error\""));
209        assert!(json.contains("bad request"));
210    }
211
212    #[test]
213    fn ws_server_message_ping() {
214        let msg = WsServerMessage::Ping;
215        let json = serde_json::to_string(&msg).unwrap();
216        assert!(json.contains("\"type\":\"ping\""));
217    }
218}