Skip to main content

topiq_core/
subscription.rs

1use std::fmt;
2
3use tokio::sync::mpsc;
4
5use crate::message::Message;
6
7/// Unique identifier for a subscription within a session.
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
9pub struct SubscriptionId(pub u64);
10
11impl fmt::Display for SubscriptionId {
12    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13        write!(f, "sub:{}", self.0)
14    }
15}
16
17/// A tagged message destined for a specific subscription.
18///
19/// When multiple subscriptions share a session channel, this tag
20/// lets the session map messages back to the correct subscription id.
21pub type TaggedMessage = (SubscriptionId, Message);
22
23/// Sender half that tags messages with a subscription id before sending.
24#[derive(Debug, Clone)]
25pub struct SubscriptionSender {
26    sid: SubscriptionId,
27    tx: mpsc::Sender<TaggedMessage>,
28}
29
30impl SubscriptionSender {
31    pub fn new(sid: SubscriptionId, tx: mpsc::Sender<TaggedMessage>) -> Self {
32        Self { sid, tx }
33    }
34
35    pub fn sid(&self) -> SubscriptionId {
36        self.sid
37    }
38
39    /// Try to send a message without blocking. Returns false if the channel is full.
40    pub fn try_send(&self, msg: Message) -> bool {
41        self.tx.try_send((self.sid, msg)).is_ok()
42    }
43
44    pub async fn send(&self, msg: Message) -> bool {
45        self.tx.send((self.sid, msg)).await.is_ok()
46    }
47
48    /// Check if the receiving end has been dropped.
49    pub fn is_closed(&self) -> bool {
50        self.tx.is_closed()
51    }
52}
53
54#[cfg(test)]
55mod tests {
56    use bytes::Bytes;
57
58    use super::*;
59    use crate::topic::Subject;
60
61    #[tokio::test]
62    async fn subscription_sender_tags_messages() {
63        let (tx, mut rx) = mpsc::channel(8);
64        let sender = SubscriptionSender::new(SubscriptionId(42), tx);
65
66        let msg = Message::new(Subject::new("test").unwrap(), Bytes::from("data"));
67        assert!(sender.try_send(msg));
68
69        let (sid, received) = rx.recv().await.unwrap();
70        assert_eq!(sid, SubscriptionId(42));
71        assert_eq!(received.payload, Bytes::from("data"));
72    }
73
74    #[tokio::test]
75    async fn sender_detects_closed_channel() {
76        let (tx, rx) = mpsc::channel::<TaggedMessage>(1);
77        let sender = SubscriptionSender::new(SubscriptionId(1), tx);
78        drop(rx);
79        assert!(sender.is_closed());
80    }
81
82    #[test]
83    fn subscription_id_display() {
84        assert_eq!(SubscriptionId(7).to_string(), "sub:7");
85    }
86}