1use std::collections::HashMap;
2use std::fmt;
3use std::sync::Arc;
4
5use serde::de::DeserializeOwned;
6use serde::Serialize;
7
8use rustvello_proto::call::SerializedArguments;
9use rustvello_proto::config::TaskConfig;
10use rustvello_proto::identifiers::TaskId;
11
12use crate::error::{RustvelloError, RustvelloResult};
13
14pub trait Task: Send + Sync + 'static {
60 type Params: Serialize + DeserializeOwned + Send + Sync + 'static;
62 type Result: Serialize + DeserializeOwned + Send + Sync + 'static;
64
65 fn task_id(&self) -> &TaskId;
67
68 fn config(&self) -> &TaskConfig;
70
71 fn run(&self, params: Self::Params) -> RustvelloResult<Self::Result>;
73}
74
75pub trait DynTask: Send + Sync {
85 fn task_id(&self) -> &TaskId;
87
88 fn config(&self) -> &TaskConfig;
90
91 fn execute(&self, args: &SerializedArguments) -> RustvelloResult<String>;
93}
94
95pub fn serialized_args_to_json(
101 args: &SerializedArguments,
102) -> RustvelloResult<std::borrow::Cow<'_, str>> {
103 use std::borrow::Cow;
104 if args.0.len() == 1 && args.0.contains_key("__args__") {
105 return Ok(Cow::Borrowed(&args.0["__args__"]));
107 }
108 use std::fmt::Write;
110 let mut buf = String::with_capacity(args.0.len() * 32 + 2);
111 buf.push('{');
112 for (i, (k, v)) in args.0.iter().enumerate() {
113 if i > 0 {
114 buf.push(',');
115 }
116 let escaped_key =
118 serde_json::to_string(k.as_str()).map_err(|e| RustvelloError::Serialization {
119 message: format!("failed to escape JSON key: {e}"),
120 })?;
121 serde_json::from_str::<serde_json::Value>(v).map_err(|e| {
123 RustvelloError::Serialization {
124 message: format!("invalid JSON value for key {k}: {e}"),
125 }
126 })?;
127 write!(buf, "{}:{}", escaped_key, v).map_err(|e| RustvelloError::Serialization {
128 message: format!("failed to build JSON: {e}"),
129 })?;
130 }
131 buf.push('}');
132 Ok(Cow::Owned(buf))
133}
134
135impl<T: Task> DynTask for T {
136 #[inline]
137 fn task_id(&self) -> &TaskId {
138 Task::task_id(self)
139 }
140
141 #[inline]
142 fn config(&self) -> &TaskConfig {
143 Task::config(self)
144 }
145
146 fn execute(&self, args: &SerializedArguments) -> RustvelloResult<String> {
147 let json_str = serialized_args_to_json(args)?;
148 let params: T::Params =
149 serde_json::from_str(&json_str).map_err(|e| RustvelloError::Serialization {
150 message: e.to_string(),
151 })?;
152 let result = self.run(params)?;
153 serde_json::to_string(&result).map_err(|e| RustvelloError::Serialization {
154 message: e.to_string(),
155 })
156 }
157}
158
159impl fmt::Debug for dyn DynTask {
160 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
161 f.debug_struct("DynTask")
162 .field("task_id", &self.task_id())
163 .finish()
164 }
165}
166
167pub trait CrossLanguageSafe: Serialize + DeserializeOwned {}
180
181impl CrossLanguageSafe for String {}
183impl CrossLanguageSafe for bool {}
184impl CrossLanguageSafe for i32 {}
185impl CrossLanguageSafe for i64 {}
186impl CrossLanguageSafe for u32 {}
187impl CrossLanguageSafe for u64 {}
188impl CrossLanguageSafe for f32 {}
189impl CrossLanguageSafe for f64 {}
190impl<T: CrossLanguageSafe> CrossLanguageSafe for Vec<T> {}
191impl<T: CrossLanguageSafe> CrossLanguageSafe for Option<T> {}
192impl<K: CrossLanguageSafe + Ord, V: CrossLanguageSafe> CrossLanguageSafe
193 for std::collections::BTreeMap<K, V>
194{
195}
196impl<K: CrossLanguageSafe + Eq + std::hash::Hash, V: CrossLanguageSafe> CrossLanguageSafe
197 for std::collections::HashMap<K, V>
198{
199}
200
201pub trait ForeignTask: Send + Sync + 'static {
239 type Params: CrossLanguageSafe + Send + Sync + 'static;
241 type Result: CrossLanguageSafe + Send + Sync + 'static;
243
244 fn task_id(&self) -> TaskId;
247
248 fn config(&self) -> TaskConfig {
250 TaskConfig::default()
251 }
252}
253
254struct ForeignTaskAdapter<F: ForeignTask> {
262 _inner: F,
263 task_id: TaskId,
264 config: TaskConfig,
265}
266
267impl<F: ForeignTask> ForeignTaskAdapter<F> {
268 fn new(task: F) -> Self {
269 let task_id = task.task_id();
270 let config = task.config();
271 Self {
272 _inner: task,
273 task_id,
274 config,
275 }
276 }
277}
278
279impl<F: ForeignTask> DynTask for ForeignTaskAdapter<F> {
280 fn task_id(&self) -> &TaskId {
281 &self.task_id
282 }
283
284 fn config(&self) -> &TaskConfig {
285 &self.config
286 }
287
288 fn execute(&self, _args: &SerializedArguments) -> RustvelloResult<String> {
289 Err(RustvelloError::Configuration {
290 message: format!(
291 "foreign task {} cannot be executed locally — must be processed by a {} worker",
292 self.task_id,
293 self.task_id.language(),
294 ),
295 })
296 }
297}
298
299pub type TaskFn = Arc<dyn Fn(String) -> RustvelloResult<String> + Send + Sync>;
311
312pub struct TaskDefinition {
317 pub task_id: TaskId,
318 pub config: TaskConfig,
319 pub func: TaskFn,
320}
321
322impl TaskDefinition {
323 pub fn new(task_id: TaskId, config: TaskConfig, func: TaskFn) -> Self {
324 Self {
325 task_id,
326 config,
327 func,
328 }
329 }
330}
331
332impl fmt::Debug for TaskDefinition {
333 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
334 f.debug_struct("TaskDefinition")
335 .field("task_id", &self.task_id)
336 .field("config", &self.config)
337 .finish()
338 }
339}
340
341struct LegacyTaskAdapter {
343 definition: Arc<TaskDefinition>,
344}
345
346impl DynTask for LegacyTaskAdapter {
347 fn task_id(&self) -> &TaskId {
348 &self.definition.task_id
349 }
350
351 fn config(&self) -> &TaskConfig {
352 &self.definition.config
353 }
354
355 fn execute(&self, args: &SerializedArguments) -> RustvelloResult<String> {
356 let args_json =
358 serde_json::to_string(&args.0).map_err(|e| RustvelloError::Serialization {
359 message: e.to_string(),
360 })?;
361 (self.definition.func)(args_json)
362 }
363}
364
365#[derive(Default)]
374pub struct TaskRegistry {
375 tasks: HashMap<TaskId, Arc<dyn DynTask>>,
376 legacy_tasks: HashMap<TaskId, Arc<TaskDefinition>>,
378}
379
380impl TaskRegistry {
381 pub fn new() -> Self {
382 Self::default()
383 }
384
385 pub fn register_typed<T: Task>(&mut self, task: T) -> RustvelloResult<()> {
387 let task_id = task.task_id().clone();
388 if self.tasks.contains_key(&task_id) {
389 return Err(RustvelloError::Configuration {
390 message: format!("task already registered: {}", task_id),
391 });
392 }
393 self.tasks.insert(task_id, Arc::new(task));
394 Ok(())
395 }
396
397 pub fn register_foreign<F: ForeignTask>(&mut self, task: F) -> RustvelloResult<()> {
400 let task_id = task.task_id();
401 if !task_id.is_foreign() {
402 return Err(RustvelloError::Configuration {
403 message: format!(
404 "ForeignTask must have a non-empty language, got: {}",
405 task_id
406 ),
407 });
408 }
409 if self.tasks.contains_key(&task_id) {
410 return Err(RustvelloError::Configuration {
411 message: format!("task already registered: {}", task_id),
412 });
413 }
414 self.tasks
415 .insert(task_id, Arc::new(ForeignTaskAdapter::new(task)));
416 Ok(())
417 }
418
419 pub fn register(&mut self, definition: TaskDefinition) -> RustvelloResult<()> {
421 let task_id = definition.task_id.clone();
422 if self.tasks.contains_key(&task_id) {
423 return Err(RustvelloError::Configuration {
424 message: format!("task already registered: {}", task_id),
425 });
426 }
427 let def = Arc::new(definition);
428 let adapter = LegacyTaskAdapter {
429 definition: Arc::clone(&def),
430 };
431 self.tasks.insert(task_id.clone(), Arc::new(adapter));
432 self.legacy_tasks.insert(task_id, def);
433 Ok(())
434 }
435
436 pub fn get_dyn(&self, task_id: &TaskId) -> Option<Arc<dyn DynTask>> {
438 self.tasks.get(task_id).cloned()
439 }
440
441 pub fn get(&self, task_id: &TaskId) -> Option<Arc<TaskDefinition>> {
443 self.legacy_tasks.get(task_id).cloned()
444 }
445
446 pub fn contains(&self, task_id: &TaskId) -> bool {
448 self.tasks.contains_key(task_id)
449 }
450
451 pub fn task_ids(&self) -> Vec<&TaskId> {
453 self.tasks.keys().collect()
454 }
455
456 pub fn len(&self) -> usize {
458 self.tasks.len()
459 }
460
461 pub fn is_empty(&self) -> bool {
462 self.tasks.is_empty()
463 }
464}
465
466impl fmt::Debug for TaskRegistry {
467 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
468 f.debug_struct("TaskRegistry")
469 .field("tasks", &self.tasks.keys().collect::<Vec<_>>())
470 .finish()
471 }
472}
473
474pub trait TaskModule: Send + Sync {
507 fn name(&self) -> &str;
509
510 fn register(&self, registry: &mut TaskRegistry) -> RustvelloResult<()>;
512}
513
514#[cfg(test)]
515mod tests {
516 use super::*;
517
518 fn dummy_fn() -> TaskFn {
519 Arc::new(|_| Ok("ok".to_string()))
520 }
521
522 #[test]
523 fn registry_new_is_empty() {
524 let reg = TaskRegistry::new();
525 assert!(reg.is_empty());
526 assert_eq!(reg.len(), 0);
527 }
528
529 #[test]
530 fn register_and_get() {
531 let mut reg = TaskRegistry::new();
532 let tid = TaskId::new("mod", "func");
533 reg.register(TaskDefinition::new(
534 tid.clone(),
535 TaskConfig::default(),
536 dummy_fn(),
537 ))
538 .unwrap();
539
540 assert_eq!(reg.len(), 1);
541 assert!(!reg.is_empty());
542 assert!(reg.contains(&tid));
543 assert!(reg.get(&tid).is_some());
544 assert_eq!(reg.get(&tid).unwrap().task_id, tid);
545 }
546
547 #[test]
548 fn register_duplicate_errors() {
549 let mut reg = TaskRegistry::new();
550 let tid = TaskId::new("mod", "func");
551 reg.register(TaskDefinition::new(
552 tid.clone(),
553 TaskConfig::default(),
554 dummy_fn(),
555 ))
556 .unwrap();
557 let result = reg.register(TaskDefinition::new(tid, TaskConfig::default(), dummy_fn()));
558 assert!(result.is_err());
559 }
560
561 #[test]
562 fn get_nonexistent_returns_none() {
563 let reg = TaskRegistry::new();
564 let tid = TaskId::new("no", "such");
565 assert!(!reg.contains(&tid));
566 assert!(reg.get(&tid).is_none());
567 }
568
569 #[test]
570 fn task_ids_lists_all() {
571 let mut reg = TaskRegistry::new();
572 let t1 = TaskId::new("mod", "a");
573 let t2 = TaskId::new("mod", "b");
574 reg.register(TaskDefinition::new(
575 t1.clone(),
576 TaskConfig::default(),
577 dummy_fn(),
578 ))
579 .unwrap();
580 reg.register(TaskDefinition::new(
581 t2.clone(),
582 TaskConfig::default(),
583 dummy_fn(),
584 ))
585 .unwrap();
586
587 let ids = reg.task_ids();
588 assert_eq!(ids.len(), 2);
589 assert!(ids.contains(&&t1));
590 assert!(ids.contains(&&t2));
591 }
592
593 #[test]
594 fn task_definition_debug() {
595 let def = TaskDefinition::new(
596 TaskId::new("mod", "func"),
597 TaskConfig::default(),
598 dummy_fn(),
599 );
600 let debug_str = format!("{:?}", def);
601 assert!(debug_str.contains("mod"));
602 assert!(debug_str.contains("func"));
603 }
604
605 #[derive(serde::Serialize, serde::Deserialize)]
608 struct TestParams {
609 value: String,
610 }
611 impl CrossLanguageSafe for TestParams {}
612
613 struct TestForeignTask;
614
615 impl ForeignTask for TestForeignTask {
616 type Params = TestParams;
617 type Result = String;
618
619 fn task_id(&self) -> TaskId {
620 TaskId::foreign("python", "analytics.tasks", "train_model")
621 }
622 }
623
624 #[test]
625 fn register_foreign_task() {
626 let mut reg = TaskRegistry::new();
627 reg.register_foreign(TestForeignTask).unwrap();
628
629 let tid = TaskId::foreign("python", "analytics.tasks", "train_model");
630 assert!(reg.contains(&tid));
631 assert_eq!(reg.len(), 1);
632
633 let dyn_task = reg.get_dyn(&tid).unwrap();
634 assert_eq!(dyn_task.task_id(), &tid);
635 assert!(dyn_task.task_id().is_foreign());
636 }
637
638 #[test]
639 fn foreign_task_execute_returns_error() {
640 let mut reg = TaskRegistry::new();
641 reg.register_foreign(TestForeignTask).unwrap();
642
643 let tid = TaskId::foreign("python", "analytics.tasks", "train_model");
644 let dyn_task = reg.get_dyn(&tid).unwrap();
645
646 let args = SerializedArguments::default();
647 let result = dyn_task.execute(&args);
648 assert!(result.is_err());
649 let err_msg = result.unwrap_err().to_string();
650 assert!(err_msg.contains("foreign task"));
651 assert!(err_msg.contains("python"));
652 }
653
654 #[test]
655 fn register_foreign_duplicate_errors() {
656 let mut reg = TaskRegistry::new();
657 reg.register_foreign(TestForeignTask).unwrap();
658 let result = reg.register_foreign(TestForeignTask);
659 assert!(result.is_err());
660 }
661
662 #[test]
663 fn cross_language_safe_primitives() {
664 fn assert_cls<T: CrossLanguageSafe>() {}
666 assert_cls::<String>();
667 assert_cls::<bool>();
668 assert_cls::<i32>();
669 assert_cls::<i64>();
670 assert_cls::<u32>();
671 assert_cls::<u64>();
672 assert_cls::<f32>();
673 assert_cls::<f64>();
674 assert_cls::<Vec<String>>();
675 assert_cls::<Option<i64>>();
676 assert_cls::<std::collections::BTreeMap<String, i64>>();
677 assert_cls::<std::collections::HashMap<String, String>>();
678 }
679}