1#[allow(unused_imports)]
2#[cfg(not(feature = "defmt"))]
3pub use log::{debug, error, info, log, trace, warn};
4
5#[allow(unused_imports)]
6#[cfg(feature = "defmt")]
7pub use defmt::{debug, error, info, panic, trace, warn};
8
9use core::future::{poll_fn, Future};
10use core::task::{Poll, Context};
11use core::ops::{ControlFlow, DerefMut};
12use core::sync::atomic::AtomicBool;
13use core::sync::atomic::Ordering::{Relaxed, SeqCst};
14
15use embassy_sync::waitqueue::WakerRegistration;
16use embassy_sync::blocking_mutex::raw::{NoopRawMutex, RawMutex};
17use embassy_sync::mutex::Mutex;
18use embassy_sync::signal::Signal;
19use embassy_futures::select::select;
20use embassy_futures::join;
21use embedded_io_async::{Read, Write, BufRead};
22
23use atomic_polyfill::AtomicUsize;
25
26use pin_utils::pin_mut;
27
28use sunset::{Runner, Result, Error, error, Behaviour, ChanData, ChanHandle, ChanNum, CliBehaviour, ServBehaviour};
29use sunset::config::MAX_CHANNELS;
30
31pub type SunsetRawMutex = NoopRawMutex;
36
37pub type SunsetMutex<T> = Mutex<SunsetRawMutex, T>;
38
39struct Wakers {
40 chan_read: [WakerRegistration; MAX_CHANNELS],
41
42 chan_write: [WakerRegistration; MAX_CHANNELS],
43
44 chan_ext: [WakerRegistration; MAX_CHANNELS],
47
48 chan_close: [WakerRegistration; MAX_CHANNELS],
50}
51
52struct Inner<'a> {
53 runner: Runner<'a>,
54
55 wakers: Wakers,
56
57 chan_handles: [Option<ChanHandle>; MAX_CHANNELS],
60}
61
62impl<'a> Inner<'a> {
63 fn fetch(&mut self, num: ChanNum) -> Result<(&mut Runner<'a>, &ChanHandle, &mut Wakers)> {
67 self.chan_handles[num.0 as usize].as_ref().map(|ch| {
68 (&mut self.runner, ch, &mut self.wakers)
69 })
70 .ok_or_else(Error::bug)
71 }
72}
73
74pub(crate) struct EmbassySunset<'a> {
81 inner: SunsetMutex<Inner<'a>>,
82
83 progress_notify: Signal<SunsetRawMutex, ()>,
84
85 exit: AtomicBool,
89 flushing: AtomicBool,
90
91 chan_refcounts: [AtomicUsize; MAX_CHANNELS],
95}
96
97impl<'a> EmbassySunset<'a> {
98 pub fn new(runner: Runner<'a>) -> Self {
99 let wakers = Wakers {
100 chan_read: Default::default(),
101 chan_write: Default::default(),
102 chan_ext: Default::default(),
103 chan_close: Default::default(),
104 };
105 let inner = Inner {
106 runner,
107 wakers,
108 chan_handles: Default::default(),
109 };
110 let inner = Mutex::new(inner);
111
112 let progress_notify = Signal::new();
113
114 Self {
115 inner,
116 exit: AtomicBool::new(false),
117 flushing: AtomicBool::new(false),
118 progress_notify,
119 chan_refcounts: Default::default(),
120 }
121 }
122
123 pub async fn run<B: ?Sized, M: RawMutex, C: CliBehaviour, S: ServBehaviour>(&self,
129 rsock: &mut impl Read,
130 wsock: &mut impl Write,
131 b: &Mutex<M, B>) -> Result<()>
132 where
133 for<'f> Behaviour<'f, C, S>: From<&'f mut B>
134 {
135 let tx_stop = Signal::<SunsetRawMutex, ()>::new();
139 let rx_stop = Signal::<SunsetRawMutex, ()>::new();
140
141 let tx = async {
142 loop {
143 let mut buf = [0; 1024];
146 let l = self.output(&mut buf).await?;
147 wsock.write_all(&buf[..l]).await
148 .map_err(|_| {
149 info!("socket write error");
150 Error::ChannelEOF
151 })?;
152 }
153 #[allow(unreachable_code)]
154 Ok::<_, sunset::Error>(())
155 };
156 let tx = select(tx, tx_stop.wait());
157
158 let rx = async {
159 loop {
160 let mut buf = [0; 1024];
162 let l = rsock.read(&mut buf).await
163 .map_err(|_| {
164 info!("socket read error");
165 Error::ChannelEOF
166 })?;
167 if l == 0 {
168 debug!("net EOF");
169 self.flushing.store(true, Relaxed);
170 self.wake_progress();
171 break
172 }
173 let mut buf = &buf[..l];
174 while !buf.is_empty() {
175 let n = self.input(buf).await?;
176 buf = &buf[n..];
177 }
178 }
179 Ok::<_, sunset::Error>(())
180 };
181
182 let rx = select(rx, rx_stop.wait());
184 let rx = async {
185 let r = rx.await;
186 tx_stop.signal(());
187 r
188 };
189
190 let prog = async {
191 loop {
192 if self.progress(b).await?.is_break() {
193 break Ok(())
194 }
195 }
196 };
197
198 let prog = async {
199 let r = prog.await;
200 self.with_runner(|runner| runner.close()).await;
201 rx_stop.signal(());
202 r
203 };
204
205 let f = join::join3(prog, rx, tx).await;
209 let (fp, _frx, _ftx) = f;
210
211 fp
221 }
222
223 fn wake_progress(&self) {
224 self.progress_notify.signal(())
225 }
226
227 pub async fn exit(&self) {
228 self.exit.store(true, Relaxed);
229 self.wake_progress()
230 }
231
232 fn wake_channels(&self, inner: &mut Inner) -> Result<()> {
233 let w = &mut inner.wakers;
235 if let Some((num, dt, _len)) = inner.runner.ready_channel_input() {
236 match dt {
238 ChanData::Normal => w.chan_read[num.0 as usize].wake(),
239 ChanData::Stderr => w.chan_ext[num.0 as usize].wake(),
240 }
241 }
242
243 for (idx, c) in inner.chan_handles.iter().enumerate() {
244 let ch = if let Some(ch) = c.as_ref() {
245 ch
246 } else {
247 continue
248 };
249
250 if inner.runner.ready_channel_send(ch, ChanData::Normal)?.unwrap_or(0) > 0 {
256 w.chan_write[idx].wake()
257 }
258
259 if !inner.runner.is_client() {
260 if inner.runner.ready_channel_send(ch, ChanData::Stderr)?.unwrap_or(0) > 0 {
261 w.chan_ext[idx].wake()
262 }
263 }
264
265 if inner.runner.is_channel_eof(ch) {
267 w.chan_read[idx].wake();
268 if inner.runner.is_client() {
269 w.chan_ext[idx].wake();
270 }
271 }
272
273 if inner.runner.is_channel_closed(ch) {
274 w.chan_close[idx].wake();
275 }
276 }
277 Ok(())
278 }
279
280 fn clear_refcounts(&self, inner: &mut Inner) -> Result<()> {
287 for (ch, count) in inner.chan_handles.iter_mut().zip(self.chan_refcounts.iter()) {
288 let count = count.load(Relaxed);
289 if count > 0 {
290 debug_assert!(ch.is_some());
291 continue;
292 }
293 if let Some(ch) = ch.take() {
294 inner.runner.channel_done(ch)?;
296 }
297 }
298 Ok(())
299 }
300
301 async fn progress<B: ?Sized, M: RawMutex, C: CliBehaviour, S: ServBehaviour>(&self,
305 b: &Mutex<M, B>)
306 -> Result<ControlFlow<()>>
307 where
308 for<'f> Behaviour<'f, C, S>: From<&'f mut B>
309 {
310 let ret;
311
312 {
313 if self.exit.load(Relaxed) {
314 return Ok(ControlFlow::Break(()))
315 }
316
317 let mut inner = self.inner.lock().await;
318 {
319 {
320 let mut b = b.lock().await;
322 let b = b.deref_mut();
324 let mut b: Behaviour<C, S> = b.into();
327 ret = inner.runner.progress(&mut b).await?;
328 }
330
331 self.wake_channels(&mut inner)?;
332
333 self.clear_refcounts(&mut inner)?;
334 }
335 }
337
338 if ret.disconnected {
339 return Ok(ControlFlow::Break(()))
340 }
341
342 if !ret.progressed {
343 if self.flushing.load(Relaxed) {
344 return Ok(ControlFlow::Break(()))
346 }
347 self.progress_notify.wait().await;
350 }
351
352 Ok(ControlFlow::Continue(()))
353 }
354
355 pub(crate) async fn with_runner<F, R>(&self, f: F) -> R
356 where F: FnOnce(&mut Runner) -> R {
357 let mut inner = self.inner.lock().await;
358 f(&mut inner.runner)
359 }
360
361 async fn poll_inner<F, T>(&self, mut f: F) -> T
363 where F: FnMut(&mut Inner, &mut Context) -> Poll<T> {
364 poll_fn(|cx| {
365 let i = self.inner.lock();
367 pin_mut!(i);
368 match i.poll(cx) {
369 Poll::Ready(mut inner) => {
370 f(&mut inner, cx)
371 }
372 Poll::Pending => {
373 Poll::Pending
375 }
376 }
377 })
378 .await
379 }
380
381 pub async fn output(&self, buf: &mut [u8]) -> Result<usize> {
382 self.poll_inner(|inner, cx| {
383 match inner.runner.output(buf) {
384 Ok(0) => {
386 inner.runner.set_output_waker(cx.waker());
387 Poll::Pending
388 }
389 Ok(n) => Poll::Ready(Ok(n)),
390 Err(e) => Poll::Ready(Err(e)),
391 }
392 }).await
393 }
394
395 pub async fn input(&self, buf: &[u8]) -> Result<usize> {
396 self.poll_inner(|inner, cx| {
397 if inner.runner.is_input_ready() {
398 let r = match inner.runner.input(buf) {
399 Ok(0) => {
400 inner.runner.set_input_waker(cx.waker());
401 Poll::Pending
402 },
403 Ok(n) => Poll::Ready(Ok(n)),
404 Err(e) => Poll::Ready(Err(e)),
405 };
406 if r.is_ready() {
407 self.wake_progress()
408 }
409 r
410 } else {
411 inner.runner.set_input_waker(cx.waker());
412 Poll::Pending
413 }
414 }).await
415 }
416
417 pub(crate) async fn read_channel(&self, num: ChanNum, dt: ChanData, buf: &mut [u8]) -> Result<usize> {
419 if num.0 as usize > MAX_CHANNELS {
420 return sunset::error::BadChannel { num }.fail()
421 }
422 self.poll_inner(|inner, cx| {
423 let (runner, h, wakers) = inner.fetch(num)?;
424 let i = match runner.channel_input(h, dt, buf) {
425 Ok(0) => {
426 match dt {
428 ChanData::Normal => {
429 wakers.chan_read[num.0 as usize].register(cx.waker());
430 }
431 ChanData::Stderr => {
432 wakers.chan_ext[num.0 as usize].register(cx.waker());
433 }
434 }
435 Poll::Pending
436 }
437 Err(Error::ChannelEOF) => {
438 Poll::Ready(Ok(0))
439 }
440 r => Poll::Ready(r),
441 };
442 if matches!(i, Poll::Ready(_)) {
443 self.wake_progress()
444 }
445 i
446 }).await
447 }
448
449 pub(crate) async fn write_channel(&self, num: ChanNum, dt: ChanData, buf: &[u8]) -> Result<usize> {
450 if num.0 as usize > MAX_CHANNELS {
451 return sunset::error::BadChannel { num }.fail()
452 }
453 self.poll_inner(|inner, cx| {
454 let (runner, h, wakers) = inner.fetch(num)?;
455 let l = runner.channel_send(h, dt, buf);
456 if let Ok(0) = l {
457 match dt {
459 ChanData::Normal => {
460 wakers.chan_write[num.0 as usize].register(cx.waker());
461 }
462 ChanData::Stderr => {
463 wakers.chan_ext[num.0 as usize].register(cx.waker());
464 }
465 }
466 Poll::Pending
467 } else {
468 self.wake_progress();
469 Poll::Ready(l)
470 }
471 }).await
472 }
473
474 pub(crate) async fn until_channel_closed(&self, num: ChanNum) -> Result<()> {
475 self.poll_inner(|inner, cx| {
476 let (runner, h, wakers) = inner.fetch(num)?;
477 if runner.is_channel_closed(h) {
478 Poll::Ready(Ok(()))
479 } else {
480 wakers.chan_close[num.0 as usize].register(cx.waker());
481 Poll::Pending
482 }
483 }).await
484 }
485
486 pub async fn term_window_change(&self, num: ChanNum, winch: sunset::packets::WinChange) -> Result<()> {
487 let mut inner = self.inner.lock().await;
488 let (runner, h, _) = inner.fetch(num)?;
489 runner.term_window_change(h, winch)
490 }
491
492 pub(crate) async fn add_channel(&self, handle: ChanHandle, init_refcount: usize) -> Result<()> {
500 let mut inner = self.inner.lock().await;
501 let idx = handle.num().0 as usize;
502 if inner.chan_handles[idx].is_some() {
503 return error::Bug.fail()
504 }
505
506 debug_assert_eq!(self.chan_refcounts[idx].load(Relaxed), 0);
507
508 inner.chan_handles[idx] = Some(handle);
509 self.chan_refcounts[idx].store(init_refcount, Relaxed);
510 Ok(())
511 }
512
513 pub(crate) fn inc_chan(&self, num: ChanNum) {
514 let c = self.chan_refcounts[num.0 as usize].fetch_add(1, SeqCst);
515 debug_assert_ne!(c, 0);
516 debug_assert_ne!(c, usize::MAX);
518 self.wake_progress();
520 }
521
522 pub(crate) fn dec_chan(&self, num: ChanNum) {
523 let c = self.chan_refcounts[num.0 as usize].fetch_sub(1, SeqCst);
525 debug_assert_ne!(c, 0);
526 self.wake_progress();
528 }
529}
530
531
532pub async fn io_copy<const B: usize, R, W>(r: &mut R, w: &mut W) -> Result<()>
533 where R: Read<Error=sunset::Error>,
534 W: Write<Error=sunset::Error>
535{
536 let mut b = [0u8; B];
537 loop {
538 let n = r.read(&mut b).await?;
539 if n == 0 {
540 return sunset::error::ChannelEOF.fail();
541 }
542 let b = &b[..n];
543 w.write_all(b).await?
544 }
545 #[allow(unreachable_code)]
546 Ok::<_, Error>(())
547}
548
549pub async fn io_copy_nowriteerror<const B: usize, R, W>(r: &mut R, w: &mut W) -> Result<()>
550 where R: Read<Error=sunset::Error>,
551 W: Write,
552{
553 let mut b = [0u8; B];
554 loop {
555 let n = r.read(&mut b).await?;
556 if n == 0 {
557 return sunset::error::ChannelEOF.fail();
558 }
559 let b = &b[..n];
560 if let Err(_) = w.write_all(b).await {
561 info!("write error");
562 }
563 }
564 #[allow(unreachable_code)]
565 Ok::<_, Error>(())
566}
567
568pub async fn io_buf_copy<R, W>(r: &mut R, w: &mut W) -> Result<()>
569 where R: BufRead<Error=sunset::Error>,
570 W: Write<Error=sunset::Error>
571{
572 loop {
573 let b = r.fill_buf().await?;
574 if b.is_empty() {
575 return sunset::error::ChannelEOF.fail();
576 }
577 let n = b.len();
578 w.write_all(b).await?;
579 r.consume(n)
580 }
581 #[allow(unreachable_code)]
582 Ok::<_, Error>(())
583}
584
585pub async fn io_buf_copy_noreaderror<R, W>(r: &mut R, w: &mut W) -> Result<()>
586 where R: BufRead,
587 W: Write<Error=sunset::Error>
588{
589 loop {
590 let b = match r.fill_buf().await {
591 Ok(b) => b,
592 Err(_) => {
593 info!("read error");
594 embassy_futures::yield_now().await;
595 continue;
596 }
597 };
598 if b.is_empty() {
599 return sunset::error::ChannelEOF.fail();
600 }
601 let n = b.len();
602 w.write_all(b).await?;
603 r.consume(n)
604 }
605 #[allow(unreachable_code)]
606 Ok::<_, Error>(())
607}