1#[allow(unused_imports)]
2pub use log::{debug, error, info, log, trace, warn};
3
4use core::future::{poll_fn, Future};
5use core::sync::atomic::AtomicBool;
6use core::sync::atomic::Ordering::{Relaxed, SeqCst};
7use core::task::{Context, Poll, Poll::Pending, Poll::Ready};
8
9use embassy_futures::join;
10use embassy_futures::select::select;
11#[allow(unused_imports)]
12use embassy_sync::blocking_mutex::raw::{CriticalSectionRawMutex, NoopRawMutex};
13use embassy_sync::mutex::{Mutex, MutexGuard};
14use embassy_sync::signal::Signal;
15use embassy_sync::waitqueue::WakerRegistration;
16use embedded_io_async::{BufRead, Read, Write};
17
18use pin_utils::pin_mut;
19
20use sunset::config::MAX_CHANNELS;
21use sunset::event::Event;
22use sunset::{error, ChanData, ChanHandle, ChanNum, CliServ, Error, Result, Runner};
23
24#[cfg(feature = "multi-thread")]
25pub type SunsetRawMutex = CriticalSectionRawMutex;
26#[cfg(not(feature = "multi-thread"))]
27pub type SunsetRawMutex = NoopRawMutex;
28
29pub type SunsetMutex<T> = Mutex<SunsetRawMutex, T>;
30
31#[derive(Debug)]
32struct Wakers {
33 chan_read: [WakerRegistration; MAX_CHANNELS],
34
35 chan_write: [WakerRegistration; MAX_CHANNELS],
36
37 chan_ext: [WakerRegistration; MAX_CHANNELS],
40
41 chan_close: [WakerRegistration; MAX_CHANNELS],
43}
44
45struct Inner<'a, CS: CliServ> {
46 runner: Runner<'a, CS>,
47
48 wakers: Wakers,
49
50 chan_handles: [Option<ChanHandle>; MAX_CHANNELS],
53}
54
55impl<'a, CS: CliServ> Inner<'a, CS> {
56 fn fetch(
60 &mut self,
61 num: ChanNum,
62 ) -> Result<(&mut Runner<'a, CS>, &ChanHandle, &mut Wakers)> {
63 let h = self
64 .chan_handles
65 .get(num.0 as usize)
66 .ok_or(Error::BadChannel { num })?;
67 h.as_ref()
68 .map(|ch| (&mut self.runner, ch, &mut self.wakers))
69 .ok_or_else(Error::bug)
70 }
71}
72
73pub struct ProgressHolder<'g, 'a, CS: CliServ> {
76 guard: Option<MutexGuard<'g, SunsetRawMutex, Inner<'a, CS>>>,
77}
78
79impl<'g, 'a, CS: CliServ> ProgressHolder<'g, 'a, CS> {
80 pub fn new() -> Self {
81 Self { guard: None }
82 }
83}
84
85impl<CS: CliServ> Default for ProgressHolder<'_, '_, CS> {
86 fn default() -> Self {
87 Self::new()
88 }
89}
90
91pub(crate) struct AsyncSunset<'a, CS: CliServ> {
98 inner: SunsetMutex<Inner<'a, CS>>,
99
100 progress_notify: Signal<SunsetRawMutex, ()>,
101 last_progress_idled: AtomicBool,
102
103 moribund: AtomicBool,
108
109 chan_refcounts: [portable_atomic::AtomicUsize; MAX_CHANNELS],
115}
116
117impl<CS: CliServ> core::fmt::Debug for AsyncSunset<'_, CS> {
118 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
119 let mut d = f.debug_struct("AsyncSunset");
120 if let Ok(i) = self.inner.try_lock() {
121 d.field("runner", &i.runner);
122 } else {
123 d.field("inner", &"(locked)");
124 }
125 d.finish_non_exhaustive()
126 }
127}
128
129impl<'a, CS: CliServ> AsyncSunset<'a, CS> {
130 pub fn new(runner: Runner<'a, CS>) -> Self {
131 let wakers = Wakers {
132 chan_read: Default::default(),
133 chan_write: Default::default(),
134 chan_ext: Default::default(),
135 chan_close: Default::default(),
136 };
137 let inner = Inner { runner, wakers, chan_handles: Default::default() };
138 let inner = Mutex::new(inner);
139
140 let progress_notify = Signal::new();
141
142 Self {
143 inner,
144 moribund: AtomicBool::new(false),
145 progress_notify,
146 chan_refcounts: Default::default(),
147 last_progress_idled: AtomicBool::new(false),
148 }
149 }
150
151 pub async fn run(
153 &self,
154 rsock: &mut impl Read,
155 wsock: &mut impl Write,
156 ) -> Result<()> {
157 let tx_stop = Signal::<SunsetRawMutex, ()>::new();
161 let rx_stop = Signal::<SunsetRawMutex, ()>::new();
162
163 let tx = async {
164 let r = self
165 .output_loop(wsock)
166 .await
167 .inspect(|r| warn!("tx complete {r:?}"));
168 r
169 };
170 let tx = select(tx, tx_stop.wait());
171
172 let mut rxbuf = [0; 1024];
174 let rx = async {
175 loop {
176 let l = match rsock.read(&mut rxbuf).await {
178 Ok(0) => {
179 debug!("net EOF");
180 self.with_runner(|r| r.close_input()).await;
181 self.moribund.store(true, Relaxed);
182 self.wake_progress();
183 break Ok(());
184 }
185 Ok(l) => l,
186 Err(_) => {
187 info!("socket read error");
188 self.with_runner(|r| r.close_input()).await;
189 break Err(Error::ChannelEOF);
190 }
191 };
192 let mut rxbuf = &rxbuf[..l];
193 while !rxbuf.is_empty() {
194 let n = self.input(rxbuf).await?;
195 self.wake_progress();
196 rxbuf = &rxbuf[n..];
197 }
198 }
199 .inspect(|r| warn!("rx complete {r:?}"))
200 };
201
202 let rx = async {
204 let r = select(rx, rx_stop.wait()).await;
205 tx_stop.signal(());
206 r
207 };
208
209 let f = join::join(rx, tx).await;
213 let (_frx, _ftx) = f;
214
215 Ok(())
224 }
225
226 fn wake_progress(&self) {
227 self.progress_notify.signal(())
228 }
229
230 fn wake_channels(&self, inner: &mut Inner<CS>) -> Result<()> {
231 let w = &mut inner.wakers;
233 if let Some((num, dt, _len)) = inner.runner.read_channel_ready() {
234 let waker = match dt {
235 ChanData::Normal => &mut w.chan_read[num.0 as usize],
236 ChanData::Stderr => &mut w.chan_ext[num.0 as usize],
237 };
238 if waker.occupied() {
239 waker.wake();
240 } else {
241 if let Some(h) = &inner.chan_handles[num.0 as usize] {
247 inner.runner.discard_read_channel(h)?
248 }
249 }
250 }
251
252 for (idx, c) in inner.chan_handles.iter().enumerate() {
253 let ch = if let Some(ch) = c.as_ref() { ch } else { continue };
254
255 if inner.runner.write_channel_ready(ch, ChanData::Normal)?.unwrap_or(0)
260 > 0
261 {
262 w.chan_write[idx].wake()
263 }
264
265 if !CS::is_client()
266 && inner
267 .runner
268 .write_channel_ready(ch, ChanData::Stderr)?
269 .unwrap_or(0)
270 > 0
271 {
272 w.chan_ext[idx].wake()
273 }
274
275 if inner.runner.is_channel_eof(ch) {
277 w.chan_read[idx].wake();
278 if CS::is_client() {
279 w.chan_ext[idx].wake();
280 }
281 }
282
283 if inner.runner.is_channel_closed(ch) {
284 w.chan_close[idx].wake();
285 }
286 }
287 Ok(())
288 }
289
290 fn clear_refcounts(&self, inner: &mut Inner<CS>) -> Result<()> {
297 for (ch, count) in
298 inner.chan_handles.iter_mut().zip(self.chan_refcounts.iter())
299 {
300 let count = count.load(Relaxed);
301 if count > 0 {
302 debug_assert!(ch.is_some());
303 continue;
304 }
305 if let Some(ch) = ch.take() {
306 inner.runner.channel_done(ch)?;
308 }
309 }
310 Ok(())
311 }
312
313 pub(crate) async fn progress<'g, 'f>(
317 &'g self,
318 ph: &'f mut ProgressHolder<'g, 'a, CS>,
319 ) -> Result<Event<'f, 'a>> {
320 *ph = ProgressHolder::default();
322
323 let need_wait = self.last_progress_idled.load(Relaxed);
335 if need_wait {
336 self.last_progress_idled.store(false, Relaxed);
337 self.progress_notify.wait().await;
338 }
339
340 let inner = ph.guard.insert(self.inner.lock().await);
342
343 self.clear_refcounts(inner)?;
345 self.wake_channels(inner)?;
347
348 if self.moribund.load(Relaxed) {
349 debug!("All data flushed")
351 }
353
354 let ev = inner.runner.progress();
355 if matches!(ev, Ok(Event::None)) {
356 self.last_progress_idled.store(true, Relaxed);
358 }
359 ev
360 }
361
362 pub(crate) async fn with_runner<F, R>(&self, f: F) -> R
363 where
364 F: FnOnce(&mut Runner<CS>) -> R,
365 {
366 let mut inner = self.inner.lock().await;
367 f(&mut inner.runner)
368 }
369
370 async fn poll_inner<F, T>(&self, mut f: F) -> T
372 where
373 F: FnMut(&mut Inner<CS>, &mut Context) -> Poll<T>,
374 {
375 poll_fn(|cx| {
376 let i = self.inner.lock();
378 pin_mut!(i);
379 match i.poll(cx) {
380 Poll::Ready(mut inner) => f(&mut inner, cx),
381 Poll::Pending => {
382 Poll::Pending
384 }
385 }
386 })
387 .await
388 }
389
390 pub async fn output_loop(&self, wsock: &mut impl Write) -> Result<()> {
391 poll_fn(|cx| {
392 let i = self.inner.lock();
394 pin_mut!(i);
395 let Ready(mut inner) = i.poll(cx) else {
396 return Pending;
397 };
398
399 loop {
400 let buf = inner.runner.output_buf();
401 if buf.is_empty() {
402 inner.runner.set_output_waker(cx.waker());
404 return Pending;
405 }
406
407 let res = {
408 let w = wsock.write(buf);
409 pin_mut!(w);
410 w.poll(cx)
411 };
412
413 return match res {
414 Pending => Pending,
415 Ready(Ok(0)) => {
416 info!("socket EOF");
417 inner.runner.close_output();
418 Ready(error::ChannelEOF.fail())
419 }
420 Ready(Ok(write_len)) => {
421 let buf_len = buf.len();
422 inner.runner.consume_output(write_len);
423 if write_len < buf_len {
424 continue;
428 }
429 Pending
430 }
431 Ready(Err(_e)) => {
432 info!("socket write error");
433 inner.runner.close_output();
434 Ready(error::ChannelEOF.fail())
435 }
436 };
437 }
438 })
439 .await
440 }
441
442 pub async fn input(&self, buf: &[u8]) -> Result<usize> {
443 let res = self
444 .poll_inner(|inner, cx| {
445 if inner.runner.is_input_ready() {
446 match inner.runner.input(buf) {
447 Ok(0) => {
448 inner.runner.set_input_waker(cx.waker());
449 Poll::Pending
450 }
451 Ok(n) => Poll::Ready(Ok(n)),
452 Err(e) => Poll::Ready(Err(e)),
453 }
454 } else {
455 inner.runner.set_input_waker(cx.waker());
456 Poll::Pending
457 }
458 })
459 .await;
460 self.wake_progress();
461 res
462 }
463
464 pub(crate) async fn add_channel(
472 &self,
473 handle: ChanHandle,
474 init_refcount: usize,
475 ) -> Result<()> {
476 let mut inner = self.inner.lock().await;
477 let idx = handle.num().0 as usize;
478 if inner.chan_handles[idx].is_some() {
479 return error::Bug.fail();
480 }
481
482 debug_assert_eq!(self.chan_refcounts[idx].load(Relaxed), 0);
483
484 inner.chan_handles[idx] = Some(handle);
485 self.chan_refcounts[idx].store(init_refcount, Relaxed);
486 Ok(())
487 }
488}
489
490#[cfg(feature = "multi-thread")]
492pub(crate) trait MaybeSend: Sync {}
493#[cfg(not(feature = "multi-thread"))]
494pub(crate) trait MaybeSend {}
495
496impl<'a, CS: CliServ> MaybeSend for AsyncSunset<'a, CS> {}
497
498pub(crate) trait ChanCore: MaybeSend {
501 fn inc_chan(&self, num: ChanNum);
502 fn dec_chan(&self, num: ChanNum);
503
504 fn poll_until_channel_closed(
505 &self,
506 cx: &mut Context,
507 num: ChanNum,
508 ) -> Poll<Result<()>>;
509
510 fn poll_read_channel(
511 &self,
512 cx: &mut Context,
513 num: ChanNum,
514 dt: ChanData,
515 buf: &mut [u8],
516 ) -> Poll<Result<usize>>;
517
518 fn poll_write_channel(
519 &self,
520 cx: &mut Context,
521 num: ChanNum,
522 dt: ChanData,
523 buf: &[u8],
524 ) -> Poll<Result<usize>>;
525
526 fn poll_term_window_change(
527 &self,
528 cx: &mut Context,
529 num: ChanNum,
530 winch: &sunset::packets::WinChange,
531 ) -> Poll<Result<()>>;
532}
533
534impl<'a, CS: CliServ> ChanCore for AsyncSunset<'a, CS> {
535 fn inc_chan(&self, num: ChanNum) {
536 let c = self.chan_refcounts[num.0 as usize].fetch_add(1, SeqCst);
537 debug_assert_ne!(c, 0);
538 debug_assert_ne!(c, usize::MAX);
540 }
541
542 fn dec_chan(&self, num: ChanNum) {
543 let c = self.chan_refcounts[num.0 as usize].fetch_sub(1, SeqCst);
545 debug_assert_ne!(c, 0);
546 if c == 1 {
547 self.wake_progress();
550 }
551 }
552
553 fn poll_until_channel_closed(
554 &self,
555 cx: &mut Context,
556 num: ChanNum,
557 ) -> Poll<Result<()>> {
558 let i = self.inner.lock();
560 pin_mut!(i);
561 let Ready(mut inner) = i.poll(cx) else {
562 return Pending;
563 };
564
565 let (runner, h, wakers) = inner.fetch(num)?;
566 if runner.is_channel_closed(h) {
567 Poll::Ready(Ok(()))
568 } else {
569 wakers.chan_close[num.0 as usize].register(cx.waker());
570 Poll::Pending
571 }
572 }
573
574 fn poll_read_channel(
576 &self,
577 cx: &mut Context,
578 num: ChanNum,
579 dt: ChanData,
580 buf: &mut [u8],
581 ) -> Poll<Result<usize>> {
582 let i = self.inner.lock();
584 pin_mut!(i);
585 let Ready(mut inner) = i.poll(cx) else {
586 return Pending;
587 };
588
589 let (runner, h, wakers) = inner.fetch(num)?;
590 let i = match runner.read_channel(h, dt, buf) {
591 Ok(0) => {
592 match dt {
594 ChanData::Normal => {
595 wakers.chan_read[num.0 as usize].register(cx.waker());
596 }
597 ChanData::Stderr => {
598 wakers.chan_ext[num.0 as usize].register(cx.waker());
599 }
600 }
601 Poll::Pending
602 }
603 Err(Error::ChannelEOF) => Poll::Ready(Ok(0)),
604 r => Poll::Ready(r),
605 };
606 if matches!(i, Poll::Ready(_)) {
607 self.wake_progress()
608 }
609 i
610 }
611
612 fn poll_write_channel(
613 &self,
614 cx: &mut Context,
615 num: ChanNum,
616 dt: ChanData,
617 buf: &[u8],
618 ) -> Poll<Result<usize>> {
619 let i = self.inner.lock();
621 pin_mut!(i);
622 let Ready(mut inner) = i.poll(cx) else {
623 return Pending;
624 };
625
626 let (runner, h, wakers) = inner.fetch(num)?;
627 let l = runner.write_channel(h, dt, buf);
628 if let Ok(0) = l {
629 match dt {
631 ChanData::Normal => {
632 wakers.chan_write[num.0 as usize].register(cx.waker());
633 }
634 ChanData::Stderr => {
635 wakers.chan_ext[num.0 as usize].register(cx.waker());
636 }
637 }
638 Poll::Pending
639 } else {
640 self.wake_progress();
641 Poll::Ready(l)
642 }
643 }
644
645 fn poll_term_window_change(
646 &self,
647 cx: &mut Context,
648 num: ChanNum,
649 winch: &sunset::packets::WinChange,
650 ) -> Poll<Result<()>> {
651 let i = self.inner.lock();
653 pin_mut!(i);
654 let Ready(mut inner) = i.poll(cx) else {
655 return Pending;
656 };
657 let (runner, h, _) = inner.fetch(num)?;
658 Poll::Ready(runner.term_window_change(h, winch))
659 }
660}
661
662pub async fn io_copy<const B: usize, R, W>(r: &mut R, w: &mut W) -> Result<()>
663where
664 R: Read<Error = sunset::Error>,
665 W: Write<Error = sunset::Error>,
666{
667 let mut b = [0u8; B];
668 loop {
669 let n = r.read(&mut b).await?;
670 if n == 0 {
671 return sunset::error::ChannelEOF.fail();
672 }
673 let b = &b[..n];
674 w.write_all(b).await?
675 }
676 #[allow(unreachable_code)]
677 Ok::<_, Error>(())
678}
679
680pub async fn io_copy_nowriteerror<const B: usize, R, W>(
681 r: &mut R,
682 w: &mut W,
683) -> Result<()>
684where
685 R: Read<Error = sunset::Error>,
686 W: Write,
687{
688 let mut b = [0u8; B];
689 loop {
690 let n = r.read(&mut b).await?;
691 if n == 0 {
692 return sunset::error::ChannelEOF.fail();
693 }
694 let b = &b[..n];
695 if let Err(_) = w.write_all(b).await {
696 info!("write error");
697 }
698 }
699 #[allow(unreachable_code)]
700 Ok::<_, Error>(())
701}
702
703pub async fn io_buf_copy<R, W>(r: &mut R, w: &mut W) -> Result<()>
704where
705 R: BufRead<Error = sunset::Error>,
706 W: Write<Error = sunset::Error>,
707{
708 loop {
709 let b = r.fill_buf().await?;
710 if b.is_empty() {
711 return sunset::error::ChannelEOF.fail();
712 }
713 let n = b.len();
714 w.write_all(b).await?;
715 r.consume(n)
716 }
717 #[allow(unreachable_code)]
718 Ok::<_, Error>(())
719}
720
721pub async fn io_buf_copy_noreaderror<R, W>(r: &mut R, w: &mut W) -> Result<()>
722where
723 R: BufRead,
724 W: Write<Error = sunset::Error>,
725{
726 loop {
727 let b = match r.fill_buf().await {
728 Ok(b) => b,
729 Err(_) => {
730 info!("read error");
731 embassy_futures::yield_now().await;
732 continue;
733 }
734 };
735 if b.is_empty() {
736 return sunset::error::ChannelEOF.fail();
737 }
738 let n = b.len();
739 w.write_all(b).await?;
740 r.consume(n)
741 }
742 #[allow(unreachable_code)]
743 Ok::<_, Error>(())
744}