1use futures::TryStreamExt as _;
12use simploxide_api_types::events::{Event, EventData};
13#[cfg(feature = "cancellation")]
14use tokio_util::sync::CancellationToken;
15
16use std::{future::Future, pin::Pin, sync::Arc};
17
18use crate::{EventParser, EventStream, StreamEvents};
19
20pub struct DispatchChain<P, Ctx> {
22 events: EventStream<P>,
23 ctx: Ctx,
24}
25
26impl<P, Ctx> DispatchChain<P, Ctx>
27where
28 P: EventParser,
29{
30 pub fn with_ctx(events: EventStream<P>, ctx: Ctx) -> Self {
31 Self { ctx, events }
32 }
33
34 pub fn seq_fallback<E, F>(mut self, f: F) -> Dispatcher<P, Ctx, Fallback<F>>
36 where
37 F: AsyncFnMut(Event, &mut Ctx) -> Result<StreamEvents, E>,
38 {
39 self.events.accept_all();
40 Dispatcher {
41 events: self.events,
42 ctx: self.ctx,
43 chain: Fallback { f },
44 }
45 }
46
47 pub fn fallback<E, F, Fut>(mut self, f: F) -> Dispatcher<P, Ctx, Fallback<F>>
52 where
53 Ctx: 'static + Send,
54 E: 'static + Send + From<P::Error>,
55 F: Fn(Event, Ctx) -> Fut,
56 Fut: 'static + Send + Future<Output = Result<StreamEvents, E>>,
57 {
58 self.events.accept_all();
59 Dispatcher {
60 events: self.events,
61 ctx: self.ctx,
62 chain: Fallback { f },
63 }
64 }
65
66 pub fn seq<Ev, E, F>(mut self, f: F) -> Dispatcher<P, Ctx, Match<Ev, F>>
70 where
71 Ev: EventData,
72 F: AsyncFnMut(Arc<Ev>, &mut Ctx) -> Result<StreamEvents, E>,
73 {
74 self.events.reject_all();
75 self.events.accept(Ev::KIND);
76 Dispatcher {
77 events: self.events,
78 ctx: self.ctx,
79 chain: Match {
80 f,
81 _phantom: std::marker::PhantomData,
82 },
83 }
84 }
85
86 pub fn on<Ev, E, F, Fut>(mut self, f: F) -> Dispatcher<P, Ctx, Match<Ev, F>>
92 where
93 Ctx: 'static + Send,
94 E: 'static + Send,
95 Ev: 'static + EventData,
96 F: Fn(Arc<Ev>, Ctx) -> Fut,
97 Fut: 'static + Send + Future<Output = Result<StreamEvents, E>>,
98 {
99 self.events.reject_all();
100 self.events.accept(Ev::KIND);
101 Dispatcher {
102 events: self.events,
103 ctx: self.ctx,
104 chain: Match {
105 f,
106 _phantom: std::marker::PhantomData,
107 },
108 }
109 }
110}
111
112pub struct Dispatcher<P, Ctx, D> {
118 events: EventStream<P>,
119 ctx: Ctx,
120 chain: D,
121}
122
123impl<P, Ctx, D> Dispatcher<P, Ctx, D>
124where
125 D: DispatchEvent<Ctx>,
126{
127 pub fn seq<Ev, F>(mut self, f: F) -> Dispatcher<P, Ctx, Intercept<Match<Ev, F>, D>>
176 where
177 Ev: EventData,
178 F: AsyncFnMut(Arc<Ev>, &mut Ctx) -> Result<StreamEvents, D::Error>,
179 {
180 self.events.accept(Ev::KIND);
181 Dispatcher {
182 events: self.events,
183 ctx: self.ctx,
184 chain: Intercept {
185 d1: Match {
186 f,
187 _phantom: std::marker::PhantomData,
188 },
189 d2: self.chain,
190 },
191 }
192 }
193
194 pub async fn sequential_dispatch(self) -> Result<(EventStream<P>, Ctx), D::Error>
201 where
202 P: EventParser,
203 D::Error: From<P::Error>,
204 {
205 let Self {
206 ctx,
207 events,
208 mut chain,
209 } = self;
210
211 events.stream_events_with_ctx_mut(async move |ev, ctx| {
212 let Ok(handler) = chain.dispatch_event(ev, ctx) else {
213 unreachable!("EventStream filters set by seq/fallback_seq drop events without handlers during parsing");
214 };
215
216 handler.await
217
218 }, ctx).await
219 }
220
221 #[cfg(feature = "cancellation")]
224 pub async fn sequential_dispatch_with_cancellation(
225 self,
226 token: CancellationToken,
227 ) -> Result<(EventStream<P>, Ctx), D::Error>
228 where
229 P: EventParser,
230 D::Error: From<P::Error>,
231 {
232 let Self {
233 mut ctx,
234 mut events,
235 mut chain,
236 } = self;
237
238 loop {
239 tokio::select! {
240 biased;
241 _ = token.cancelled() => break,
242 res = events.try_next() => match res {
243 Ok(Some(ev)) => {
244 let Ok(handler) = chain.dispatch_event(ev, &mut ctx) else {
245 unreachable!("EventStream filters set by seq/fallback_seq drop events without handlers during parsing");
246 };
247 if let StreamEvents::Break = handler.await? {
248 break;
249 }
250 }
251 Ok(None) => break,
252 Err(e) => return Err(e.into()),
253 }
254 }
255 }
256
257 Ok((events, ctx))
258 }
259}
260
261impl<P, Ctx, D> Dispatcher<P, Ctx, D>
262where
263 P: 'static + EventParser,
264 Ctx: 'static + Send + Clone,
265 D: ConcurrentDispatchEvent<Ctx>,
266 D::Error: From<P::Error>,
267{
268 pub fn on<Ev, F, Fut>(mut self, f: F) -> Dispatcher<P, Ctx, Intercept<Match<Ev, F>, D>>
316 where
317 Ev: 'static + EventData,
318 F: Fn(Arc<Ev>, Ctx) -> Fut,
319 Fut: 'static + Send + Future<Output = Result<StreamEvents, D::Error>>,
320 {
321 self.events.accept(Ev::KIND);
322 Dispatcher {
323 events: self.events,
324 ctx: self.ctx,
325 chain: Intercept {
326 d1: Match {
327 f,
328 _phantom: std::marker::PhantomData,
329 },
330 d2: self.chain,
331 },
332 }
333 }
334
335 pub async fn dispatch(self) -> Result<(EventStream<P>, Ctx, Vec<Event>), D::Error> {
345 let chain = self.chain;
346 let ctx = self.ctx;
347 let mut events = self.events;
348 let (event_buffer, result) =
349 run_concurrent_dispatch(&chain, &ctx, &mut events, std::future::pending::<()>()).await;
350 match result {
351 Ok(inner) => inner.map(move |_| (events, ctx, event_buffer)),
352 Err(e) => std::panic::resume_unwind(e.into_panic()),
353 }
354 }
355
356 #[cfg(feature = "cancellation")]
359 pub async fn dispatch_with_cancellation(
360 self,
361 token: CancellationToken,
362 ) -> Result<(EventStream<P>, Ctx, Vec<Event>), D::Error> {
363 let chain = self.chain;
364 let ctx = self.ctx;
365 let mut events = self.events;
366 let (event_buffer, result) =
367 run_concurrent_dispatch(&chain, &ctx, &mut events, token.cancelled()).await;
368 match result {
369 Ok(inner) => inner.map(move |_| (events, ctx, event_buffer)),
370 Err(e) => std::panic::resume_unwind(e.into_panic()),
371 }
372 }
373
374 pub async fn dispatch_sequentially(self) -> Result<(EventStream<P>, Ctx), D::Error> {
380 let ctx = self.ctx;
381 let events = self.events;
382 let chain = self.chain;
383
384 events.stream_events_with_ctx_cloned(async move |ev, ctx| {
385 let Ok(handler) = chain.concurrent_dispatch_event(ev, ctx) else {
386 unreachable!("EventStream filters set by on/fallback drop events without handlers during parsing");
387 };
388 handler.await
389 }, ctx).await
390 }
391
392 #[cfg(feature = "cancellation")]
394 pub async fn dispatch_sequentially_with_cancellation(
395 self,
396 token: CancellationToken,
397 ) -> Result<(EventStream<P>, Ctx), D::Error> {
398 let Self {
399 ctx,
400 mut events,
401 chain,
402 } = self;
403
404 loop {
405 tokio::select! {
406 biased;
407 _ = token.cancelled() => break,
408 res = events.try_next() => match res {
409 Ok(Some(ev)) => {
410 let Ok(handler) = chain.concurrent_dispatch_event(ev, ctx.clone()) else {
411 unreachable!("EventStream filters set by on/fallback drop events without handlers during parsing");
412 };
413 if let StreamEvents::Break = handler.await? {
414 break;
415 }
416 }
417 Ok(None) => break,
418 Err(e) => return Err(e.into()),
419 }
420 }
421 }
422
423 Ok((events, ctx))
424 }
425}
426
427async fn run_concurrent_dispatch<P, Ctx, D, Fut>(
433 chain: &D,
434 ctx: &Ctx,
435 events: &mut EventStream<P>,
436 stop: Fut,
437) -> (
438 Vec<Event>,
439 Result<Result<StreamEvents, D::Error>, tokio::task::JoinError>,
440)
441where
442 P: 'static + EventParser,
443 Ctx: 'static + Send + Clone,
444 D: ConcurrentDispatchEvent<Ctx>,
445 D::Error: From<P::Error>,
446 Fut: Future<Output = ()>,
447{
448 let mut join_set: tokio::task::JoinSet<Result<StreamEvents, D::Error>> =
449 tokio::task::JoinSet::new();
450 let (cancellator, cancellation) = tokio::sync::oneshot::channel::<()>();
451
452 join_set.spawn(async move {
454 let _ = cancellation.await;
455 Ok(StreamEvents::Continue)
456 });
457
458 let mut stop = std::pin::pin!(stop);
459
460 let mut result = loop {
461 tokio::select! {
462 _ = stop.as_mut() => break Ok(Ok(StreamEvents::Break)),
463 result = events.try_next() => match result {
464 Ok(Some(event)) => {
465 let Ok(handler) = chain.concurrent_dispatch_event(event, ctx.clone()) else {
466 unreachable!(
467 "EventStream filtering set by on and fallback methods drops events without handlers before parsing them"
468 );
469 };
470 join_set.spawn(handler);
471 }
472 Ok(None) => break Ok(Ok(StreamEvents::Break)),
473 Err(e) => break Ok(Err(e.into())),
474 },
475 result = join_set.join_next() => match result {
476 Some(Ok(Ok(StreamEvents::Continue))) => continue,
477 Some(Ok(Ok(StreamEvents::Break))) => break Ok(Ok(StreamEvents::Break)),
478 Some(err) => break err,
479 None => unreachable!("Dummy task must be running during the whole tokio select! loop"),
480 }
481 }
482 };
483
484 let _ = cancellator.send(());
485 let mut event_buffer = Vec::new();
486
487 loop {
488 tokio::select! {
489 joined = join_set.join_next() => match joined {
490 Some(next) => {
491 if matches!(result, Ok(Ok(_))) {
492 result = next;
493 }
494 }
495 None => break,
496 },
497 event = events.try_next() => match event {
498 Ok(Some(ev)) => event_buffer.push(ev),
499 Ok(None) => (),
500 Err(e) => {
501 result = Ok(Err(e.into()));
502 break;
503 }
504 }
505 }
506 }
507
508 (event_buffer, result)
509}
510
511pub trait DispatchEvent<Ctx> {
512 type Error;
513 type Future<'s>: Future<Output = Result<StreamEvents, Self::Error>>
514 where
515 Self: 's,
516 Ctx: 's;
517
518 fn dispatch_event<'s>(
519 &'s mut self,
520 ev: Event,
521 ctx: &'s mut Ctx,
522 ) -> Result<Self::Future<'s>, (Event, &'s mut Ctx)>;
523}
524
525pub trait ConcurrentDispatchEvent<Ctx>
526where
527 Ctx: 'static + Send,
528{
529 type Error: 'static + Send;
530 type Future: 'static + Send + Future<Output = Result<StreamEvents, Self::Error>>;
531
532 fn concurrent_dispatch_event(&self, ev: Event, ctx: Ctx) -> Result<Self::Future, (Event, Ctx)>;
533}
534
535pub struct Fallback<F> {
538 f: F,
539}
540
541impl<Ctx, E, F> DispatchEvent<Ctx> for Fallback<F>
542where
543 F: AsyncFnMut(Event, &mut Ctx) -> Result<StreamEvents, E>,
544{
545 type Error = E;
546 type Future<'s>
548 = Pin<Box<dyn 's + Future<Output = Result<StreamEvents, E>>>>
549 where
550 Self: 's,
551 Ctx: 's;
552
553 fn dispatch_event<'s>(
554 &'s mut self,
555 ev: Event,
556 ctx: &'s mut Ctx,
557 ) -> Result<Self::Future<'s>, (Event, &'s mut Ctx)> {
558 Ok(Box::pin((self.f)(ev, ctx)))
559 }
560}
561
562impl<Ctx, E, F, Fut> ConcurrentDispatchEvent<Ctx> for Fallback<F>
563where
564 Ctx: 'static + Send,
565 E: 'static + Send,
566 F: Fn(Event, Ctx) -> Fut,
569 Fut: 'static + Send + Future<Output = Result<StreamEvents, E>>,
570{
571 type Error = E;
572 type Future = Fut;
573
574 fn concurrent_dispatch_event(&self, ev: Event, ctx: Ctx) -> Result<Self::Future, (Event, Ctx)> {
575 Ok((self.f)(ev, ctx))
576 }
577}
578
579pub struct Match<Ev, F> {
580 f: F,
581 _phantom: std::marker::PhantomData<Ev>,
582}
583
584impl<Ctx, Ev, E, F> DispatchEvent<Ctx> for Match<Ev, F>
585where
586 Ev: EventData,
587 F: AsyncFnMut(Arc<Ev>, &mut Ctx) -> Result<StreamEvents, E>,
588{
589 type Error = E;
590 type Future<'s>
592 = Pin<Box<dyn 's + Future<Output = Result<StreamEvents, E>>>>
593 where
594 Self: 's,
595 Ctx: 's;
596
597 fn dispatch_event<'s>(
598 &'s mut self,
599 ev: Event,
600 ctx: &'s mut Ctx,
601 ) -> Result<Self::Future<'s>, (Event, &'s mut Ctx)> {
602 match Ev::from_event(ev) {
603 Ok(ev) => Ok(Box::pin((self.f)(ev, ctx))),
604 Err(ev) => Err((ev, ctx)),
605 }
606 }
607}
608
609impl<Ctx, Ev, E, F, Fut> ConcurrentDispatchEvent<Ctx> for Match<Ev, F>
610where
611 Ctx: 'static + Send,
612 Ev: 'static + EventData,
613 E: 'static + Send,
614 F: Fn(Arc<Ev>, Ctx) -> Fut,
617 Fut: 'static + Send + Future<Output = Result<StreamEvents, E>>,
618{
619 type Error = E;
620 type Future = Fut;
621
622 fn concurrent_dispatch_event(&self, ev: Event, ctx: Ctx) -> Result<Self::Future, (Event, Ctx)> {
623 match Ev::from_event(ev) {
624 Ok(ev) => Ok((self.f)(ev, ctx)),
625 Err(ev) => Err((ev, ctx)),
626 }
627 }
628}
629
630pub struct Intercept<D1, D2> {
631 d1: D1,
632 d2: D2,
633}
634
635impl<Ctx, D1, D2> DispatchEvent<Ctx> for Intercept<D1, D2>
636where
637 D1: DispatchEvent<Ctx>,
638 D2: DispatchEvent<Ctx, Error = D1::Error>,
639{
640 type Error = D1::Error;
641 type Future<'s>
642 = futures::future::Either<D1::Future<'s>, D2::Future<'s>>
643 where
644 Self: 's,
645 Ctx: 's;
646
647 fn dispatch_event<'s>(
648 &'s mut self,
649 ev: Event,
650 ctx: &'s mut Ctx,
651 ) -> Result<Self::Future<'s>, (Event, &'s mut Ctx)> {
652 self.d1
653 .dispatch_event(ev, ctx)
654 .map(futures::future::Either::Left)
655 .or_else(|(ev, ctx)| {
656 self.d2
657 .dispatch_event(ev, ctx)
658 .map(futures::future::Either::Right)
659 })
660 }
661}
662
663impl<Ctx, D1, D2> ConcurrentDispatchEvent<Ctx> for Intercept<D1, D2>
664where
665 Ctx: 'static + Send,
666 D1: ConcurrentDispatchEvent<Ctx>,
667 D2: ConcurrentDispatchEvent<Ctx, Error = D1::Error>,
668{
669 type Error = D1::Error;
670 type Future = futures::future::Either<D1::Future, D2::Future>;
671
672 fn concurrent_dispatch_event(&self, ev: Event, ctx: Ctx) -> Result<Self::Future, (Event, Ctx)> {
673 self.d1
674 .concurrent_dispatch_event(ev, ctx)
675 .map(futures::future::Either::Left)
676 .or_else(|(ev, ctx)| {
677 self.d2
678 .concurrent_dispatch_event(ev, ctx)
679 .map(futures::future::Either::Right)
680 })
681 }
682}