1#![warn(clippy::all, missing_docs, nonstandard_style, future_incompatible)]
2#![doc = include_str!("../README.md")]
3
4use std::{
5 collections::{HashMap, HashSet},
6 hash::Hash,
7 sync::Arc,
8};
9
10use parking_lot::Mutex;
11use tokio::sync::mpsc::{self, Receiver, Sender};
12
13type ChannelId = u64;
14
15pub struct TaggedChannels<M, T>(Arc<Mutex<ChannelsInner<M, T>>>);
17
18impl<M, T> Clone for TaggedChannels<M, T> {
19 fn clone(&self) -> Self {
20 Self(Arc::clone(&self.0))
21 }
22}
23
24pub struct ChannelsInner<M, T> {
26 last_id: u64,
27 channels: HashMap<ChannelId, Channel<M, T>>,
28 tags: HashMap<T, HashSet<ChannelId>>,
29}
30
31struct Channel<M, T> {
32 tx: Sender<Arc<M>>,
33 tags: Box<[T]>,
34}
35
36pub struct ChannelGuard<M, T>
38where
39 T: Clone + Eq + Hash + PartialEq,
40{
41 channel_id: ChannelId,
42 manager: TaggedChannels<M, T>,
43}
44
45pub struct GuardedReceiver<M, T>
47where
48 T: Clone + Eq + Hash + PartialEq,
49{
50 rx: Receiver<Arc<M>>,
51 #[allow(dead_code)]
52 guard: ChannelGuard<M, T>,
53}
54
55impl<M, T> TaggedChannels<M, T>
56where
57 T: Clone + Eq + Hash + PartialEq,
58{
59 pub fn new() -> Self {
61 Default::default()
62 }
63
64 pub fn create_channel(&self, tags: impl Into<Vec<T>>) -> GuardedReceiver<M, T> {
66 let tags = tags.into();
67 let (tx, rx) = mpsc::channel::<Arc<M>>(1);
68 let channel = Channel {
69 tx,
70 tags: tags.clone().into_boxed_slice(),
71 };
72
73 let mut inner = self.0.lock();
74 let channel_id = inner.last_id.overflowing_add(1).0;
75 inner.channels.insert(channel_id, channel);
76 for tag in tags {
77 inner
78 .tags
79 .entry(tag)
80 .and_modify(|set| {
81 set.insert(channel_id);
82 })
83 .or_insert(HashSet::from([channel_id]));
84 }
85 inner.last_id = channel_id;
86
87 let guard = ChannelGuard::new(channel_id, self.clone());
88 GuardedReceiver { rx, guard }
89 }
90
91 pub fn num_connections(&self) -> usize {
93 self.0.lock().channels.len()
94 }
95
96 pub async fn send_by_tag(&self, tag: &T, message: M) {
98 let msg = Arc::new(message);
99 for rx in self.tagged_senders(tag) {
100 rx.send(Arc::clone(&msg)).await.ok();
101 }
102 }
103
104 pub async fn broadcast(&self, message: M) {
106 let msg = Arc::new(message);
107 for rx in self.all_senders() {
108 rx.send(Arc::clone(&msg)).await.ok();
109 }
110 }
111
112 pub fn connected_tags(&self) -> Vec<T> {
114 self.0.lock().tags.keys().cloned().collect()
115 }
116
117 fn remove_channel(&self, channel_id: &ChannelId) {
119 let mut inner = self.0.lock();
120 if let Some(channel) = inner.channels.remove(channel_id) {
121 for tag in channel.tags.iter() {
122 inner.remove_channel_tag(channel_id, tag);
123 }
124 }
125 }
126
127 fn tagged_senders(&self, tag: &T) -> Vec<Sender<Arc<M>>> {
129 let inner = self.0.lock();
130 inner
131 .tags
132 .get(tag)
133 .map(|ids| ids.iter().filter_map(|id| inner.clone_tx(id)).collect())
134 .unwrap_or_default()
135 }
136
137 fn all_senders(&self) -> Vec<Sender<Arc<M>>> {
139 self.0
140 .lock()
141 .channels
142 .values()
143 .map(|c| c.tx.clone())
144 .collect()
145 }
146}
147
148impl<M, T> Default for TaggedChannels<M, T> {
149 fn default() -> Self {
150 let inner = ChannelsInner {
151 last_id: 0,
152 channels: HashMap::new(),
153 tags: HashMap::new(),
154 };
155 Self(Arc::new(Mutex::new(inner)))
156 }
157}
158
159impl<M, T> ChannelGuard<M, T>
160where
161 T: Clone + Eq + Hash + PartialEq,
162{
163 fn new(channel_id: ChannelId, manager: TaggedChannels<M, T>) -> Self {
164 Self {
165 channel_id,
166 manager,
167 }
168 }
169}
170
171impl<M, T> Drop for ChannelGuard<M, T>
172where
173 T: Clone + Eq + Hash + PartialEq,
174{
175 fn drop(&mut self) {
176 self.manager.remove_channel(&self.channel_id);
177 }
178}
179
180impl<M, T> ChannelsInner<M, T>
181where
182 T: Eq + Hash + PartialEq,
183{
184 fn clone_tx(&self, channel_id: &ChannelId) -> Option<Sender<Arc<M>>> {
185 self.channels.get(channel_id).map(|c| c.tx.clone())
186 }
187
188 fn remove_channel_tag(&mut self, channel_id: &ChannelId, tag: &T) {
189 let empty = if let Some(ids) = self.tags.get_mut(tag) {
190 ids.remove(channel_id);
191 ids.is_empty()
192 } else {
193 false
194 };
195 if empty {
196 self.tags.remove(tag);
197 }
198 }
199}
200
201impl<M, T> GuardedReceiver<M, T>
202where
203 T: Clone + Eq + Hash + PartialEq,
204{
205 pub async fn recv(&mut self) -> Option<Arc<M>> {
207 self.rx.recv().await
208 }
209}