sunset_async/
async_sunset.rs

1#[allow(unused_imports)]
2pub use log::{debug, error, info, log, trace, warn};
3
4use core::future::{poll_fn, Future};
5use core::sync::atomic::AtomicBool;
6use core::sync::atomic::Ordering::{Relaxed, SeqCst};
7use core::task::{Context, Poll, Poll::Pending, Poll::Ready};
8
9use embassy_futures::join;
10use embassy_futures::select::select;
11#[allow(unused_imports)]
12use embassy_sync::blocking_mutex::raw::{CriticalSectionRawMutex, NoopRawMutex};
13use embassy_sync::mutex::{Mutex, MutexGuard};
14use embassy_sync::signal::Signal;
15use embassy_sync::waitqueue::WakerRegistration;
16use embedded_io_async::{BufRead, Read, Write};
17
18use pin_utils::pin_mut;
19
20use sunset::config::MAX_CHANNELS;
21use sunset::event::Event;
22use sunset::{error, ChanData, ChanHandle, ChanNum, CliServ, Error, Result, Runner};
23
24#[cfg(feature = "multi-thread")]
25pub type SunsetRawMutex = CriticalSectionRawMutex;
26#[cfg(not(feature = "multi-thread"))]
27pub type SunsetRawMutex = NoopRawMutex;
28
29pub type SunsetMutex<T> = Mutex<SunsetRawMutex, T>;
30
31#[derive(Debug)]
32struct Wakers {
33    chan_read: [WakerRegistration; MAX_CHANNELS],
34
35    chan_write: [WakerRegistration; MAX_CHANNELS],
36
37    /// Will be a stderr read waker for a client, or stderr write waker for
38    /// a server.
39    chan_ext: [WakerRegistration; MAX_CHANNELS],
40
41    // TODO: do we need a separate waker for this?
42    chan_close: [WakerRegistration; MAX_CHANNELS],
43}
44
45struct Inner<'a, CS: CliServ> {
46    runner: Runner<'a, CS>,
47
48    wakers: Wakers,
49
50    // May only be safely modified when the corresponding
51    // `chan_refcounts` is zero.
52    chan_handles: [Option<ChanHandle>; MAX_CHANNELS],
53}
54
55impl<'a, CS: CliServ> Inner<'a, CS> {
56    /// Helper to lookup the corresponding ChanHandle
57    ///
58    /// Returns split references that will be required by many callers
59    fn fetch(
60        &mut self,
61        num: ChanNum,
62    ) -> Result<(&mut Runner<'a, CS>, &ChanHandle, &mut Wakers)> {
63        let h = self
64            .chan_handles
65            .get(num.0 as usize)
66            .ok_or(Error::BadChannel { num })?;
67        h.as_ref()
68            .map(|ch| (&mut self.runner, ch, &mut self.wakers))
69            .ok_or_else(Error::bug)
70    }
71}
72
73/// A handle used for storage from a [`SSHClient::progress()`](crate::SSHClient::progress)
74/// or [`SSHServer::progress()`](crate::SSHServer::progress) call.
75pub struct ProgressHolder<'g, 'a, CS: CliServ> {
76    guard: Option<MutexGuard<'g, SunsetRawMutex, Inner<'a, CS>>>,
77}
78
79impl<'g, 'a, CS: CliServ> ProgressHolder<'g, 'a, CS> {
80    pub fn new() -> Self {
81        Self { guard: None }
82    }
83}
84
85impl<CS: CliServ> Default for ProgressHolder<'_, '_, CS> {
86    fn default() -> Self {
87        Self::new()
88    }
89}
90
91/// Provides an async wrapper for Sunset core
92///
93/// A [`ChanHandle`] provided by sunset core must be added with [`add_channel()`] before
94/// a method can be called with the equivalent ChanNum.
95///
96/// Applications use `async_sunset::{Client,Server}`.
97pub(crate) struct AsyncSunset<'a, CS: CliServ> {
98    inner: SunsetMutex<Inner<'a, CS>>,
99
100    progress_notify: Signal<SunsetRawMutex, ()>,
101    last_progress_idled: AtomicBool,
102
103    // wake_progress() should be called after modifying these atomics, to
104    // trigger the progress loop to handle state changes
105
106    // When draining the last events
107    moribund: AtomicBool,
108
109    // Refcount for `Inner::chan_handles`. Must be non-async so it can be
110    // decremented on `ChanIn::drop()` etc.
111    // The pending chan_refcount=0 handling occurs in the `progress()` loop.
112    //
113    // thumbv6m has no atomic usize add/sub.
114    chan_refcounts: [portable_atomic::AtomicUsize; MAX_CHANNELS],
115}
116
117impl<CS: CliServ> core::fmt::Debug for AsyncSunset<'_, CS> {
118    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
119        let mut d = f.debug_struct("AsyncSunset");
120        if let Ok(i) = self.inner.try_lock() {
121            d.field("runner", &i.runner);
122        } else {
123            d.field("inner", &"(locked)");
124        }
125        d.finish_non_exhaustive()
126    }
127}
128
129impl<'a, CS: CliServ> AsyncSunset<'a, CS> {
130    pub fn new(runner: Runner<'a, CS>) -> Self {
131        let wakers = Wakers {
132            chan_read: Default::default(),
133            chan_write: Default::default(),
134            chan_ext: Default::default(),
135            chan_close: Default::default(),
136        };
137        let inner = Inner { runner, wakers, chan_handles: Default::default() };
138        let inner = Mutex::new(inner);
139
140        let progress_notify = Signal::new();
141
142        Self {
143            inner,
144            moribund: AtomicBool::new(false),
145            progress_notify,
146            chan_refcounts: Default::default(),
147            last_progress_idled: AtomicBool::new(false),
148        }
149    }
150
151    /// Runs the session to completion
152    pub async fn run(
153        &self,
154        rsock: &mut impl Read,
155        wsock: &mut impl Write,
156    ) -> Result<()> {
157        // Some loops need to terminate other loops on completion.
158        // prog finish -> stop rx
159        // rx finish -> stop tx
160        let tx_stop = Signal::<SunsetRawMutex, ()>::new();
161        let rx_stop = Signal::<SunsetRawMutex, ()>::new();
162
163        let tx = async {
164            let r = self
165                .output_loop(wsock)
166                .await
167                .inspect(|r| warn!("tx complete {r:?}"));
168            r
169        };
170        let tx = select(tx, tx_stop.wait());
171
172        // rxbuf outside the async block avoids an extraneous copy somehow
173        let mut rxbuf = [0; 1024];
174        let rx = async {
175            loop {
176                // TODO: make sunset read directly from socket, no intermediate buffer.
177                let l = match rsock.read(&mut rxbuf).await {
178                    Ok(0) => {
179                        debug!("net EOF");
180                        self.with_runner(|r| r.close_input()).await;
181                        self.moribund.store(true, Relaxed);
182                        self.wake_progress();
183                        break Ok(());
184                    }
185                    Ok(l) => l,
186                    Err(_) => {
187                        info!("socket read error");
188                        self.with_runner(|r| r.close_input()).await;
189                        break Err(Error::ChannelEOF);
190                    }
191                };
192                let mut rxbuf = &rxbuf[..l];
193                while !rxbuf.is_empty() {
194                    let n = self.input(rxbuf).await?;
195                    self.wake_progress();
196                    rxbuf = &rxbuf[n..];
197                }
198            }
199            .inspect(|r| warn!("rx complete {r:?}"))
200        };
201
202        // TODO: if RX fails (bad decrypt etc) it doesn't cancel prog, so gets stuck
203        let rx = async {
204            let r = select(rx, rx_stop.wait()).await;
205            tx_stop.signal(());
206            r
207        };
208
209        // TODO: we might want to let `prog` run until buffers are drained
210        // in case a disconnect message was received.
211        // TODO Is there a nice way than this?
212        let f = join::join(rx, tx).await;
213        let (_frx, _ftx) = f;
214
215        // debug!("frx {_frx:?}");
216        // debug!("ftx {_ftx:?}");
217
218        // TODO: is this a good way to do cancellation...?
219        // self.with_runner(|runner| runner.close()).await;
220        // // Wake any channels that were awoken after the runner closed
221        // let mut inner = self.inner.lock().await;
222        // self.wake_channels(&mut inner)?;
223        Ok(())
224    }
225
226    fn wake_progress(&self) {
227        self.progress_notify.signal(())
228    }
229
230    fn wake_channels(&self, inner: &mut Inner<CS>) -> Result<()> {
231        // Read wakers
232        let w = &mut inner.wakers;
233        if let Some((num, dt, _len)) = inner.runner.read_channel_ready() {
234            let waker = match dt {
235                ChanData::Normal => &mut w.chan_read[num.0 as usize],
236                ChanData::Stderr => &mut w.chan_ext[num.0 as usize],
237            };
238            if waker.occupied() {
239                waker.wake();
240            } else {
241                // No waker waiting for this packet, so drop it.
242                // This avoids the case where for example a client application
243                // is only reading from a Stdin ChanIn, but some data arrives
244                // over the write fore Stderr. Something needs to mark it done,
245                // since the session can't proceed until it's consumed.
246                if let Some(h) = &inner.chan_handles[num.0 as usize] {
247                    inner.runner.discard_read_channel(h)?
248                }
249            }
250        }
251
252        for (idx, c) in inner.chan_handles.iter().enumerate() {
253            let ch = if let Some(ch) = c.as_ref() { ch } else { continue };
254
255            // Write wakers
256
257            // TODO: if this is slow we could be smarter about aggregating dt vs standard,
258            // or handling the case of full out payload buffers.
259            if inner.runner.write_channel_ready(ch, ChanData::Normal)?.unwrap_or(0)
260                > 0
261            {
262                w.chan_write[idx].wake()
263            }
264
265            if !CS::is_client()
266                && inner
267                    .runner
268                    .write_channel_ready(ch, ChanData::Stderr)?
269                    .unwrap_or(0)
270                    > 0
271            {
272                w.chan_ext[idx].wake()
273            }
274
275            // TODO: do we want to keep waking it?
276            if inner.runner.is_channel_eof(ch) {
277                w.chan_read[idx].wake();
278                if CS::is_client() {
279                    w.chan_ext[idx].wake();
280                }
281            }
282
283            if inner.runner.is_channel_closed(ch) {
284                w.chan_close[idx].wake();
285            }
286        }
287        Ok(())
288    }
289
290    /// Check for channels that have reached zero refcount
291    ///
292    /// When a ChanIO is dropped the refcount may reach 0, but
293    /// without "async Drop" it isn't possible to take the `inner` lock during
294    /// `drop()`.
295    /// Instead this runs periodically from an async context to release channels.
296    fn clear_refcounts(&self, inner: &mut Inner<CS>) -> Result<()> {
297        for (ch, count) in
298            inner.chan_handles.iter_mut().zip(self.chan_refcounts.iter())
299        {
300            let count = count.load(Relaxed);
301            if count > 0 {
302                debug_assert!(ch.is_some());
303                continue;
304            }
305            if let Some(ch) = ch.take() {
306                // done with the channel
307                inner.runner.channel_done(ch)?;
308            }
309        }
310        Ok(())
311    }
312
313    /// Returns an `Event`.
314    ///
315    /// The returned `Event` borrows from the mutex locked in `ph`.
316    pub(crate) async fn progress<'g, 'f>(
317        &'g self,
318        ph: &'f mut ProgressHolder<'g, 'a, CS>,
319    ) -> Result<Event<'f, 'a>> {
320        // In case a ProgressHolder was reused, release any guard.
321        *ph = ProgressHolder::default();
322
323        // Ideally we would .wait() after calling .progress() below when
324        // Event::None is returned, but the borrow checker won't allow that.
325        // Instead we wait at the start of the next progress() call,
326        // but will return immediately if something external
327        // has woken the progress_notify in the interim.
328        //
329        // TODO: rework once rustc's polonius is stable.
330        // https://github.com/rust-lang/rust/issues/54663
331        //
332        // This is a non-atomic swap since thumbv6m won't support it.
333        // Only one task should be calling progress(), so that's OK.
334        let need_wait = self.last_progress_idled.load(Relaxed);
335        if need_wait {
336            self.last_progress_idled.store(false, Relaxed);
337            self.progress_notify.wait().await;
338        }
339
340        // The returned event borrows from a guard inside ProgressHolder
341        let inner = ph.guard.insert(self.inner.lock().await);
342
343        // Drop deferred finished channels
344        self.clear_refcounts(inner)?;
345        // Wake channels
346        self.wake_channels(inner)?;
347
348        if self.moribund.load(Relaxed) {
349            // if we're flushing, we exit once there is no progress
350            debug!("All data flushed")
351            // TODO make this do something!
352        }
353
354        let ev = inner.runner.progress();
355        if matches!(ev, Ok(Event::None)) {
356            // nothing happened, will progress_notify.wait() next progress() call, see above.
357            self.last_progress_idled.store(true, Relaxed);
358        }
359        ev
360    }
361
362    pub(crate) async fn with_runner<F, R>(&self, f: F) -> R
363    where
364        F: FnOnce(&mut Runner<CS>) -> R,
365    {
366        let mut inner = self.inner.lock().await;
367        f(&mut inner.runner)
368    }
369
370    /// helper to perform a function on the `inner`, returning a `Poll` value
371    async fn poll_inner<F, T>(&self, mut f: F) -> T
372    where
373        F: FnMut(&mut Inner<CS>, &mut Context) -> Poll<T>,
374    {
375        poll_fn(|cx| {
376            // Attempt to lock .inner
377            let i = self.inner.lock();
378            pin_mut!(i);
379            match i.poll(cx) {
380                Poll::Ready(mut inner) => f(&mut inner, cx),
381                Poll::Pending => {
382                    // .inner lock is busy
383                    Poll::Pending
384                }
385            }
386        })
387        .await
388    }
389
390    pub async fn output_loop(&self, wsock: &mut impl Write) -> Result<()> {
391        poll_fn(|cx| {
392            // Attempt to lock .inner
393            let i = self.inner.lock();
394            pin_mut!(i);
395            let Ready(mut inner) = i.poll(cx) else {
396                return Pending;
397            };
398
399            loop {
400                let buf = inner.runner.output_buf();
401                if buf.is_empty() {
402                    // no output ready
403                    inner.runner.set_output_waker(cx.waker());
404                    return Pending;
405                }
406
407                let res = {
408                    let w = wsock.write(buf);
409                    pin_mut!(w);
410                    w.poll(cx)
411                };
412
413                return match res {
414                    Pending => Pending,
415                    Ready(Ok(0)) => {
416                        info!("socket EOF");
417                        inner.runner.close_output();
418                        Ready(error::ChannelEOF.fail())
419                    }
420                    Ready(Ok(write_len)) => {
421                        let buf_len = buf.len();
422                        inner.runner.consume_output(write_len);
423                        if write_len < buf_len {
424                            // Must keep going until either wsock
425                            // or output_buf returns Pending and
426                            // registers a waker.
427                            continue;
428                        }
429                        Pending
430                    }
431                    Ready(Err(_e)) => {
432                        info!("socket write error");
433                        inner.runner.close_output();
434                        Ready(error::ChannelEOF.fail())
435                    }
436                };
437            }
438        })
439        .await
440    }
441
442    pub async fn input(&self, buf: &[u8]) -> Result<usize> {
443        let res = self
444            .poll_inner(|inner, cx| {
445                if inner.runner.is_input_ready() {
446                    match inner.runner.input(buf) {
447                        Ok(0) => {
448                            inner.runner.set_input_waker(cx.waker());
449                            Poll::Pending
450                        }
451                        Ok(n) => Poll::Ready(Ok(n)),
452                        Err(e) => Poll::Ready(Err(e)),
453                    }
454                } else {
455                    inner.runner.set_input_waker(cx.waker());
456                    Poll::Pending
457                }
458            })
459            .await;
460        self.wake_progress();
461        res
462    }
463
464    /// Adds a new channel handle provided by sunset core.
465    ///
466    /// AsyncSunset will take ownership of the handle. An initial refcount
467    /// must be provided, this will match the number of ChanIO that
468    /// will be created. (A zero initial refcount would be prone to immediate
469    /// garbage collection).
470    /// ChanIO will take care of `inc_chan()` on clone, `dec_chan()` on drop.
471    pub(crate) async fn add_channel(
472        &self,
473        handle: ChanHandle,
474        init_refcount: usize,
475    ) -> Result<()> {
476        let mut inner = self.inner.lock().await;
477        let idx = handle.num().0 as usize;
478        if inner.chan_handles[idx].is_some() {
479            return error::Bug.fail();
480        }
481
482        debug_assert_eq!(self.chan_refcounts[idx].load(Relaxed), 0);
483
484        inner.chan_handles[idx] = Some(handle);
485        self.chan_refcounts[idx].store(init_refcount, Relaxed);
486        Ok(())
487    }
488}
489
490// necessary for the &dyn ChanCore
491#[cfg(feature = "multi-thread")]
492pub(crate) trait MaybeSend: Sync {}
493#[cfg(not(feature = "multi-thread"))]
494pub(crate) trait MaybeSend {}
495
496impl<'a, CS: CliServ> MaybeSend for AsyncSunset<'a, CS> {}
497
498// Ideally the poll_...() methods would be async, but that isn't
499// dyn compatible at present. Instead run poll_fn in the ChanIO caller.
500pub(crate) trait ChanCore: MaybeSend {
501    fn inc_chan(&self, num: ChanNum);
502    fn dec_chan(&self, num: ChanNum);
503
504    fn poll_until_channel_closed(
505        &self,
506        cx: &mut Context,
507        num: ChanNum,
508    ) -> Poll<Result<()>>;
509
510    fn poll_read_channel(
511        &self,
512        cx: &mut Context,
513        num: ChanNum,
514        dt: ChanData,
515        buf: &mut [u8],
516    ) -> Poll<Result<usize>>;
517
518    fn poll_write_channel(
519        &self,
520        cx: &mut Context,
521        num: ChanNum,
522        dt: ChanData,
523        buf: &[u8],
524    ) -> Poll<Result<usize>>;
525
526    fn poll_term_window_change(
527        &self,
528        cx: &mut Context,
529        num: ChanNum,
530        winch: &sunset::packets::WinChange,
531    ) -> Poll<Result<()>>;
532}
533
534impl<'a, CS: CliServ> ChanCore for AsyncSunset<'a, CS> {
535    fn inc_chan(&self, num: ChanNum) {
536        let c = self.chan_refcounts[num.0 as usize].fetch_add(1, SeqCst);
537        debug_assert_ne!(c, 0);
538        // overflow shouldn't be possible unless ChanIn etc is leaking
539        debug_assert_ne!(c, usize::MAX);
540    }
541
542    fn dec_chan(&self, num: ChanNum) {
543        // refcounts that hit zero will be cleaned up later in clear_refcounts()
544        let c = self.chan_refcounts[num.0 as usize].fetch_sub(1, SeqCst);
545        debug_assert_ne!(c, 0);
546        if c == 1 {
547            // refcount hit zero, progress() will clean it up
548            // in an async context
549            self.wake_progress();
550        }
551    }
552
553    fn poll_until_channel_closed(
554        &self,
555        cx: &mut Context,
556        num: ChanNum,
557    ) -> Poll<Result<()>> {
558        // Attempt to lock .inner
559        let i = self.inner.lock();
560        pin_mut!(i);
561        let Ready(mut inner) = i.poll(cx) else {
562            return Pending;
563        };
564
565        let (runner, h, wakers) = inner.fetch(num)?;
566        if runner.is_channel_closed(h) {
567            Poll::Ready(Ok(()))
568        } else {
569            wakers.chan_close[num.0 as usize].register(cx.waker());
570            Poll::Pending
571        }
572    }
573
574    /// Reads channel data.
575    fn poll_read_channel(
576        &self,
577        cx: &mut Context,
578        num: ChanNum,
579        dt: ChanData,
580        buf: &mut [u8],
581    ) -> Poll<Result<usize>> {
582        // Attempt to lock .inner
583        let i = self.inner.lock();
584        pin_mut!(i);
585        let Ready(mut inner) = i.poll(cx) else {
586            return Pending;
587        };
588
589        let (runner, h, wakers) = inner.fetch(num)?;
590        let i = match runner.read_channel(h, dt, buf) {
591            Ok(0) => {
592                // 0 bytes read, pending
593                match dt {
594                    ChanData::Normal => {
595                        wakers.chan_read[num.0 as usize].register(cx.waker());
596                    }
597                    ChanData::Stderr => {
598                        wakers.chan_ext[num.0 as usize].register(cx.waker());
599                    }
600                }
601                Poll::Pending
602            }
603            Err(Error::ChannelEOF) => Poll::Ready(Ok(0)),
604            r => Poll::Ready(r),
605        };
606        if matches!(i, Poll::Ready(_)) {
607            self.wake_progress()
608        }
609        i
610    }
611
612    fn poll_write_channel(
613        &self,
614        cx: &mut Context,
615        num: ChanNum,
616        dt: ChanData,
617        buf: &[u8],
618    ) -> Poll<Result<usize>> {
619        // Attempt to lock .inner
620        let i = self.inner.lock();
621        pin_mut!(i);
622        let Ready(mut inner) = i.poll(cx) else {
623            return Pending;
624        };
625
626        let (runner, h, wakers) = inner.fetch(num)?;
627        let l = runner.write_channel(h, dt, buf);
628        if let Ok(0) = l {
629            // 0 bytes written, pending
630            match dt {
631                ChanData::Normal => {
632                    wakers.chan_write[num.0 as usize].register(cx.waker());
633                }
634                ChanData::Stderr => {
635                    wakers.chan_ext[num.0 as usize].register(cx.waker());
636                }
637            }
638            Poll::Pending
639        } else {
640            self.wake_progress();
641            Poll::Ready(l)
642        }
643    }
644
645    fn poll_term_window_change(
646        &self,
647        cx: &mut Context,
648        num: ChanNum,
649        winch: &sunset::packets::WinChange,
650    ) -> Poll<Result<()>> {
651        // Attempt to lock .inner
652        let i = self.inner.lock();
653        pin_mut!(i);
654        let Ready(mut inner) = i.poll(cx) else {
655            return Pending;
656        };
657        let (runner, h, _) = inner.fetch(num)?;
658        Poll::Ready(runner.term_window_change(h, winch))
659    }
660}
661
662pub async fn io_copy<const B: usize, R, W>(r: &mut R, w: &mut W) -> Result<()>
663where
664    R: Read<Error = sunset::Error>,
665    W: Write<Error = sunset::Error>,
666{
667    let mut b = [0u8; B];
668    loop {
669        let n = r.read(&mut b).await?;
670        if n == 0 {
671            return sunset::error::ChannelEOF.fail();
672        }
673        let b = &b[..n];
674        w.write_all(b).await?
675    }
676    #[allow(unreachable_code)]
677    Ok::<_, Error>(())
678}
679
680pub async fn io_copy_nowriteerror<const B: usize, R, W>(
681    r: &mut R,
682    w: &mut W,
683) -> Result<()>
684where
685    R: Read<Error = sunset::Error>,
686    W: Write,
687{
688    let mut b = [0u8; B];
689    loop {
690        let n = r.read(&mut b).await?;
691        if n == 0 {
692            return sunset::error::ChannelEOF.fail();
693        }
694        let b = &b[..n];
695        if let Err(_) = w.write_all(b).await {
696            info!("write error");
697        }
698    }
699    #[allow(unreachable_code)]
700    Ok::<_, Error>(())
701}
702
703pub async fn io_buf_copy<R, W>(r: &mut R, w: &mut W) -> Result<()>
704where
705    R: BufRead<Error = sunset::Error>,
706    W: Write<Error = sunset::Error>,
707{
708    loop {
709        let b = r.fill_buf().await?;
710        if b.is_empty() {
711            return sunset::error::ChannelEOF.fail();
712        }
713        let n = b.len();
714        w.write_all(b).await?;
715        r.consume(n)
716    }
717    #[allow(unreachable_code)]
718    Ok::<_, Error>(())
719}
720
721pub async fn io_buf_copy_noreaderror<R, W>(r: &mut R, w: &mut W) -> Result<()>
722where
723    R: BufRead,
724    W: Write<Error = sunset::Error>,
725{
726    loop {
727        let b = match r.fill_buf().await {
728            Ok(b) => b,
729            Err(_) => {
730                info!("read error");
731                embassy_futures::yield_now().await;
732                continue;
733            }
734        };
735        if b.is_empty() {
736            return sunset::error::ChannelEOF.fail();
737        }
738        let n = b.len();
739        w.write_all(b).await?;
740        r.consume(n)
741    }
742    #[allow(unreachable_code)]
743    Ok::<_, Error>(())
744}