1use crate::codec::{Codec, sealed};
58use crate::error::BoxError;
59use crate::task::{
60 BranchOutputs, BytesFuture, CoreTask, TaskMetadata, UntypedCoreTask,
61 to_heterogeneous_join_task_arc,
62};
63use bytes::Bytes;
64use std::collections::HashMap;
65use std::future::Future;
66use std::marker::PhantomData;
67use std::sync::Arc;
68
69pub type TaskFactory = Box<dyn Fn() -> UntypedCoreTask + Send + Sync>;
71
72pub struct TaskEntry {
74 factory: TaskFactory,
75 metadata: TaskMetadata,
76}
77
78pub struct TaskRegistry {
89 tasks: HashMap<String, TaskEntry>,
90}
91
92impl Default for TaskRegistry {
93 fn default() -> Self {
94 Self::new()
95 }
96}
97
98impl TaskRegistry {
99 #[must_use]
101 pub fn new() -> Self {
102 Self {
103 tasks: HashMap::new(),
104 }
105 }
106
107 pub fn register<T, C>(&mut self, id: &str, codec: Arc<C>, task: T)
153 where
154 T: CoreTask + 'static,
155 T::Input: Send + 'static,
156 T::Output: Send + 'static,
157 T::Future: Send + 'static,
158 C: Codec + sealed::DecodeValue<T::Input> + sealed::EncodeValue<T::Output> + 'static,
159 {
160 self.register_with_metadata(id, codec, task, TaskMetadata::default());
161 }
162
163 pub fn register_with_metadata<T, C>(
168 &mut self,
169 id: &str,
170 codec: Arc<C>,
171 task: T,
172 metadata: TaskMetadata,
173 ) where
174 T: CoreTask + 'static,
175 T::Input: Send + 'static,
176 T::Output: Send + 'static,
177 T::Future: Send + 'static,
178 C: Codec + sealed::DecodeValue<T::Input> + sealed::EncodeValue<T::Output> + 'static,
179 {
180 self.register_task_arc(id, codec, Arc::new(task), metadata);
181 }
182
183 pub(crate) fn register_task_arc<T, C>(
185 &mut self,
186 id: &str,
187 codec: Arc<C>,
188 task: Arc<T>,
189 metadata: TaskMetadata,
190 ) where
191 T: CoreTask + 'static,
192 T::Input: Send + 'static,
193 T::Output: Send + 'static,
194 T::Future: Send + 'static,
195 C: Codec + sealed::DecodeValue<T::Input> + sealed::EncodeValue<T::Output> + 'static,
196 {
197 let factory = Box::new(move || -> UntypedCoreTask {
198 let task = Arc::clone(&task);
199 let codec = Arc::clone(&codec);
200 Box::new(TaskWrapper { task, codec })
201 });
202 self.tasks
203 .insert(id.to_string(), TaskEntry { factory, metadata });
204 }
205
206 pub fn register_fn<I, O, F, Fut, C>(&mut self, id: &str, codec: Arc<C>, func: F)
232 where
233 F: Fn(I) -> Fut + Send + Sync + 'static,
234 I: Send + 'static,
235 O: Send + 'static,
236 Fut: Future<Output = Result<O, BoxError>> + Send + 'static,
237 C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O> + 'static,
238 {
239 self.register_fn_with_metadata(id, codec, func, TaskMetadata::default());
240 }
241
242 pub fn register_fn_with_metadata<I, O, F, Fut, C>(
246 &mut self,
247 id: &str,
248 codec: Arc<C>,
249 func: F,
250 metadata: TaskMetadata,
251 ) where
252 F: Fn(I) -> Fut + Send + Sync + 'static,
253 I: Send + 'static,
254 O: Send + 'static,
255 Fut: Future<Output = Result<O, BoxError>> + Send + 'static,
256 C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O> + 'static,
257 {
258 self.register_fn_arc(id, codec, Arc::new(func), metadata);
259 }
260
261 pub(crate) fn register_fn_arc<I, O, F, Fut, C>(
263 &mut self,
264 id: &str,
265 codec: Arc<C>,
266 func: Arc<F>,
267 metadata: TaskMetadata,
268 ) where
269 F: Fn(I) -> Fut + Send + Sync + 'static,
270 I: Send + 'static,
271 O: Send + 'static,
272 Fut: Future<Output = Result<O, BoxError>> + Send + 'static,
273 C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O> + 'static,
274 {
275 let factory = Box::new(move || -> UntypedCoreTask {
276 let func = Arc::clone(&func);
277 let codec = Arc::clone(&codec);
278 Box::new(FnTaskWrapper {
279 func,
280 codec,
281 _phantom: PhantomData,
282 })
283 });
284 self.tasks
285 .insert(id.to_string(), TaskEntry { factory, metadata });
286 }
287
288 pub fn register_join<O, F, Fut, C>(&mut self, id: &str, codec: Arc<C>, func: F)
290 where
291 F: Fn(BranchOutputs<C>) -> Fut + Send + Sync + 'static,
292 O: Send + 'static,
293 Fut: Future<Output = Result<O, BoxError>> + Send + 'static,
294 C: Codec
295 + sealed::EncodeValue<O>
296 + sealed::DecodeValue<crate::branch_results::NamedBranchResults>
297 + Send
298 + Sync
299 + 'static,
300 {
301 self.register_join_with_metadata(id, codec, func, TaskMetadata::default());
302 }
303
304 pub fn register_join_with_metadata<O, F, Fut, C>(
306 &mut self,
307 id: &str,
308 codec: Arc<C>,
309 func: F,
310 metadata: TaskMetadata,
311 ) where
312 F: Fn(BranchOutputs<C>) -> Fut + Send + Sync + 'static,
313 O: Send + 'static,
314 Fut: Future<Output = Result<O, BoxError>> + Send + 'static,
315 C: Codec
316 + sealed::EncodeValue<O>
317 + sealed::DecodeValue<crate::branch_results::NamedBranchResults>
318 + Send
319 + Sync
320 + 'static,
321 {
322 self.register_arc_join(id, codec, Arc::new(func), metadata);
323 }
324
325 pub(crate) fn register_arc_join<O, F, Fut, C>(
327 &mut self,
328 id: &str,
329 codec: Arc<C>,
330 func: Arc<F>,
331 metadata: TaskMetadata,
332 ) where
333 F: Fn(BranchOutputs<C>) -> Fut + Send + Sync + 'static,
334 O: Send + 'static,
335 Fut: Future<Output = Result<O, BoxError>> + Send + 'static,
336 C: Codec
337 + sealed::EncodeValue<O>
338 + sealed::DecodeValue<crate::branch_results::NamedBranchResults>
339 + Send
340 + Sync
341 + 'static,
342 {
343 let factory = Box::new(move || -> UntypedCoreTask {
344 to_heterogeneous_join_task_arc(Arc::clone(&func), Arc::clone(&codec))
345 });
346 self.tasks
347 .insert(id.to_string(), TaskEntry { factory, metadata });
348 }
349
350 #[must_use]
354 pub fn get(&self, id: &str) -> Option<UntypedCoreTask> {
355 self.tasks.get(id).map(|entry| (entry.factory)())
356 }
357
358 #[must_use]
362 pub fn get_metadata(&self, id: &str) -> Option<&TaskMetadata> {
363 self.tasks.get(id).map(|entry| &entry.metadata)
364 }
365
366 #[must_use]
370 pub fn get_with_metadata(&self, id: &str) -> Option<(UntypedCoreTask, &TaskMetadata)> {
371 self.tasks
372 .get(id)
373 .map(|entry| ((entry.factory)(), &entry.metadata))
374 }
375
376 pub fn set_metadata(&mut self, id: &str, metadata: TaskMetadata) -> bool {
380 if let Some(entry) = self.tasks.get_mut(id) {
381 entry.metadata = metadata;
382 true
383 } else {
384 false
385 }
386 }
387
388 #[must_use]
390 pub fn contains(&self, id: &str) -> bool {
391 self.tasks.contains_key(id)
392 }
393
394 #[must_use]
396 pub fn len(&self) -> usize {
397 self.tasks.len()
398 }
399
400 #[must_use]
402 pub fn is_empty(&self) -> bool {
403 self.tasks.is_empty()
404 }
405
406 pub fn task_ids(&self) -> impl Iterator<Item = &str> {
408 self.tasks.keys().map(std::string::String::as_str)
409 }
410
411 pub fn with_codec<C>(codec: Arc<C>) -> RegistryBuilder<C>
436 where
437 C: Codec,
438 {
439 RegistryBuilder {
440 codec,
441 registry: TaskRegistry::new(),
442 }
443 }
444}
445
446pub struct RegistryBuilder<C> {
451 codec: Arc<C>,
452 registry: TaskRegistry,
453}
454
455impl<C: Codec> RegistryBuilder<C> {
456 #[must_use]
460 pub fn register<T>(mut self, id: &str, task: T) -> Self
461 where
462 T: CoreTask + 'static,
463 T::Input: Send + 'static,
464 T::Output: Send + 'static,
465 T::Future: Send + 'static,
466 C: sealed::DecodeValue<T::Input> + sealed::EncodeValue<T::Output> + 'static,
467 {
468 self.registry.register(id, Arc::clone(&self.codec), task);
469 self
470 }
471
472 #[must_use]
474 pub fn register_with_metadata<T>(mut self, id: &str, task: T, metadata: TaskMetadata) -> Self
475 where
476 T: CoreTask + 'static,
477 T::Input: Send + 'static,
478 T::Output: Send + 'static,
479 T::Future: Send + 'static,
480 C: sealed::DecodeValue<T::Input> + sealed::EncodeValue<T::Output> + 'static,
481 {
482 self.registry
483 .register_with_metadata(id, Arc::clone(&self.codec), task, metadata);
484 self
485 }
486
487 #[must_use]
489 pub fn register_fn<I, O, F, Fut>(mut self, id: &str, func: F) -> Self
490 where
491 F: Fn(I) -> Fut + Send + Sync + 'static,
492 I: Send + 'static,
493 O: Send + 'static,
494 Fut: Future<Output = Result<O, BoxError>> + Send + 'static,
495 C: sealed::DecodeValue<I> + sealed::EncodeValue<O> + 'static,
496 {
497 self.registry.register_fn(id, Arc::clone(&self.codec), func);
498 self
499 }
500
501 #[must_use]
503 pub fn register_fn_with_metadata<I, O, F, Fut>(
504 mut self,
505 id: &str,
506 func: F,
507 metadata: TaskMetadata,
508 ) -> Self
509 where
510 F: Fn(I) -> Fut + Send + Sync + 'static,
511 I: Send + 'static,
512 O: Send + 'static,
513 Fut: Future<Output = Result<O, BoxError>> + Send + 'static,
514 C: sealed::DecodeValue<I> + sealed::EncodeValue<O> + 'static,
515 {
516 self.registry
517 .register_fn_with_metadata(id, Arc::clone(&self.codec), func, metadata);
518 self
519 }
520
521 #[must_use]
523 pub fn register_join<O, F, Fut>(mut self, id: &str, func: F) -> Self
524 where
525 F: Fn(BranchOutputs<C>) -> Fut + Send + Sync + 'static,
526 O: Send + 'static,
527 Fut: Future<Output = Result<O, BoxError>> + Send + 'static,
528 C: sealed::EncodeValue<O>
529 + sealed::DecodeValue<crate::branch_results::NamedBranchResults>
530 + Send
531 + Sync
532 + 'static,
533 {
534 self.registry
535 .register_join(id, Arc::clone(&self.codec), func);
536 self
537 }
538
539 #[must_use]
541 pub fn register_join_with_metadata<O, F, Fut>(
542 mut self,
543 id: &str,
544 func: F,
545 metadata: TaskMetadata,
546 ) -> Self
547 where
548 F: Fn(BranchOutputs<C>) -> Fut + Send + Sync + 'static,
549 O: Send + 'static,
550 Fut: Future<Output = Result<O, BoxError>> + Send + 'static,
551 C: sealed::EncodeValue<O>
552 + sealed::DecodeValue<crate::branch_results::NamedBranchResults>
553 + Send
554 + Sync
555 + 'static,
556 {
557 self.registry
558 .register_join_with_metadata(id, Arc::clone(&self.codec), func, metadata);
559 self
560 }
561
562 #[must_use]
564 pub fn build(self) -> TaskRegistry {
565 self.registry
566 }
567}
568
569struct FnTaskWrapper<F, I, O, C> {
571 func: Arc<F>,
572 codec: Arc<C>,
573 _phantom: PhantomData<fn(I) -> O>,
574}
575
576impl<F, I, O, Fut, C> CoreTask for FnTaskWrapper<F, I, O, C>
577where
578 F: Fn(I) -> Fut + Send + Sync + 'static,
579 I: Send + 'static,
580 O: Send + 'static,
581 Fut: Future<Output = Result<O, BoxError>> + Send + 'static,
582 C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O>,
583{
584 type Input = Bytes;
585 type Output = Bytes;
586 type Future = BytesFuture;
587
588 fn run(&self, input: Bytes) -> Self::Future {
589 let func = Arc::clone(&self.func);
590 let codec = Arc::clone(&self.codec);
591 BytesFuture::new(async move {
592 let decoded_input = codec.decode::<I>(input)?;
593 let output = func(decoded_input).await?;
594 codec.encode(&output)
595 })
596 }
597}
598
599struct TaskWrapper<T, C> {
601 task: Arc<T>,
602 codec: Arc<C>,
603}
604
605impl<T, C> CoreTask for TaskWrapper<T, C>
606where
607 T: CoreTask + Send + Sync + 'static,
608 T::Input: Send + 'static,
609 T::Output: Send + 'static,
610 T::Future: Send + 'static,
611 C: Codec + sealed::DecodeValue<T::Input> + sealed::EncodeValue<T::Output>,
612{
613 type Input = Bytes;
614 type Output = Bytes;
615 type Future = BytesFuture;
616
617 fn run(&self, input: Bytes) -> Self::Future {
618 let task = Arc::clone(&self.task);
619 let codec = Arc::clone(&self.codec);
620 BytesFuture::new(async move {
621 let decoded_input = codec.decode::<T::Input>(input)?;
622 let output = task.run(decoded_input).await?;
623 codec.encode(&output)
624 })
625 }
626}
627
628#[cfg(test)]
629mod tests {
630 use super::*;
631 use crate::codec::{Decoder, Encoder};
632
633 struct DummyCodec;
634 impl Encoder for DummyCodec {}
635 impl Decoder for DummyCodec {}
636 impl sealed::EncodeValue<u32> for DummyCodec {
637 fn encode_value(&self, _: &u32) -> Result<Bytes, BoxError> {
638 Ok(Bytes::from_static(b"encoded"))
639 }
640 }
641 impl sealed::DecodeValue<u32> for DummyCodec {
642 fn decode_value(&self, _: Bytes) -> Result<u32, BoxError> {
643 Ok(42)
644 }
645 }
646
647 #[test]
648 fn test_registry_register() {
649 let mut registry = TaskRegistry::new();
650 let codec = Arc::new(DummyCodec);
651
652 registry.register_fn("double", codec, |input: u32| async move { Ok(input * 2) });
653
654 assert!(registry.contains("double"));
655 assert_eq!(registry.len(), 1);
656 }
657
658 #[test]
659 fn test_registry_get() {
660 let mut registry = TaskRegistry::new();
661 let codec = Arc::new(DummyCodec);
662
663 registry.register_fn("double", codec, |input: u32| async move { Ok(input * 2) });
664
665 let task = registry.get("double");
666 assert!(task.is_some());
667
668 let missing = registry.get("nonexistent");
669 assert!(missing.is_none());
670 }
671
672 #[test]
673 fn test_registry_task_ids() {
674 let mut registry = TaskRegistry::new();
675 let codec = Arc::new(DummyCodec);
676
677 registry.register_fn("task_a", codec.clone(), |i: u32| async move { Ok(i) });
678 registry.register_fn("task_b", codec.clone(), |i: u32| async move { Ok(i) });
679 registry.register_fn("task_c", codec, |i: u32| async move { Ok(i) });
680
681 let mut ids: Vec<_> = registry.task_ids().collect();
682 ids.sort();
683 assert_eq!(ids, vec!["task_a", "task_b", "task_c"]);
684 }
685}