sans_io_runtime/
bus.rs

1use parking_lot::RwLock;
2use std::{collections::HashMap, fmt::Debug, hash::Hash, sync::Arc};
3
4mod leg;
5mod local_hub;
6
7pub use leg::*;
8pub use local_hub::*;
9
10use crate::backend::Awaker;
11
12pub trait BusSendSingleFeature<MSG> {
13    fn send_safe(&self, dest_leg: usize, msg: MSG) -> usize;
14    fn send(&self, dest_leg: usize, safe: bool, msg: MSG) -> Result<usize, BusLegSenderErr>;
15}
16
17pub trait BusSendMultiFeature<MSG: Clone> {
18    fn broadcast(&self, safe: bool, msg: MSG);
19}
20
21pub trait BusPubSubFeature<ChannelId, MSG: Clone> {
22    fn subscribe(&self, channel: ChannelId);
23    fn unsubscribe(&self, channel: ChannelId);
24    fn publish(&self, channel: ChannelId, safe: bool, msg: MSG);
25}
26
27pub struct BusSystemBuilder<ChannelId, MSG, const STACK_SIZE: usize> {
28    legs: Arc<RwLock<Vec<BusLegSender<ChannelId, MSG, STACK_SIZE>>>>,
29    channels: Arc<RwLock<HashMap<ChannelId, Vec<usize>>>>,
30}
31
32impl<ChannelId, MSG, const STACK_SIZE: usize> Default
33    for BusSystemBuilder<ChannelId, MSG, STACK_SIZE>
34{
35    fn default() -> Self {
36        Self {
37            legs: Default::default(),
38            channels: Default::default(),
39        }
40    }
41}
42
43impl<ChannelId, MSG, const STACK_SIZE: usize> BusSystemBuilder<ChannelId, MSG, STACK_SIZE> {
44    pub fn new_worker(&mut self) -> BusWorker<ChannelId, MSG, STACK_SIZE> {
45        let mut legs = self.legs.write();
46        let leg_index = legs.len();
47        let (sender, recv) = create_bus_leg();
48        legs.push(sender);
49
50        BusWorker {
51            leg_index,
52            receiver: recv,
53            legs: self.legs.clone(),
54            channels: self.channels.clone(),
55        }
56    }
57}
58
59impl<ChannelId, MSG, const STACK_SIZE: usize> BusSendSingleFeature<MSG>
60    for BusSystemBuilder<ChannelId, MSG, STACK_SIZE>
61{
62    fn send_safe(&self, dest_leg: usize, msg: MSG) -> usize {
63        let legs = self.legs.read();
64        legs[dest_leg].send_safe(BusEventSource::External, msg)
65    }
66
67    fn send(&self, dest_leg: usize, safe: bool, msg: MSG) -> Result<usize, BusLegSenderErr> {
68        let legs = self.legs.read();
69        legs[dest_leg].send(BusEventSource::External, safe, msg)
70    }
71}
72
73impl<ChannelId, MSG: Clone, const STACK_SIZE: usize> BusSendMultiFeature<MSG>
74    for BusSystemBuilder<ChannelId, MSG, STACK_SIZE>
75{
76    fn broadcast(&self, safe: bool, msg: MSG) {
77        let legs = self.legs.read();
78        match legs.len() {
79            0 => log::warn!("No leg to broadcast"),
80            1 => {
81                let _ = legs[0].send(BusEventSource::External, safe, msg);
82            }
83            _ => {
84                for leg in &*legs {
85                    let _ = leg.send(BusEventSource::External, safe, msg.clone());
86                }
87            }
88        }
89    }
90}
91
92pub struct BusWorker<ChannelId, MSG, const STACK_SIZE: usize> {
93    leg_index: usize,
94    receiver: BusLegReceiver<ChannelId, MSG, STACK_SIZE>,
95    legs: Arc<RwLock<Vec<BusLegSender<ChannelId, MSG, STACK_SIZE>>>>,
96    channels: Arc<RwLock<HashMap<ChannelId, Vec<usize>>>>,
97}
98
99impl<ChannelId, MSG, const STACK_SIZE: usize> BusWorker<ChannelId, MSG, STACK_SIZE> {
100    pub fn leg_index(&self) -> usize {
101        self.leg_index
102    }
103
104    pub fn recv(&self) -> Option<(BusEventSource<ChannelId>, MSG)> {
105        self.receiver.recv()
106    }
107
108    pub fn set_awaker(&self, awaker: Arc<dyn Awaker>) {
109        self.receiver.set_awaker(awaker);
110    }
111}
112
113impl<ChannelId, MSG, const STACK_SIZE: usize> BusSendSingleFeature<MSG>
114    for BusWorker<ChannelId, MSG, STACK_SIZE>
115{
116    fn send_safe(&self, dest_leg: usize, msg: MSG) -> usize {
117        let legs = self.legs.read();
118        legs[dest_leg].send_safe(BusEventSource::Direct(self.leg_index), msg)
119    }
120
121    fn send(&self, dest_leg: usize, safe: bool, msg: MSG) -> Result<usize, BusLegSenderErr> {
122        let legs = self.legs.read();
123        legs[dest_leg].send(BusEventSource::Direct(self.leg_index), safe, msg)
124    }
125}
126
127impl<ChannelId, MSG: Clone, const STACK_SIZE: usize> BusSendMultiFeature<MSG>
128    for BusWorker<ChannelId, MSG, STACK_SIZE>
129{
130    fn broadcast(&self, safe: bool, msg: MSG) {
131        let legs = self.legs.read();
132        match legs.len() {
133            0 => log::warn!("No leg to broadcast"),
134            1 => {
135                let _ = legs[0].send(BusEventSource::Broadcast(self.leg_index), safe, msg);
136            }
137            _ => {
138                for leg in &*legs {
139                    let _ = leg.send(BusEventSource::Broadcast(self.leg_index), safe, msg.clone());
140                }
141            }
142        }
143    }
144}
145
146impl<ChannelId: Debug + Copy + Hash + PartialEq + Eq, MSG: Clone, const STACK_SIZE: usize>
147    BusPubSubFeature<ChannelId, MSG> for BusWorker<ChannelId, MSG, STACK_SIZE>
148{
149    fn subscribe(&self, channel: ChannelId) {
150        let mut channels = self.channels.write();
151        let entry = channels.entry(channel).or_default();
152        if entry.contains(&self.leg_index) {
153            return;
154        }
155        entry.push(self.leg_index);
156    }
157
158    fn unsubscribe(&self, channel: ChannelId) {
159        let mut channels = self.channels.write();
160        if let Some(entry) = channels.get_mut(&channel) {
161            if let Some(index) = entry.iter().position(|x| *x == self.leg_index) {
162                entry.swap_remove(index);
163            }
164        }
165    }
166
167    fn publish(&self, channel: ChannelId, safe: bool, msg: MSG) {
168        let legs = self.legs.read();
169        let channels = self.channels.read();
170        if let Some(entry) = channels.get(&channel) {
171            if entry.len() == 1 {
172                let _ = legs[entry[0]].send(
173                    BusEventSource::Channel(self.leg_index, channel),
174                    safe,
175                    msg,
176                );
177            } else {
178                for &leg_index in entry {
179                    let _ = legs[leg_index].send(
180                        BusEventSource::Channel(self.leg_index, channel),
181                        safe,
182                        msg.clone(),
183                    );
184                }
185            }
186        }
187    }
188}