serial_async/
lib.rs

1#![cfg_attr(not(test), no_std)]
2
3extern crate alloc;
4
5use alloc::{boxed::Box, sync::Arc};
6use core::{
7    sync::atomic::{AtomicBool, Ordering},
8    task::Poll,
9};
10use futures::{FutureExt, task::AtomicWaker};
11use rdif_serial::ErrorBase;
12pub use rdif_serial::{DriverGeneric, Interface, SerialError};
13
14#[derive(Debug, Clone, Copy, Default)]
15pub struct IrqEvent {
16    pub can_get: bool,
17    pub can_put: bool,
18}
19
20pub trait Registers: Clone + 'static {
21    fn can_put(&self) -> bool;
22    fn put(&self, c: u8) -> Result<(), SerialError>;
23    fn can_get(&self) -> bool;
24    fn get(&self) -> Result<u8, SerialError>;
25    fn get_irq_event(&self) -> IrqEvent;
26    fn clean_irq_event(&self, event: IrqEvent);
27}
28
29pub struct Serial<R: Registers> {
30    registers: R,
31    tx: ChData,
32    rx: ChData,
33    pub irq_handler: Option<IrqHandler<R>>,
34}
35
36impl<R: Registers> Serial<R> {
37    pub fn new(registers: R) -> Self {
38        let tx = ChData::new();
39        let rx = ChData::new();
40
41        Self {
42            registers: registers.clone(),
43            tx: tx.clone(),
44            rx: rx.clone(),
45            irq_handler: Some(IrqHandler { registers, tx, rx }),
46        }
47    }
48
49    pub fn try_take_tx(&mut self) -> Option<Sender<R>> {
50        self.tx.try_take()?;
51
52        Some(Sender {
53            registers: self.registers.clone(),
54            data: self.tx.clone(),
55        })
56    }
57
58    pub fn try_take_rx(&mut self) -> Option<Receiver<R>> {
59        self.rx.try_take()?;
60
61        Some(Receiver {
62            registers: self.registers.clone(),
63            data: self.rx.clone(),
64        })
65    }
66
67    /// Returns the irq state of this [`Serial`].
68    ///
69    /// # Safety
70    ///
71    /// Only used in interrupt handler.
72    pub unsafe fn get_irq_event(&self) -> IrqEvent {
73        self.registers.get_irq_event()
74    }
75
76    /// Cleans the irq state of this [`Serial`].
77    ///
78    /// # Safety
79    ///
80    /// Only used in interrupt handler.
81    pub unsafe fn clean_irq_event(&self, state: IrqEvent) {
82        self.registers.clean_irq_event(state);
83    }
84}
85
86impl<R: Registers> DriverGeneric for Serial<R> {
87    fn open(&mut self) -> Result<(), ErrorBase> {
88        Ok(())
89    }
90
91    fn close(&mut self) -> Result<(), ErrorBase> {
92        Ok(())
93    }
94}
95
96impl<R: Registers> Interface for Serial<R> {
97    fn handle_irq(&mut self) {
98        unsafe { self.irq_handler.as_mut().unwrap().handle_irq() };
99    }
100
101    fn take_tx(&mut self) -> Option<Box<dyn rdif_serial::Sender>> {
102        Some(Box::new(self.try_take_tx()?))
103    }
104
105    fn take_rx(&mut self) -> Option<Box<dyn rdif_serial::Reciever>> {
106        Some(Box::new(self.try_take_rx()?))
107    }
108}
109
110impl<R: Registers> rdif_serial::Sender for Sender<R> {
111    fn write(&mut self, buf: &[u8]) -> Result<usize, SerialError> {
112        Sender::write(self, buf)
113    }
114
115    fn write_all<'a>(
116        &'a mut self,
117        buf: &'a [u8],
118    ) -> rdif_serial::LocalBoxFuture<'a, Result<(), SerialError>> {
119        Sender::write_all(self, buf).boxed_local()
120    }
121}
122
123impl<R: Registers> rdif_serial::Reciever for Receiver<R> {
124    fn read(&mut self, buf: &mut [u8]) -> Result<usize, SerialError> {
125        Receiver::read(self, buf)
126    }
127
128    fn read_all<'a>(
129        &'a mut self,
130        buf: &'a mut [u8],
131    ) -> rdif_serial::LocalBoxFuture<'a, Result<(), SerialError>> {
132        Receiver::read_all(self, buf).boxed_local()
133    }
134}
135
136pub struct IrqHandler<R: Registers> {
137    registers: R,
138    tx: ChData,
139    rx: ChData,
140}
141
142unsafe impl<R: Registers> Sync for IrqHandler<R> {}
143
144impl<R: Registers> IrqHandler<R> {
145    /// Handle interrupt
146    ///
147    /// #  Safety
148    /// Only used in interrupt handler.
149    pub unsafe fn handle_irq(&self) {
150        let state = self.registers.get_irq_event();
151
152        if state.can_get {
153            self.rx.waker.wake();
154        }
155        if state.can_put {
156            self.tx.waker.wake();
157        }
158
159        self.registers.clean_irq_event(state);
160    }
161}
162
163unsafe impl<R: Registers> Send for Serial<R> {}
164unsafe impl<R: Registers> Send for Sender<R> {}
165unsafe impl<R: Registers> Send for Receiver<R> {}
166
167#[derive(Clone)]
168struct ChData {
169    taken: Arc<AtomicBool>,
170    waker: Arc<AtomicWaker>,
171}
172
173impl ChData {
174    fn new() -> Self {
175        Self {
176            taken: Arc::new(AtomicBool::new(false)),
177            waker: Arc::new(AtomicWaker::new()),
178        }
179    }
180
181    fn try_take(&self) -> Option<()> {
182        match self
183            .taken
184            .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
185        {
186            Ok(taken) => {
187                if taken {
188                    None
189                } else {
190                    Some(())
191                }
192            }
193            Err(_) => None,
194        }
195    }
196}
197
198pub struct Sender<R: Registers> {
199    registers: R,
200    data: ChData,
201}
202
203impl<R: Registers> Sender<R> {
204    pub fn write(&mut self, buf: &[u8]) -> Result<usize, SerialError> {
205        let mut written = 0;
206        for &byte in buf {
207            if !self.registers.can_put() {
208                break;
209            }
210            self.registers.put(byte)?;
211            written += 1;
212        }
213        Ok(written)
214    }
215
216    pub fn can_put(&self) -> bool {
217        self.registers.can_put()
218    }
219
220    pub fn write_all<'a>(
221        &'a mut self,
222        buf: &'a [u8],
223    ) -> impl Future<Output = Result<(), SerialError>> + 'a {
224        WaitForWriteAll {
225            waiter: self.data.waker.clone(),
226            sender: self,
227            buf,
228        }
229    }
230}
231
232impl<R: Registers> Drop for Sender<R> {
233    fn drop(&mut self) {
234        self.data.taken.store(false, Ordering::Release);
235    }
236}
237
238pub struct Receiver<R: Registers> {
239    registers: R,
240    data: ChData,
241}
242
243impl<R: Registers> Receiver<R> {
244    pub fn read(&mut self, buf: &mut [u8]) -> Result<usize, SerialError> {
245        let mut read = 0;
246        for byte in buf {
247            if !self.registers.can_get() {
248                break;
249            }
250            *byte = self.registers.get()?;
251            read += 1;
252        }
253        Ok(read)
254    }
255
256    pub fn read_all<'a>(
257        &'a mut self,
258        buf: &'a mut [u8],
259    ) -> impl Future<Output = Result<(), SerialError>> + 'a {
260        WaitForReadAll {
261            waiter: self.data.waker.clone(),
262            rx: self,
263            buf,
264            i: 0,
265        }
266    }
267}
268
269impl<R: Registers> Drop for Receiver<R> {
270    fn drop(&mut self) {
271        self.data.taken.store(false, Ordering::Release);
272    }
273}
274
275struct WaitForWriteAll<'a, R: Registers> {
276    waiter: Arc<AtomicWaker>,
277    sender: &'a mut Sender<R>,
278    buf: &'a [u8],
279}
280
281impl<R: Registers> Future for WaitForWriteAll<'_, R> {
282    type Output = Result<(), SerialError>;
283
284    fn poll(
285        mut self: core::pin::Pin<&mut Self>,
286        cx: &mut core::task::Context<'_>,
287    ) -> core::task::Poll<Self::Output> {
288        self.waiter.register(cx.waker());
289
290        let buf = self.buf;
291        match self.sender.write(buf) {
292            Ok(n) => {
293                self.buf = &buf[n..];
294                if n < buf.len() {
295                    Poll::Pending
296                } else {
297                    Poll::Ready(Ok(()))
298                }
299            }
300            Err(e) => Poll::Ready(Err(e)),
301        }
302    }
303}
304
305struct WaitForReadAll<'a, R: Registers> {
306    waiter: Arc<AtomicWaker>,
307    rx: &'a mut Receiver<R>,
308    buf: &'a mut [u8],
309    i: usize,
310}
311
312impl<R: Registers> Future for WaitForReadAll<'_, R> {
313    type Output = Result<(), SerialError>;
314
315    fn poll(
316        mut self: core::pin::Pin<&mut Self>,
317        cx: &mut core::task::Context<'_>,
318    ) -> core::task::Poll<Self::Output> {
319        self.waiter.register(cx.waker());
320        let begin = self.i;
321        for i in begin..self.buf.len() {
322            if self.rx.registers.can_get() {
323                match self.rx.registers.get() {
324                    Ok(b) => {
325                        self.buf[i] = b;
326                        self.i += 1;
327                    }
328                    Err(e) => {
329                        return Poll::Ready(Err(e));
330                    }
331                }
332            } else {
333                return Poll::Pending;
334            }
335        }
336        Poll::Ready(Ok(()))
337    }
338}