Skip to main content

rustvello_core/
call.rs

1//! Typed call representation bridging task parameters to serialized form.
2//!
3//! A [`Call`] represents a specific invocation of a [`Task`] with concrete
4//! parameters. It lazily computes the serialized form and deterministic
5//! [`CallId`] exactly like pynenc's `Call` class.
6//!
7//! The call hierarchy mirrors pynenc:
8//! - [`Call`] holds typed params + task reference → computes [`CallDTO`]
9//! - [`CallId`] = `TaskId` + SHA256(serialized args) → deterministic identity
10//! - [`CallDTO`] is the serialized form suitable for persistence
11
12use std::collections::BTreeMap;
13use std::marker::PhantomData;
14
15use rustvello_proto::call::{CallDTO, SerializedArguments};
16use rustvello_proto::identifiers::{CallId, TaskId};
17use rustvello_proto::status::ConcurrencyControlType;
18
19use crate::error::{RustvelloError, RustvelloResult};
20use crate::task::Task;
21
22/// A typed task call with concrete parameters.
23///
24/// Mirrors pynenc's `Call` class. Holds a reference to the task and the
25/// typed parameters. Lazily serializes the parameters and computes the
26/// deterministic [`CallId`] on demand.
27///
28/// # Example
29///
30/// ```rust
31/// use rustvello_core::call::Call;
32/// use rustvello_core::task::Task;
33/// use rustvello_proto::config::TaskConfig;
34/// use rustvello_proto::identifiers::TaskId;
35/// use rustvello_core::error::RustvelloResult;
36///
37/// struct DoubleTask { task_id: TaskId, config: TaskConfig }
38/// impl DoubleTask {
39///     fn new() -> Self {
40///         Self { task_id: TaskId::new("example", "double"), config: TaskConfig::default() }
41///     }
42/// }
43/// impl Task for DoubleTask {
44///     type Params = i32;
45///     type Result = i32;
46///     fn task_id(&self) -> &TaskId { &self.task_id }
47///     fn config(&self) -> &TaskConfig { &self.config }
48///     fn run(&self, x: i32) -> RustvelloResult<i32> { Ok(x * 2) }
49/// }
50///
51/// let task = DoubleTask::new();
52/// let call = Call::new(&task, 21);
53/// let dto = call.to_dto().unwrap();
54/// assert_eq!(dto.task_id, TaskId::new("example", "double"));
55/// ```
56pub struct Call<'a, T: Task> {
57    task: &'a T,
58    params: T::Params,
59    _marker: PhantomData<T::Result>,
60}
61
62impl<'a, T: Task> Call<'a, T> {
63    /// Create a new call with the given task and parameters.
64    pub fn new(task: &'a T, params: T::Params) -> Self {
65        Self {
66            task,
67            params,
68            _marker: PhantomData,
69        }
70    }
71
72    /// Get a reference to the task.
73    pub fn task(&self) -> &T {
74        self.task
75    }
76
77    /// Get a reference to the parameters.
78    pub fn params(&self) -> &T::Params {
79        &self.params
80    }
81
82    /// Consume the call and return the parameters.
83    pub fn into_params(self) -> T::Params {
84        self.params
85    }
86
87    /// Serialize the parameters to a JSON string.
88    pub fn serialize_params(&self) -> RustvelloResult<String> {
89        serde_json::to_string(&self.params).map_err(|e| RustvelloError::Serialization {
90            message: e.to_string(),
91        })
92    }
93
94    /// Compute the serialized arguments as a [`SerializedArguments`].
95    ///
96    /// For struct-like params, each field becomes a key-value pair.
97    /// For other types (primitives, tuples), the entire value is stored
98    /// under a single `"__args__"` key.
99    pub fn serialized_arguments(&self) -> RustvelloResult<SerializedArguments> {
100        let value =
101            serde_json::to_value(&self.params).map_err(|e| RustvelloError::Serialization {
102                message: e.to_string(),
103            })?;
104
105        let mut args = SerializedArguments::new();
106        match value {
107            serde_json::Value::Object(map) => {
108                for (k, v) in map {
109                    let v_str =
110                        serde_json::to_string(&v).map_err(|e| RustvelloError::Serialization {
111                            message: e.to_string(),
112                        })?;
113                    args.insert(k, v_str);
114                }
115            }
116            other => {
117                let v_str =
118                    serde_json::to_string(&other).map_err(|e| RustvelloError::Serialization {
119                        message: e.to_string(),
120                    })?;
121                args.insert("__args__", v_str);
122            }
123        }
124        Ok(args)
125    }
126
127    /// Compute the deterministic [`CallId`] for this call.
128    pub fn call_id(&self) -> RustvelloResult<CallId> {
129        let args = self.serialized_arguments()?;
130        let args_id = args.compute_args_id();
131        Ok(CallId::new(self.task.task_id().clone(), args_id))
132    }
133
134    /// Convert to a [`CallDTO`] suitable for persistence.
135    pub fn to_dto(&self) -> RustvelloResult<CallDTO> {
136        let args = self.serialized_arguments()?;
137        Ok(CallDTO::new(self.task.task_id().clone(), args))
138    }
139
140    /// Returns the serialized arguments relevant for concurrency checking.
141    ///
142    /// Mirrors pynenc's `Call.serialized_args_for_concurrency_check`.
143    /// The result depends on the task's concurrency control configuration:
144    /// - `Unlimited` → `None` (no concurrency check needed)
145    /// - `Task` → `Some(empty)` (task-level only, no args)
146    /// - `Argument` → `Some(all args)` or `Some(key_arguments subset)` if key_arguments is set
147    /// - `None` → `Some(all args)` (strictest: full dedup)
148    pub fn serialized_args_for_concurrency_check(
149        &self,
150    ) -> RustvelloResult<Option<SerializedArguments>> {
151        let config = self.task.config();
152        match config.concurrency_control {
153            ConcurrencyControlType::Unlimited => Ok(None),
154            ConcurrencyControlType::Task => Ok(Some(SerializedArguments::new())),
155            ConcurrencyControlType::Argument => {
156                let all_args = self.serialized_arguments()?;
157                if config.key_arguments.is_empty() {
158                    Ok(Some(all_args))
159                } else {
160                    let mut filtered = SerializedArguments::new();
161                    for key in &config.key_arguments {
162                        if let Some(val) = all_args.0.get(key) {
163                            filtered.insert(key, val.clone());
164                        }
165                    }
166                    Ok(Some(filtered))
167                }
168            }
169            ConcurrencyControlType::None => {
170                let all_args = self.serialized_arguments()?;
171                Ok(Some(all_args))
172            }
173            // #[non_exhaustive] requires a fallback — treat unknown variants
174            // conservatively by applying full-argument dedup.
175            _ => {
176                let all_args = self.serialized_arguments()?;
177                Ok(Some(all_args))
178            }
179        }
180    }
181}
182
183/// Create a [`CallDTO`] from raw parts (for use in runners when executing).
184pub fn call_dto_from_parts(task_id: TaskId, serialized_args: BTreeMap<String, String>) -> CallDTO {
185    let mut args = SerializedArguments::new();
186    for (k, v) in serialized_args {
187        args.insert(k, v);
188    }
189    CallDTO::new(task_id, args)
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195    use crate::error::RustvelloResult;
196    use rustvello_proto::config::TaskConfig;
197    use serde::{Deserialize, Serialize};
198
199    // -- Helper task for tests --
200
201    struct AddTask {
202        task_id: TaskId,
203        config: TaskConfig,
204    }
205    impl AddTask {
206        fn new() -> Self {
207            Self {
208                task_id: TaskId::new("test", "add"),
209                config: TaskConfig::default(),
210            }
211        }
212    }
213    impl Task for AddTask {
214        type Params = AddParams;
215        type Result = i32;
216        fn task_id(&self) -> &TaskId {
217            &self.task_id
218        }
219        fn config(&self) -> &TaskConfig {
220            &self.config
221        }
222        fn run(&self, p: AddParams) -> RustvelloResult<i32> {
223            Ok(p.x + p.y)
224        }
225    }
226
227    #[derive(Serialize, Deserialize)]
228    struct AddParams {
229        x: i32,
230        y: i32,
231    }
232
233    struct DoubleTask {
234        task_id: TaskId,
235        config: TaskConfig,
236    }
237    impl DoubleTask {
238        fn new() -> Self {
239            Self {
240                task_id: TaskId::new("test", "double"),
241                config: TaskConfig::default(),
242            }
243        }
244    }
245    impl Task for DoubleTask {
246        type Params = i32;
247        type Result = i32;
248        fn task_id(&self) -> &TaskId {
249            &self.task_id
250        }
251        fn config(&self) -> &TaskConfig {
252            &self.config
253        }
254        fn run(&self, x: i32) -> RustvelloResult<i32> {
255            Ok(x * 2)
256        }
257    }
258
259    #[test]
260    fn call_serialized_arguments_struct() {
261        let task = AddTask::new();
262        let call = Call::new(&task, AddParams { x: 1, y: 2 });
263        let args = call.serialized_arguments().unwrap();
264        // Struct params become individual keys
265        assert!(args.0.contains_key("x"));
266        assert!(args.0.contains_key("y"));
267        assert_eq!(args.0["x"], "1");
268        assert_eq!(args.0["y"], "2");
269    }
270
271    #[test]
272    fn call_serialized_arguments_primitive() {
273        let task = DoubleTask::new();
274        let call = Call::new(&task, 42);
275        let args = call.serialized_arguments().unwrap();
276        // Non-struct params go under __args__
277        assert!(args.0.contains_key("__args__"));
278        assert_eq!(args.0["__args__"], "42");
279    }
280
281    #[test]
282    fn call_id_deterministic() {
283        let task1 = AddTask::new();
284        let call1 = Call::new(&task1, AddParams { x: 1, y: 2 });
285        let task2 = AddTask::new();
286        let call2 = Call::new(&task2, AddParams { x: 1, y: 2 });
287        assert_eq!(call1.call_id().unwrap(), call2.call_id().unwrap());
288    }
289
290    #[test]
291    fn call_id_different_args() {
292        let task1 = AddTask::new();
293        let call1 = Call::new(&task1, AddParams { x: 1, y: 2 });
294        let task2 = AddTask::new();
295        let call2 = Call::new(&task2, AddParams { x: 3, y: 4 });
296        assert_ne!(call1.call_id().unwrap(), call2.call_id().unwrap());
297    }
298
299    #[test]
300    fn call_to_dto() {
301        let task = AddTask::new();
302        let call = Call::new(&task, AddParams { x: 10, y: 20 });
303        let dto = call.to_dto().unwrap();
304        assert_eq!(dto.task_id, TaskId::new("test", "add"));
305        assert_eq!(dto.serialized_arguments.0["x"], "10");
306        assert_eq!(dto.serialized_arguments.0["y"], "20");
307    }
308
309    #[test]
310    fn call_dto_from_parts_works() {
311        let mut map = BTreeMap::new();
312        map.insert("a".to_string(), "1".to_string());
313        let dto = call_dto_from_parts(TaskId::new("m", "f"), map);
314        assert_eq!(dto.task_id, TaskId::new("m", "f"));
315        assert_eq!(dto.serialized_arguments.0["a"], "1");
316    }
317
318    // -- Concurrency control args tests --
319
320    struct TaskCCTask {
321        task_id: TaskId,
322        config: TaskConfig,
323    }
324    impl TaskCCTask {
325        fn new() -> Self {
326            let mut config = TaskConfig::default();
327            config.concurrency_control = ConcurrencyControlType::Task;
328            Self {
329                task_id: TaskId::new("test", "cc_task"),
330                config,
331            }
332        }
333    }
334    impl Task for TaskCCTask {
335        type Params = AddParams;
336        type Result = i32;
337        fn task_id(&self) -> &TaskId {
338            &self.task_id
339        }
340        fn config(&self) -> &TaskConfig {
341            &self.config
342        }
343        fn run(&self, p: AddParams) -> RustvelloResult<i32> {
344            Ok(p.x + p.y)
345        }
346    }
347
348    struct ArgCCTask {
349        task_id: TaskId,
350        config: TaskConfig,
351    }
352    impl ArgCCTask {
353        fn new() -> Self {
354            let mut config = TaskConfig::default();
355            config.concurrency_control = ConcurrencyControlType::Argument;
356            Self {
357                task_id: TaskId::new("test", "cc_arg"),
358                config,
359            }
360        }
361    }
362    impl Task for ArgCCTask {
363        type Params = AddParams;
364        type Result = i32;
365        fn task_id(&self) -> &TaskId {
366            &self.task_id
367        }
368        fn config(&self) -> &TaskConfig {
369            &self.config
370        }
371        fn run(&self, p: AddParams) -> RustvelloResult<i32> {
372            Ok(p.x + p.y)
373        }
374    }
375
376    struct KeyCCTask {
377        task_id: TaskId,
378        config: TaskConfig,
379    }
380    impl KeyCCTask {
381        fn new() -> Self {
382            let mut config = TaskConfig::default();
383            config.concurrency_control = ConcurrencyControlType::Argument;
384            config.key_arguments = vec!["x".to_string()];
385            Self {
386                task_id: TaskId::new("test", "cc_key"),
387                config,
388            }
389        }
390    }
391    impl Task for KeyCCTask {
392        type Params = AddParams;
393        type Result = i32;
394        fn task_id(&self) -> &TaskId {
395            &self.task_id
396        }
397        fn config(&self) -> &TaskConfig {
398            &self.config
399        }
400        fn run(&self, p: AddParams) -> RustvelloResult<i32> {
401            Ok(p.x + p.y)
402        }
403    }
404
405    struct NoneCCTask {
406        task_id: TaskId,
407        config: TaskConfig,
408    }
409    impl NoneCCTask {
410        fn new() -> Self {
411            let mut config = TaskConfig::default();
412            config.concurrency_control = ConcurrencyControlType::None;
413            Self {
414                task_id: TaskId::new("test", "cc_none"),
415                config,
416            }
417        }
418    }
419    impl Task for NoneCCTask {
420        type Params = AddParams;
421        type Result = i32;
422        fn task_id(&self) -> &TaskId {
423            &self.task_id
424        }
425        fn config(&self) -> &TaskConfig {
426            &self.config
427        }
428        fn run(&self, p: AddParams) -> RustvelloResult<i32> {
429            Ok(p.x + p.y)
430        }
431    }
432
433    #[test]
434    fn cc_args_unlimited_returns_none() {
435        let task = AddTask::new();
436        let call = Call::new(&task, AddParams { x: 1, y: 2 });
437        assert!(call
438            .serialized_args_for_concurrency_check()
439            .unwrap()
440            .is_none());
441    }
442
443    #[test]
444    fn cc_args_task_returns_empty() {
445        let task = TaskCCTask::new();
446        let call = Call::new(&task, AddParams { x: 1, y: 2 });
447        let args = call
448            .serialized_args_for_concurrency_check()
449            .unwrap()
450            .unwrap();
451        assert!(args.0.is_empty());
452    }
453
454    #[test]
455    fn cc_args_argument_returns_all() {
456        let task = ArgCCTask::new();
457        let call = Call::new(&task, AddParams { x: 1, y: 2 });
458        let args = call
459            .serialized_args_for_concurrency_check()
460            .unwrap()
461            .unwrap();
462        assert_eq!(args.0.len(), 2);
463        assert_eq!(args.0["x"], "1");
464        assert_eq!(args.0["y"], "2");
465    }
466
467    #[test]
468    fn cc_args_argument_with_key_args_returns_subset() {
469        let task = KeyCCTask::new();
470        let call = Call::new(&task, AddParams { x: 1, y: 2 });
471        let args = call
472            .serialized_args_for_concurrency_check()
473            .unwrap()
474            .unwrap();
475        assert_eq!(args.0.len(), 1);
476        assert_eq!(args.0["x"], "1");
477        assert!(!args.0.contains_key("y"));
478    }
479
480    #[test]
481    fn cc_args_none_returns_all() {
482        let task = NoneCCTask::new();
483        let call = Call::new(&task, AddParams { x: 1, y: 2 });
484        let args = call
485            .serialized_args_for_concurrency_check()
486            .unwrap()
487            .unwrap();
488        assert_eq!(args.0.len(), 2);
489    }
490}