sunset_embassy/
embassy_sunset.rs

1#[allow(unused_imports)]
2#[cfg(not(feature = "defmt"))]
3pub use log::{debug, error, info, log, trace, warn};
4
5#[allow(unused_imports)]
6#[cfg(feature = "defmt")]
7pub use defmt::{debug, error, info, panic, trace, warn};
8
9use core::future::{poll_fn, Future};
10use core::task::{Poll, Context};
11use core::ops::{ControlFlow, DerefMut};
12use core::sync::atomic::AtomicBool;
13use core::sync::atomic::Ordering::{Relaxed, SeqCst};
14
15use embassy_sync::waitqueue::WakerRegistration;
16use embassy_sync::blocking_mutex::raw::{NoopRawMutex, RawMutex};
17use embassy_sync::mutex::Mutex;
18use embassy_sync::signal::Signal;
19use embassy_futures::select::select;
20use embassy_futures::join;
21use embedded_io_async::{Read, Write, BufRead};
22
23// thumbv6m has no atomic usize add/sub
24use atomic_polyfill::AtomicUsize;
25
26use pin_utils::pin_mut;
27
28use sunset::{Runner, Result, Error, error, Behaviour, ChanData, ChanHandle, ChanNum, CliBehaviour, ServBehaviour};
29use sunset::config::MAX_CHANNELS;
30
31// For now we only support single-threaded executors.
32// In future this could be behind a cfg to allow different
33// RawMutex for std executors or other situations.
34// Also requires making CliBehaviour : Send, etc.
35pub type SunsetRawMutex = NoopRawMutex;
36
37pub type SunsetMutex<T> = Mutex<SunsetRawMutex, T>;
38
39struct Wakers {
40    chan_read: [WakerRegistration; MAX_CHANNELS],
41
42    chan_write: [WakerRegistration; MAX_CHANNELS],
43
44    /// Will be a stderr read waker for a client, or stderr write waker for
45    /// a server.
46    chan_ext: [WakerRegistration; MAX_CHANNELS],
47
48    // TODO: do we need a separate waker for this?
49    chan_close: [WakerRegistration; MAX_CHANNELS],
50}
51
52struct Inner<'a> {
53    runner: Runner<'a>,
54
55    wakers: Wakers,
56
57    // May only be safely modified when the corresponding
58    // `chan_refcounts` is zero.
59    chan_handles: [Option<ChanHandle>; MAX_CHANNELS],
60}
61
62impl<'a> Inner<'a> {
63    /// Helper to lookup the corresponding ChanHandle
64    ///
65    /// Returns split references that will be required by many callers
66    fn fetch(&mut self, num: ChanNum) -> Result<(&mut Runner<'a>, &ChanHandle, &mut Wakers)> {
67        self.chan_handles[num.0 as usize].as_ref().map(|ch| {
68            (&mut self.runner, ch, &mut self.wakers)
69        })
70        .ok_or_else(Error::bug)
71    }
72}
73
74/// Provides an async wrapper for Sunset core
75///
76/// A [`ChanHandle`] provided by sunset core must be added with [`add_channel()`] before
77/// a method can be called with the equivalent ChanNum.
78///
79/// Applications use `embassy_sunset::{Client,Server}`.
80pub(crate) struct EmbassySunset<'a> {
81    inner: SunsetMutex<Inner<'a>>,
82
83    progress_notify: Signal<SunsetRawMutex, ()>,
84
85    // wake_progress() should be called after modifying these atomics, to
86    // trigger the progress loop to handle state changes
87
88    exit: AtomicBool,
89    flushing: AtomicBool,
90
91    // Refcount for `Inner::chan_handles`. Must be non-async so it can be
92    // decremented on `ChanIn::drop()` etc.
93    // The pending chan_refcount=0 handling occurs in the `progress()` loop.
94    chan_refcounts: [AtomicUsize; MAX_CHANNELS],
95}
96
97impl<'a> EmbassySunset<'a> {
98    pub fn new(runner: Runner<'a>) -> Self {
99        let wakers = Wakers {
100            chan_read: Default::default(),
101            chan_write: Default::default(),
102            chan_ext: Default::default(),
103            chan_close: Default::default(),
104        };
105        let inner = Inner {
106            runner,
107            wakers,
108            chan_handles: Default::default(),
109        };
110        let inner = Mutex::new(inner);
111
112        let progress_notify = Signal::new();
113
114        Self {
115            inner,
116            exit: AtomicBool::new(false),
117            flushing: AtomicBool::new(false),
118            progress_notify,
119            chan_refcounts: Default::default(),
120         }
121    }
122
123    /// Runs the session to completion
124    ///
125    /// `b` is a bit tricky, it allows passing either a Mutex<CliBehaviour> or
126    /// Mutex<ServBehaviour> (to be converted into a Behaviour in progress() after 
127    /// the mutex is locked).
128    pub async fn run<B: ?Sized, M: RawMutex, C: CliBehaviour, S: ServBehaviour>(&self,
129        rsock: &mut impl Read,
130        wsock: &mut impl Write,
131        b: &Mutex<M, B>) -> Result<()>
132        where
133            for<'f> Behaviour<'f, C, S>: From<&'f mut B>
134    {
135        // Some loops need to terminate other loops on completion.
136        // prog finish -> stop rx
137        // rx finish -> stop tx
138        let tx_stop = Signal::<SunsetRawMutex, ()>::new();
139        let rx_stop = Signal::<SunsetRawMutex, ()>::new();
140
141        let tx = async {
142            loop {
143                // TODO: make sunset read directly from socket, no intermediate buffer?
144                // Perhaps not possible async, might deadlock.
145                let mut buf = [0; 1024];
146                let l = self.output(&mut buf).await?;
147                wsock.write_all(&buf[..l]).await
148                .map_err(|_| {
149                    info!("socket write error");
150                    Error::ChannelEOF
151                })?;
152            }
153            #[allow(unreachable_code)]
154            Ok::<_, sunset::Error>(())
155        };
156        let tx = select(tx, tx_stop.wait());
157
158        let rx = async {
159            loop {
160                // TODO: make sunset read directly from socket, no intermediate buffer.
161                let mut buf = [0; 1024];
162                let l = rsock.read(&mut buf).await
163                .map_err(|_| {
164                    info!("socket read error");
165                    Error::ChannelEOF
166                })?;
167                if l == 0 {
168                    debug!("net EOF");
169                    self.flushing.store(true, Relaxed);
170                    self.wake_progress();
171                    break
172                }
173                let mut buf = &buf[..l];
174                while !buf.is_empty() {
175                    let n = self.input(buf).await?;
176                    buf = &buf[n..];
177                }
178            }
179            Ok::<_, sunset::Error>(())
180        };
181
182        // TODO: if RX fails (bad decrypt etc) it doesn't cancel prog, so gets stuck
183        let rx = select(rx, rx_stop.wait());
184        let rx = async {
185            let r = rx.await;
186            tx_stop.signal(());
187            r
188        };
189
190        let prog = async {
191            loop {
192                if self.progress(b).await?.is_break() {
193                    break Ok(())
194                }
195            }
196        };
197
198        let prog = async {
199            let r = prog.await;
200            self.with_runner(|runner| runner.close()).await;
201            rx_stop.signal(());
202            r
203        };
204
205        // TODO: we might want to let `prog` run until buffers are drained
206        // in case a disconnect message was received.
207        // TODO Is there a nice way than this?
208        let f = join::join3(prog, rx, tx).await;
209        let (fp, _frx, _ftx) = f;
210
211        // debug!("fp {fp:?}");
212        // debug!("frx {_frx:?}");
213        // debug!("ftx {_ftx:?}");
214
215        // TODO: is this a good way to do cancellation...?
216        // self.with_runner(|runner| runner.close()).await;
217        // // Wake any channels that were awoken after the runner closed
218        // let mut inner = self.inner.lock().await;
219        // self.wake_channels(&mut inner)?;
220        fp
221    }
222
223    fn wake_progress(&self) {
224        self.progress_notify.signal(())
225    }
226
227    pub async fn exit(&self) {
228        self.exit.store(true, Relaxed);
229        self.wake_progress()
230    }
231
232    fn wake_channels(&self, inner: &mut Inner) -> Result<()> {
233        // Read wakers
234        let w = &mut inner.wakers;
235        if let Some((num, dt, _len)) = inner.runner.ready_channel_input() {
236            // TODO: if there isn't any waker waiting, could we just drop the packet?
237            match dt {
238                ChanData::Normal => w.chan_read[num.0 as usize].wake(),
239                ChanData::Stderr => w.chan_ext[num.0 as usize].wake(),
240            }
241        }
242
243        for (idx, c) in inner.chan_handles.iter().enumerate() {
244            let ch = if let Some(ch) = c.as_ref() {
245                ch
246            } else {
247                continue
248            };
249
250            // Write wakers
251
252
253            // TODO: if this is slow we could be smarter about aggregating dt vs standard,
254            // or handling the case of full out payload buffers.
255            if inner.runner.ready_channel_send(ch, ChanData::Normal)?.unwrap_or(0) > 0 {
256                w.chan_write[idx].wake()
257            }
258
259            if !inner.runner.is_client() {
260                if inner.runner.ready_channel_send(ch, ChanData::Stderr)?.unwrap_or(0) > 0 {
261                    w.chan_ext[idx].wake()
262                }
263            }
264
265            // TODO: do we want to keep waking it?
266            if inner.runner.is_channel_eof(ch) {
267                w.chan_read[idx].wake();
268                if inner.runner.is_client() {
269                    w.chan_ext[idx].wake();
270                }
271            }
272
273            if inner.runner.is_channel_closed(ch) {
274                w.chan_close[idx].wake();
275            }
276        }
277        Ok(())
278    }
279
280    /// Check for channels that have reached zero refcount
281    ///
282    /// When a ChanIO is dropped the refcount may reach 0, but
283    /// without "async Drop" it isn't possible to take the `inner` lock during
284    /// `drop()`.
285    /// Instead this runs periodically from an async context to release channels.
286    fn clear_refcounts(&self, inner: &mut Inner) -> Result<()> {
287        for (ch, count) in inner.chan_handles.iter_mut().zip(self.chan_refcounts.iter()) {
288            let count = count.load(Relaxed);
289            if count > 0 {
290                debug_assert!(ch.is_some());
291                continue;
292            }
293            if let Some(ch) = ch.take() {
294                // done with the channel
295                inner.runner.channel_done(ch)?;
296            }
297        }
298        Ok(())
299    }
300
301    /// Returns ControlFlow::Break on session exit.
302    ///
303    /// B will be either a CliBehaviour or ServBehaviour
304    async fn progress<B: ?Sized, M: RawMutex, C: CliBehaviour, S: ServBehaviour>(&self,
305        b: &Mutex<M, B>)
306        -> Result<ControlFlow<()>>
307        where
308            for<'f> Behaviour<'f, C, S>: From<&'f mut B>
309        {
310            let ret;
311
312        {
313            if self.exit.load(Relaxed) {
314                return Ok(ControlFlow::Break(()))
315            }
316
317            let mut inner = self.inner.lock().await;
318            {
319                {
320                    // lock the Mutex around the CliBehaviour or ServBehaviour
321                    let mut b = b.lock().await;
322                    // dereference the MutexGuard
323                    let b = b.deref_mut();
324                    // create either Behaviour<C, UnusedServ> or Behaviour<UnusedCli, S>
325                    // to pass to the runner.
326                    let mut b: Behaviour<C, S> = b.into();
327                    ret = inner.runner.progress(&mut b).await?;
328                    // b is dropped, allowing other users
329                }
330
331                self.wake_channels(&mut inner)?;
332
333                self.clear_refcounts(&mut inner)?;
334            }
335            // inner dropped
336        }
337
338        if ret.disconnected {
339            return Ok(ControlFlow::Break(()))
340        }
341
342        if !ret.progressed {
343            if self.flushing.load(Relaxed) {
344                // if we're flushing, we exit once there is no progress
345                return Ok(ControlFlow::Break(()))
346            }
347            // Idle until input is received
348            // TODO do we also want to wake in other situations?
349            self.progress_notify.wait().await;
350        }
351
352        Ok(ControlFlow::Continue(()))
353    }
354
355    pub(crate) async fn with_runner<F, R>(&self, f: F) -> R
356        where F: FnOnce(&mut Runner) -> R {
357        let mut inner = self.inner.lock().await;
358        f(&mut inner.runner)
359    }
360
361    /// helper to perform a function on the `inner`, returning a `Poll` value
362    async fn poll_inner<F, T>(&self, mut f: F) -> T
363        where F: FnMut(&mut Inner, &mut Context) -> Poll<T> {
364        poll_fn(|cx| {
365            // Attempt to lock .inner
366            let i = self.inner.lock();
367            pin_mut!(i);
368            match i.poll(cx) {
369                Poll::Ready(mut inner) => {
370                    f(&mut inner, cx)
371                }
372                Poll::Pending => {
373                    // .inner lock is busy
374                    Poll::Pending
375                }
376            }
377        })
378        .await
379    }
380
381    pub async fn output(&self, buf: &mut [u8]) -> Result<usize> {
382        self.poll_inner(|inner, cx| {
383            match inner.runner.output(buf) {
384                // no output ready
385                Ok(0) => {
386                    inner.runner.set_output_waker(cx.waker());
387                    Poll::Pending
388                }
389                Ok(n) => Poll::Ready(Ok(n)),
390                Err(e) => Poll::Ready(Err(e)),
391            }
392        }).await
393    }
394
395    pub async fn input(&self, buf: &[u8]) -> Result<usize> {
396        self.poll_inner(|inner, cx| {
397            if inner.runner.is_input_ready() {
398                let r = match inner.runner.input(buf) {
399                    Ok(0) => {
400                        inner.runner.set_input_waker(cx.waker());
401                        Poll::Pending
402                    },
403                    Ok(n) => Poll::Ready(Ok(n)),
404                    Err(e) => Poll::Ready(Err(e)),
405                };
406                if r.is_ready() {
407                    self.wake_progress()
408                }
409                r
410            } else {
411                inner.runner.set_input_waker(cx.waker());
412                Poll::Pending
413            }
414        }).await
415    }
416
417    /// Reads channel data.
418    pub(crate) async fn read_channel(&self, num: ChanNum, dt: ChanData, buf: &mut [u8]) -> Result<usize> {
419        if num.0 as usize > MAX_CHANNELS {
420            return sunset::error::BadChannel { num }.fail()
421        }
422        self.poll_inner(|inner, cx| {
423            let (runner, h, wakers) = inner.fetch(num)?;
424            let i = match runner.channel_input(h, dt, buf) {
425                Ok(0) => {
426                    // 0 bytes read, pending
427                    match dt {
428                        ChanData::Normal => {
429                            wakers.chan_read[num.0 as usize].register(cx.waker());
430                        }
431                        ChanData::Stderr => {
432                            wakers.chan_ext[num.0 as usize].register(cx.waker());
433                        }
434                    }
435                    Poll::Pending
436                }
437                Err(Error::ChannelEOF) => {
438                    Poll::Ready(Ok(0))
439                }
440                r => Poll::Ready(r),
441            };
442            if matches!(i, Poll::Ready(_)) {
443                self.wake_progress()
444            }
445            i
446        }).await
447    }
448
449    pub(crate) async fn write_channel(&self, num: ChanNum, dt: ChanData, buf: &[u8]) -> Result<usize> {
450        if num.0 as usize > MAX_CHANNELS {
451            return sunset::error::BadChannel { num }.fail()
452        }
453        self.poll_inner(|inner, cx| {
454            let (runner, h, wakers) = inner.fetch(num)?;
455            let l = runner.channel_send(h, dt, buf);
456            if let Ok(0) = l {
457                // 0 bytes written, pending
458                match dt {
459                    ChanData::Normal => {
460                        wakers.chan_write[num.0 as usize].register(cx.waker());
461                    }
462                    ChanData::Stderr => {
463                        wakers.chan_ext[num.0 as usize].register(cx.waker());
464                    }
465                }
466                Poll::Pending
467            } else {
468                self.wake_progress();
469                Poll::Ready(l)
470            }
471        }).await
472    }
473
474    pub(crate) async fn until_channel_closed(&self, num: ChanNum) -> Result<()> {
475        self.poll_inner(|inner, cx| {
476            let (runner, h, wakers) = inner.fetch(num)?;
477            if runner.is_channel_closed(h) {
478                Poll::Ready(Ok(()))
479            } else {
480                wakers.chan_close[num.0 as usize].register(cx.waker());
481                Poll::Pending
482            }
483        }).await
484    }
485
486    pub async fn term_window_change(&self, num: ChanNum, winch: sunset::packets::WinChange) -> Result<()> {
487        let mut inner = self.inner.lock().await;
488        let (runner, h, _) = inner.fetch(num)?;
489        runner.term_window_change(h, winch)
490    }
491
492    /// Adds a new channel handle provided by sunset core.
493    ///
494    /// EmbassySunset will take ownership of the handle. An initial refcount
495    /// must be provided, this will match the number of ChanIO that
496    /// will be created. (A zero initial refcount would be prone to immediate
497    /// garbage collection).
498    /// ChanIO will take care of `inc_chan()` on clone, `dec_chan()` on drop.
499    pub(crate) async fn add_channel(&self, handle: ChanHandle, init_refcount: usize) -> Result<()> {
500        let mut inner = self.inner.lock().await;
501        let idx = handle.num().0 as usize;
502        if inner.chan_handles[idx].is_some() {
503            return error::Bug.fail()
504        }
505
506        debug_assert_eq!(self.chan_refcounts[idx].load(Relaxed), 0);
507
508        inner.chan_handles[idx] = Some(handle);
509        self.chan_refcounts[idx].store(init_refcount, Relaxed);
510        Ok(())
511    }
512
513    pub(crate) fn inc_chan(&self, num: ChanNum) {
514        let c = self.chan_refcounts[num.0 as usize].fetch_add(1, SeqCst);
515        debug_assert_ne!(c, 0);
516        // overflow shouldn't be possible unless ChanIn etc is leaking
517        debug_assert_ne!(c, usize::MAX);
518        // perhaps not necessary? is cheap?
519        self.wake_progress();
520    }
521
522    pub(crate) fn dec_chan(&self, num: ChanNum) {
523        // refcounts that hit zero will be cleaned up later in clear_refcounts()
524        let c = self.chan_refcounts[num.0 as usize].fetch_sub(1, SeqCst);
525        debug_assert_ne!(c, 0);
526        // perhaps not necessary? is cheap?
527        self.wake_progress();
528    }
529}
530
531
532pub async fn io_copy<const B: usize, R, W>(r: &mut R, w: &mut W) -> Result<()>
533    where R: Read<Error=sunset::Error>,
534        W: Write<Error=sunset::Error>
535{
536    let mut b = [0u8; B];
537    loop {
538        let n = r.read(&mut b).await?;
539        if n == 0 {
540            return sunset::error::ChannelEOF.fail();
541        }
542        let b = &b[..n];
543        w.write_all(b).await?
544    }
545    #[allow(unreachable_code)]
546    Ok::<_, Error>(())
547}
548
549pub async fn io_copy_nowriteerror<const B: usize, R, W>(r: &mut R, w: &mut W) -> Result<()>
550    where R: Read<Error=sunset::Error>,
551        W: Write,
552{
553    let mut b = [0u8; B];
554    loop {
555        let n = r.read(&mut b).await?;
556        if n == 0 {
557            return sunset::error::ChannelEOF.fail();
558        }
559        let b = &b[..n];
560        if let Err(_) = w.write_all(b).await {
561            info!("write error");
562        }
563    }
564    #[allow(unreachable_code)]
565    Ok::<_, Error>(())
566}
567
568pub async fn io_buf_copy<R, W>(r: &mut R, w: &mut W) -> Result<()>
569    where R: BufRead<Error=sunset::Error>,
570        W: Write<Error=sunset::Error>
571{
572    loop {
573        let b = r.fill_buf().await?;
574        if b.is_empty() {
575            return sunset::error::ChannelEOF.fail();
576        }
577        let n = b.len();
578        w.write_all(b).await?;
579        r.consume(n)
580    }
581    #[allow(unreachable_code)]
582    Ok::<_, Error>(())
583}
584
585pub async fn io_buf_copy_noreaderror<R, W>(r: &mut R, w: &mut W) -> Result<()>
586    where R: BufRead,
587        W: Write<Error=sunset::Error>
588{
589    loop {
590        let b = match r.fill_buf().await {
591            Ok(b) => b,
592            Err(_) => {
593                info!("read error");
594                embassy_futures::yield_now().await;
595                continue;
596            }
597        };
598        if b.is_empty() {
599            return sunset::error::ChannelEOF.fail();
600        }
601        let n = b.len();
602        w.write_all(b).await?;
603        r.consume(n)
604    }
605    #[allow(unreachable_code)]
606    Ok::<_, Error>(())
607}