tagged_channels/
lib.rs

1#![warn(clippy::all, missing_docs, nonstandard_style, future_incompatible)]
2#![doc = include_str!("../README.md")]
3
4use std::{
5    collections::{HashMap, HashSet},
6    hash::Hash,
7    sync::Arc,
8};
9
10use parking_lot::Mutex;
11use tokio::sync::mpsc::{self, Receiver, Sender};
12
13type ChannelId = u64;
14
15/// Channels manager
16pub struct TaggedChannels<M, T>(Arc<Mutex<ChannelsInner<M, T>>>);
17
18impl<M, T> Clone for TaggedChannels<M, T> {
19    fn clone(&self) -> Self {
20        Self(Arc::clone(&self.0))
21    }
22}
23
24/// Inner part of the manager
25pub struct ChannelsInner<M, T> {
26    last_id: u64,
27    channels: HashMap<ChannelId, Channel<M, T>>,
28    tags: HashMap<T, HashSet<ChannelId>>,
29}
30
31struct Channel<M, T> {
32    tx: Sender<Arc<M>>,
33    tags: Box<[T]>,
34}
35
36/// A guard to clean up channel resources when the receiver is dropped
37pub struct ChannelGuard<M, T>
38where
39    T: Clone + Eq + Hash + PartialEq,
40{
41    channel_id: ChannelId,
42    manager: TaggedChannels<M, T>,
43}
44
45/// A wrapper around [`Receiver`] to clean up resources on `Drop`
46pub struct GuardedReceiver<M, T>
47where
48    T: Clone + Eq + Hash + PartialEq,
49{
50    rx: Receiver<Arc<M>>,
51    #[allow(dead_code)]
52    guard: ChannelGuard<M, T>,
53}
54
55impl<M, T> TaggedChannels<M, T>
56where
57    T: Clone + Eq + Hash + PartialEq,
58{
59    /// Creates a new channels manager
60    pub fn new() -> Self {
61        Default::default()
62    }
63
64    /// Creates a new channel and returns its events receiver
65    pub fn create_channel(&self, tags: impl Into<Vec<T>>) -> GuardedReceiver<M, T> {
66        let tags = tags.into();
67        let (tx, rx) = mpsc::channel::<Arc<M>>(1);
68        let channel = Channel {
69            tx,
70            tags: tags.clone().into_boxed_slice(),
71        };
72
73        let mut inner = self.0.lock();
74        let channel_id = inner.last_id.overflowing_add(1).0;
75        inner.channels.insert(channel_id, channel);
76        for tag in tags {
77            inner
78                .tags
79                .entry(tag)
80                .and_modify(|set| {
81                    set.insert(channel_id);
82                })
83                .or_insert(HashSet::from([channel_id]));
84        }
85        inner.last_id = channel_id;
86
87        let guard = ChannelGuard::new(channel_id, self.clone());
88        GuardedReceiver { rx, guard }
89    }
90
91    /// Returns number of active channels
92    pub fn num_connections(&self) -> usize {
93        self.0.lock().channels.len()
94    }
95
96    /// Sends the `message` to all channels tagged by the `tag`
97    pub async fn send_by_tag(&self, tag: &T, message: M) {
98        let msg = Arc::new(message);
99        for rx in self.tagged_senders(tag) {
100            rx.send(Arc::clone(&msg)).await.ok();
101        }
102    }
103
104    /// Send the `message` to everyone
105    pub async fn broadcast(&self, message: M) {
106        let msg = Arc::new(message);
107        for rx in self.all_senders() {
108            rx.send(Arc::clone(&msg)).await.ok();
109        }
110    }
111
112    /// Returns tags of all currently connected channels
113    pub fn connected_tags(&self) -> Vec<T> {
114        self.0.lock().tags.keys().cloned().collect()
115    }
116
117    /// Removes the channel from the manager
118    fn remove_channel(&self, channel_id: &ChannelId) {
119        let mut inner = self.0.lock();
120        if let Some(channel) = inner.channels.remove(channel_id) {
121            for tag in channel.tags.iter() {
122                inner.remove_channel_tag(channel_id, tag);
123            }
124        }
125    }
126
127    /// Returns senders by tag
128    fn tagged_senders(&self, tag: &T) -> Vec<Sender<Arc<M>>> {
129        let inner = self.0.lock();
130        inner
131            .tags
132            .get(tag)
133            .map(|ids| ids.iter().filter_map(|id| inner.clone_tx(id)).collect())
134            .unwrap_or_default()
135    }
136
137    /// Returns all senders
138    fn all_senders(&self) -> Vec<Sender<Arc<M>>> {
139        self.0
140            .lock()
141            .channels
142            .values()
143            .map(|c| c.tx.clone())
144            .collect()
145    }
146}
147
148impl<M, T> Default for TaggedChannels<M, T> {
149    fn default() -> Self {
150        let inner = ChannelsInner {
151            last_id: 0,
152            channels: HashMap::new(),
153            tags: HashMap::new(),
154        };
155        Self(Arc::new(Mutex::new(inner)))
156    }
157}
158
159impl<M, T> ChannelGuard<M, T>
160where
161    T: Clone + Eq + Hash + PartialEq,
162{
163    fn new(channel_id: ChannelId, manager: TaggedChannels<M, T>) -> Self {
164        Self {
165            channel_id,
166            manager,
167        }
168    }
169}
170
171impl<M, T> Drop for ChannelGuard<M, T>
172where
173    T: Clone + Eq + Hash + PartialEq,
174{
175    fn drop(&mut self) {
176        self.manager.remove_channel(&self.channel_id);
177    }
178}
179
180impl<M, T> ChannelsInner<M, T>
181where
182    T: Eq + Hash + PartialEq,
183{
184    fn clone_tx(&self, channel_id: &ChannelId) -> Option<Sender<Arc<M>>> {
185        self.channels.get(channel_id).map(|c| c.tx.clone())
186    }
187
188    fn remove_channel_tag(&mut self, channel_id: &ChannelId, tag: &T) {
189        let empty = if let Some(ids) = self.tags.get_mut(tag) {
190            ids.remove(channel_id);
191            ids.is_empty()
192        } else {
193            false
194        };
195        if empty {
196            self.tags.remove(tag);
197        }
198    }
199}
200
201impl<M, T> GuardedReceiver<M, T>
202where
203    T: Clone + Eq + Hash + PartialEq,
204{
205    /// Receives the next event from the channel
206    pub async fn recv(&mut self) -> Option<Arc<M>> {
207        self.rx.recv().await
208    }
209}