Skip to main content

rustvello_core/
task.rs

1use std::collections::HashMap;
2use std::fmt;
3use std::sync::Arc;
4
5use serde::de::DeserializeOwned;
6use serde::Serialize;
7
8use rustvello_proto::call::SerializedArguments;
9use rustvello_proto::config::TaskConfig;
10use rustvello_proto::identifiers::TaskId;
11
12use crate::error::{RustvelloError, RustvelloResult};
13
14// ---------------------------------------------------------------------------
15// Typed Task trait
16// ---------------------------------------------------------------------------
17
18/// A distributable task with typed parameters and results.
19///
20/// This is the Rust equivalent of pynenc's `Task` class. Each task
21/// definition implements this trait, providing:
22/// - A unique identity ([`TaskId`])
23/// - Configuration (retries, concurrency, etc.)
24/// - Typed execution (`Params` → `Result` via serde)
25///
26/// Tasks are typically created via the `#[rustvello::task]` proc-macro, but
27/// can also be implemented manually for testing or advanced use cases.
28///
29/// # Example (manual implementation)
30///
31/// ```rust
32/// use rustvello_core::task::Task;
33/// use rustvello_proto::config::TaskConfig;
34/// use rustvello_proto::identifiers::TaskId;
35/// use rustvello_core::error::RustvelloResult;
36///
37/// struct AddTask {
38///     task_id: TaskId,
39///     config: TaskConfig,
40/// }
41///
42/// impl Task for AddTask {
43///     type Params = (i32, i32);
44///     type Result = i32;
45///
46///     fn task_id(&self) -> &TaskId {
47///         &self.task_id
48///     }
49///
50///     fn config(&self) -> &TaskConfig {
51///         &self.config
52///     }
53///
54///     fn run(&self, params: Self::Params) -> RustvelloResult<Self::Result> {
55///         Ok(params.0 + params.1)
56///     }
57/// }
58/// ```
59pub trait Task: Send + Sync + 'static {
60    /// The input parameters type (must be serializable).
61    type Params: Serialize + DeserializeOwned + Send + Sync + 'static;
62    /// The return type (must be serializable).
63    type Result: Serialize + DeserializeOwned + Send + Sync + 'static;
64
65    /// Unique identifier for this task.
66    fn task_id(&self) -> &TaskId;
67
68    /// Per-task configuration.
69    fn config(&self) -> &TaskConfig;
70
71    /// Execute the task with the given parameters.
72    fn run(&self, params: Self::Params) -> RustvelloResult<Self::Result>;
73}
74
75// ---------------------------------------------------------------------------
76// Type-erased DynTask for heterogeneous registry storage
77// ---------------------------------------------------------------------------
78
79/// Type-erased task interface for the [`TaskRegistry`].
80///
81/// Every `T: Task` automatically implements `DynTask` via a blanket impl.
82/// The registry stores `Arc<dyn DynTask>`, which handles serialization
83/// and deserialization internally.
84pub trait DynTask: Send + Sync {
85    /// The task's unique identifier.
86    fn task_id(&self) -> &TaskId;
87
88    /// The task's configuration.
89    fn config(&self) -> &TaskConfig;
90
91    /// Execute with [`SerializedArguments`], returns serialized JSON result.
92    fn execute(&self, args: &SerializedArguments) -> RustvelloResult<String>;
93}
94
95/// Reconstruct a single JSON string from per-key [`SerializedArguments`].
96///
97/// - If only `__args__` is present, returns its raw value (non-struct params).
98/// - Otherwise, builds a JSON object from the key-value pairs with proper
99///   key escaping and value validation to prevent structural injection.
100pub fn serialized_args_to_json(
101    args: &SerializedArguments,
102) -> RustvelloResult<std::borrow::Cow<'_, str>> {
103    use std::borrow::Cow;
104    if args.0.len() == 1 && args.0.contains_key("__args__") {
105        // Non-struct params (primitives, tuples) stored under __args__
106        return Ok(Cow::Borrowed(&args.0["__args__"]));
107    }
108    // Struct params: build a JSON object string directly
109    use std::fmt::Write;
110    let mut buf = String::with_capacity(args.0.len() * 32 + 2);
111    buf.push('{');
112    for (i, (k, v)) in args.0.iter().enumerate() {
113        if i > 0 {
114            buf.push(',');
115        }
116        // Escape keys to prevent JSON injection from arbitrary input.
117        let escaped_key =
118            serde_json::to_string(k.as_str()).map_err(|e| RustvelloError::Serialization {
119                message: format!("failed to escape JSON key: {e}"),
120            })?;
121        // Validate that the value is valid JSON to prevent structural injection
122        serde_json::from_str::<serde_json::Value>(v).map_err(|e| {
123            RustvelloError::Serialization {
124                message: format!("invalid JSON value for key {k}: {e}"),
125            }
126        })?;
127        write!(buf, "{}:{}", escaped_key, v).map_err(|e| RustvelloError::Serialization {
128            message: format!("failed to build JSON: {e}"),
129        })?;
130    }
131    buf.push('}');
132    Ok(Cow::Owned(buf))
133}
134
135impl<T: Task> DynTask for T {
136    #[inline]
137    fn task_id(&self) -> &TaskId {
138        Task::task_id(self)
139    }
140
141    #[inline]
142    fn config(&self) -> &TaskConfig {
143        Task::config(self)
144    }
145
146    fn execute(&self, args: &SerializedArguments) -> RustvelloResult<String> {
147        let json_str = serialized_args_to_json(args)?;
148        let params: T::Params =
149            serde_json::from_str(&json_str).map_err(|e| RustvelloError::Serialization {
150                message: e.to_string(),
151            })?;
152        let result = self.run(params)?;
153        serde_json::to_string(&result).map_err(|e| RustvelloError::Serialization {
154            message: e.to_string(),
155        })
156    }
157}
158
159impl fmt::Debug for dyn DynTask {
160    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
161        f.debug_struct("DynTask")
162            .field("task_id", &self.task_id())
163            .finish()
164    }
165}
166
167// ---------------------------------------------------------------------------
168// Cross-language safety marker + ForeignTask trait
169// ---------------------------------------------------------------------------
170
171/// Marker trait for types that can safely cross language boundaries.
172///
173/// Types implementing this trait must serialize to/from JSON using only
174/// universally supported primitives: bool, number, string, array, object, null.
175/// This excludes language-specific types (Python objects, Rust enums with data, etc.).
176///
177/// Used as a bound on [`ForeignTask`] params and results to provide
178/// compile-time enforcement that cross-language calls use compatible types.
179pub trait CrossLanguageSafe: Serialize + DeserializeOwned {}
180
181// Blanket implementations for common JSON-safe types
182impl CrossLanguageSafe for String {}
183impl CrossLanguageSafe for bool {}
184impl CrossLanguageSafe for i32 {}
185impl CrossLanguageSafe for i64 {}
186impl CrossLanguageSafe for u32 {}
187impl CrossLanguageSafe for u64 {}
188impl CrossLanguageSafe for f32 {}
189impl CrossLanguageSafe for f64 {}
190impl<T: CrossLanguageSafe> CrossLanguageSafe for Vec<T> {}
191impl<T: CrossLanguageSafe> CrossLanguageSafe for Option<T> {}
192impl<K: CrossLanguageSafe + Ord, V: CrossLanguageSafe> CrossLanguageSafe
193    for std::collections::BTreeMap<K, V>
194{
195}
196impl<K: CrossLanguageSafe + Eq + std::hash::Hash, V: CrossLanguageSafe> CrossLanguageSafe
197    for std::collections::HashMap<K, V>
198{
199}
200
201/// A foreign task stub — represents a task implemented in another language.
202///
203/// Unlike [`Task`], a `ForeignTask` has no `run()` method because execution
204/// happens in the foreign language worker. The Rust side only creates
205/// invocations that the foreign worker picks up from its language queue.
206///
207/// The `CrossLanguageSafe` bound on `Params` and `Result` ensures that
208/// only JSON-compatible types are used for cross-language serialization.
209///
210/// # Example
211///
212/// ```rust
213/// use rustvello_core::task::{ForeignTask, CrossLanguageSafe};
214/// use rustvello_proto::config::TaskConfig;
215/// use rustvello_proto::identifiers::TaskId;
216/// use serde::{Serialize, Deserialize};
217///
218/// #[derive(Serialize, Deserialize)]
219/// struct TrainModelParams {
220///     dataset_path: String,
221///     epochs: u32,
222/// }
223/// impl CrossLanguageSafe for TrainModelParams {}
224///
225/// struct TrainModel {
226///     task_id: TaskId,
227/// }
228///
229/// impl ForeignTask for TrainModel {
230///     type Params = TrainModelParams;
231///     type Result = String;
232///
233///     fn task_id(&self) -> TaskId {
234///         self.task_id.clone()
235///     }
236/// }
237/// ```
238pub trait ForeignTask: Send + Sync + 'static {
239    /// The input parameters type (must be cross-language safe).
240    type Params: CrossLanguageSafe + Send + Sync + 'static;
241    /// The return type (must be cross-language safe).
242    type Result: CrossLanguageSafe + Send + Sync + 'static;
243
244    /// Unique identifier for this foreign task.
245    /// Must return a qualified `TaskId` (with non-empty `language`).
246    fn task_id(&self) -> TaskId;
247
248    /// Per-task configuration (optional override).
249    fn config(&self) -> TaskConfig {
250        TaskConfig::default()
251    }
252}
253
254/// Adapter: wraps a [`ForeignTask`] as a [`DynTask`].
255///
256/// Since foreign tasks have no `run()` implementation, `execute()` returns
257/// an error indicating the task must be processed by a foreign worker.
258///
259/// The adapter caches `task_id` and `config` from the inner [`ForeignTask`]
260/// at construction time so that [`DynTask`] can return references.
261struct ForeignTaskAdapter<F: ForeignTask> {
262    _inner: F,
263    task_id: TaskId,
264    config: TaskConfig,
265}
266
267impl<F: ForeignTask> ForeignTaskAdapter<F> {
268    fn new(task: F) -> Self {
269        let task_id = task.task_id();
270        let config = task.config();
271        Self {
272            _inner: task,
273            task_id,
274            config,
275        }
276    }
277}
278
279impl<F: ForeignTask> DynTask for ForeignTaskAdapter<F> {
280    fn task_id(&self) -> &TaskId {
281        &self.task_id
282    }
283
284    fn config(&self) -> &TaskConfig {
285        &self.config
286    }
287
288    fn execute(&self, _args: &SerializedArguments) -> RustvelloResult<String> {
289        Err(RustvelloError::Configuration {
290            message: format!(
291                "foreign task {} cannot be executed locally — must be processed by a {} worker",
292                self.task_id,
293                self.task_id.language(),
294            ),
295        })
296    }
297}
298
299// ---------------------------------------------------------------------------
300// Legacy untyped TaskFn/TaskDefinition (preserved for backward compatibility)
301// ---------------------------------------------------------------------------
302
303/// A function that can be executed as a task (untyped, legacy).
304///
305/// In Rust, tasks are registered as boxed closures or function pointers.
306/// The input and output are serialized JSON strings to allow heterogeneous
307/// task types in the same registry.
308///
309/// **Prefer using the typed [`Task`] trait for new code.**
310pub type TaskFn = Arc<dyn Fn(String) -> RustvelloResult<String> + Send + Sync>;
311
312/// A registered task definition with its metadata and executable function (legacy).
313///
314/// **Prefer using the typed [`Task`] trait for new code.** This type exists
315/// for backward compatibility with code that uses `TaskFn` closures.
316pub struct TaskDefinition {
317    pub task_id: TaskId,
318    pub config: TaskConfig,
319    pub func: TaskFn,
320}
321
322impl TaskDefinition {
323    pub fn new(task_id: TaskId, config: TaskConfig, func: TaskFn) -> Self {
324        Self {
325            task_id,
326            config,
327            func,
328        }
329    }
330}
331
332impl fmt::Debug for TaskDefinition {
333    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
334        f.debug_struct("TaskDefinition")
335            .field("task_id", &self.task_id)
336            .field("config", &self.config)
337            .finish()
338    }
339}
340
341/// Adapter: wraps a legacy [`TaskDefinition`] as a [`DynTask`].
342struct LegacyTaskAdapter {
343    definition: Arc<TaskDefinition>,
344}
345
346impl DynTask for LegacyTaskAdapter {
347    fn task_id(&self) -> &TaskId {
348        &self.definition.task_id
349    }
350
351    fn config(&self) -> &TaskConfig {
352        &self.definition.config
353    }
354
355    fn execute(&self, args: &SerializedArguments) -> RustvelloResult<String> {
356        // Legacy tasks expect the BTreeMap<String, String> serialized as JSON
357        let args_json =
358            serde_json::to_string(&args.0).map_err(|e| RustvelloError::Serialization {
359                message: e.to_string(),
360            })?;
361        (self.definition.func)(args_json)
362    }
363}
364
365// ---------------------------------------------------------------------------
366// TaskRegistry — stores both typed and legacy tasks
367// ---------------------------------------------------------------------------
368
369/// Registry holding all known task definitions for this application.
370///
371/// Tasks must be registered before they can be invoked. Supports both
372/// typed tasks (via [`Task`] trait) and legacy closure-based tasks.
373#[derive(Default)]
374pub struct TaskRegistry {
375    tasks: HashMap<TaskId, Arc<dyn DynTask>>,
376    /// Legacy index for backward-compatible `get_legacy()` access.
377    legacy_tasks: HashMap<TaskId, Arc<TaskDefinition>>,
378}
379
380impl TaskRegistry {
381    pub fn new() -> Self {
382        Self::default()
383    }
384
385    /// Register a typed task. Returns error if the task ID is already registered.
386    pub fn register_typed<T: Task>(&mut self, task: T) -> RustvelloResult<()> {
387        let task_id = task.task_id().clone();
388        if self.tasks.contains_key(&task_id) {
389            return Err(RustvelloError::Configuration {
390                message: format!("task already registered: {}", task_id),
391            });
392        }
393        self.tasks.insert(task_id, Arc::new(task));
394        Ok(())
395    }
396
397    /// Register a foreign task stub. Returns error if the task ID is already registered
398    /// or if the task ID does not have a non-empty language (i.e. is not foreign).
399    pub fn register_foreign<F: ForeignTask>(&mut self, task: F) -> RustvelloResult<()> {
400        let task_id = task.task_id();
401        if !task_id.is_foreign() {
402            return Err(RustvelloError::Configuration {
403                message: format!(
404                    "ForeignTask must have a non-empty language, got: {}",
405                    task_id
406                ),
407            });
408        }
409        if self.tasks.contains_key(&task_id) {
410            return Err(RustvelloError::Configuration {
411                message: format!("task already registered: {}", task_id),
412            });
413        }
414        self.tasks
415            .insert(task_id, Arc::new(ForeignTaskAdapter::new(task)));
416        Ok(())
417    }
418
419    /// Register a legacy task definition. Returns error if already registered.
420    pub fn register(&mut self, definition: TaskDefinition) -> RustvelloResult<()> {
421        let task_id = definition.task_id.clone();
422        if self.tasks.contains_key(&task_id) {
423            return Err(RustvelloError::Configuration {
424                message: format!("task already registered: {}", task_id),
425            });
426        }
427        let def = Arc::new(definition);
428        let adapter = LegacyTaskAdapter {
429            definition: Arc::clone(&def),
430        };
431        self.tasks.insert(task_id.clone(), Arc::new(adapter));
432        self.legacy_tasks.insert(task_id, def);
433        Ok(())
434    }
435
436    /// Get a type-erased task by ID.
437    pub fn get_dyn(&self, task_id: &TaskId) -> Option<Arc<dyn DynTask>> {
438        self.tasks.get(task_id).cloned()
439    }
440
441    /// Get a legacy task definition by ID (backward compatibility).
442    pub fn get(&self, task_id: &TaskId) -> Option<Arc<TaskDefinition>> {
443        self.legacy_tasks.get(task_id).cloned()
444    }
445
446    /// Check if a task is registered.
447    pub fn contains(&self, task_id: &TaskId) -> bool {
448        self.tasks.contains_key(task_id)
449    }
450
451    /// List all registered task IDs.
452    pub fn task_ids(&self) -> Vec<&TaskId> {
453        self.tasks.keys().collect()
454    }
455
456    /// Number of registered tasks.
457    pub fn len(&self) -> usize {
458        self.tasks.len()
459    }
460
461    pub fn is_empty(&self) -> bool {
462        self.tasks.is_empty()
463    }
464}
465
466impl fmt::Debug for TaskRegistry {
467    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
468        f.debug_struct("TaskRegistry")
469            .field("tasks", &self.tasks.keys().collect::<Vec<_>>())
470            .finish()
471    }
472}
473
474// ---------------------------------------------------------------------------
475// TaskModule — grouping of task registrations
476// ---------------------------------------------------------------------------
477
478/// A module that registers one or more tasks with a [`TaskRegistry`].
479///
480/// Inspired by pynenc's plugin system — each module groups related tasks
481/// and registers them at application startup.
482///
483/// # Example
484///
485/// ```rust
486/// use rustvello_core::task::{TaskModule, TaskRegistry, TaskDefinition};
487/// use rustvello_proto::config::TaskConfig;
488/// use rustvello_proto::identifiers::TaskId;
489/// use rustvello_core::error::RustvelloResult;
490/// use std::sync::Arc;
491///
492/// struct MathTasks;
493///
494/// impl TaskModule for MathTasks {
495///     fn name(&self) -> &str { "math" }
496///
497///     fn register(&self, registry: &mut TaskRegistry) -> RustvelloResult<()> {
498///         registry.register(TaskDefinition::new(
499///             TaskId::new("math", "add"),
500///             TaskConfig::default(),
501///             Arc::new(|_| Ok("0".to_string())),
502///         ))
503///     }
504/// }
505/// ```
506pub trait TaskModule: Send + Sync {
507    /// Human-readable name for this module (for logging/diagnostics).
508    fn name(&self) -> &str;
509
510    /// Register all tasks provided by this module.
511    fn register(&self, registry: &mut TaskRegistry) -> RustvelloResult<()>;
512}
513
514#[cfg(test)]
515mod tests {
516    use super::*;
517
518    fn dummy_fn() -> TaskFn {
519        Arc::new(|_| Ok("ok".to_string()))
520    }
521
522    #[test]
523    fn registry_new_is_empty() {
524        let reg = TaskRegistry::new();
525        assert!(reg.is_empty());
526        assert_eq!(reg.len(), 0);
527    }
528
529    #[test]
530    fn register_and_get() {
531        let mut reg = TaskRegistry::new();
532        let tid = TaskId::new("mod", "func");
533        reg.register(TaskDefinition::new(
534            tid.clone(),
535            TaskConfig::default(),
536            dummy_fn(),
537        ))
538        .unwrap();
539
540        assert_eq!(reg.len(), 1);
541        assert!(!reg.is_empty());
542        assert!(reg.contains(&tid));
543        assert!(reg.get(&tid).is_some());
544        assert_eq!(reg.get(&tid).unwrap().task_id, tid);
545    }
546
547    #[test]
548    fn register_duplicate_errors() {
549        let mut reg = TaskRegistry::new();
550        let tid = TaskId::new("mod", "func");
551        reg.register(TaskDefinition::new(
552            tid.clone(),
553            TaskConfig::default(),
554            dummy_fn(),
555        ))
556        .unwrap();
557        let result = reg.register(TaskDefinition::new(tid, TaskConfig::default(), dummy_fn()));
558        assert!(result.is_err());
559    }
560
561    #[test]
562    fn get_nonexistent_returns_none() {
563        let reg = TaskRegistry::new();
564        let tid = TaskId::new("no", "such");
565        assert!(!reg.contains(&tid));
566        assert!(reg.get(&tid).is_none());
567    }
568
569    #[test]
570    fn task_ids_lists_all() {
571        let mut reg = TaskRegistry::new();
572        let t1 = TaskId::new("mod", "a");
573        let t2 = TaskId::new("mod", "b");
574        reg.register(TaskDefinition::new(
575            t1.clone(),
576            TaskConfig::default(),
577            dummy_fn(),
578        ))
579        .unwrap();
580        reg.register(TaskDefinition::new(
581            t2.clone(),
582            TaskConfig::default(),
583            dummy_fn(),
584        ))
585        .unwrap();
586
587        let ids = reg.task_ids();
588        assert_eq!(ids.len(), 2);
589        assert!(ids.contains(&&t1));
590        assert!(ids.contains(&&t2));
591    }
592
593    #[test]
594    fn task_definition_debug() {
595        let def = TaskDefinition::new(
596            TaskId::new("mod", "func"),
597            TaskConfig::default(),
598            dummy_fn(),
599        );
600        let debug_str = format!("{:?}", def);
601        assert!(debug_str.contains("mod"));
602        assert!(debug_str.contains("func"));
603    }
604
605    // -- Cross-language tests --
606
607    #[derive(serde::Serialize, serde::Deserialize)]
608    struct TestParams {
609        value: String,
610    }
611    impl CrossLanguageSafe for TestParams {}
612
613    struct TestForeignTask;
614
615    impl ForeignTask for TestForeignTask {
616        type Params = TestParams;
617        type Result = String;
618
619        fn task_id(&self) -> TaskId {
620            TaskId::foreign("python", "analytics.tasks", "train_model")
621        }
622    }
623
624    #[test]
625    fn register_foreign_task() {
626        let mut reg = TaskRegistry::new();
627        reg.register_foreign(TestForeignTask).unwrap();
628
629        let tid = TaskId::foreign("python", "analytics.tasks", "train_model");
630        assert!(reg.contains(&tid));
631        assert_eq!(reg.len(), 1);
632
633        let dyn_task = reg.get_dyn(&tid).unwrap();
634        assert_eq!(dyn_task.task_id(), &tid);
635        assert!(dyn_task.task_id().is_foreign());
636    }
637
638    #[test]
639    fn foreign_task_execute_returns_error() {
640        let mut reg = TaskRegistry::new();
641        reg.register_foreign(TestForeignTask).unwrap();
642
643        let tid = TaskId::foreign("python", "analytics.tasks", "train_model");
644        let dyn_task = reg.get_dyn(&tid).unwrap();
645
646        let args = SerializedArguments::default();
647        let result = dyn_task.execute(&args);
648        assert!(result.is_err());
649        let err_msg = result.unwrap_err().to_string();
650        assert!(err_msg.contains("foreign task"));
651        assert!(err_msg.contains("python"));
652    }
653
654    #[test]
655    fn register_foreign_duplicate_errors() {
656        let mut reg = TaskRegistry::new();
657        reg.register_foreign(TestForeignTask).unwrap();
658        let result = reg.register_foreign(TestForeignTask);
659        assert!(result.is_err());
660    }
661
662    #[test]
663    fn cross_language_safe_primitives() {
664        // Verify the marker trait compiles for common types
665        fn assert_cls<T: CrossLanguageSafe>() {}
666        assert_cls::<String>();
667        assert_cls::<bool>();
668        assert_cls::<i32>();
669        assert_cls::<i64>();
670        assert_cls::<u32>();
671        assert_cls::<u64>();
672        assert_cls::<f32>();
673        assert_cls::<f64>();
674        assert_cls::<Vec<String>>();
675        assert_cls::<Option<i64>>();
676        assert_cls::<std::collections::BTreeMap<String, i64>>();
677        assert_cls::<std::collections::HashMap<String, String>>();
678    }
679}