1use std::collections::BTreeMap;
13use std::marker::PhantomData;
14
15use rustvello_proto::call::{CallDTO, SerializedArguments};
16use rustvello_proto::identifiers::{CallId, TaskId};
17use rustvello_proto::status::ConcurrencyControlType;
18
19use crate::error::{RustvelloError, RustvelloResult};
20use crate::task::Task;
21
22pub struct Call<'a, T: Task> {
57 task: &'a T,
58 params: T::Params,
59 _marker: PhantomData<T::Result>,
60}
61
62impl<'a, T: Task> Call<'a, T> {
63 pub fn new(task: &'a T, params: T::Params) -> Self {
65 Self {
66 task,
67 params,
68 _marker: PhantomData,
69 }
70 }
71
72 pub fn task(&self) -> &T {
74 self.task
75 }
76
77 pub fn params(&self) -> &T::Params {
79 &self.params
80 }
81
82 pub fn into_params(self) -> T::Params {
84 self.params
85 }
86
87 pub fn serialize_params(&self) -> RustvelloResult<String> {
89 serde_json::to_string(&self.params).map_err(|e| RustvelloError::Serialization {
90 message: e.to_string(),
91 })
92 }
93
94 pub fn serialized_arguments(&self) -> RustvelloResult<SerializedArguments> {
100 let value =
101 serde_json::to_value(&self.params).map_err(|e| RustvelloError::Serialization {
102 message: e.to_string(),
103 })?;
104
105 let mut args = SerializedArguments::new();
106 match value {
107 serde_json::Value::Object(map) => {
108 for (k, v) in map {
109 let v_str =
110 serde_json::to_string(&v).map_err(|e| RustvelloError::Serialization {
111 message: e.to_string(),
112 })?;
113 args.insert(k, v_str);
114 }
115 }
116 other => {
117 let v_str =
118 serde_json::to_string(&other).map_err(|e| RustvelloError::Serialization {
119 message: e.to_string(),
120 })?;
121 args.insert("__args__", v_str);
122 }
123 }
124 Ok(args)
125 }
126
127 pub fn call_id(&self) -> RustvelloResult<CallId> {
129 let args = self.serialized_arguments()?;
130 let args_id = args.compute_args_id();
131 Ok(CallId::new(self.task.task_id().clone(), args_id))
132 }
133
134 pub fn to_dto(&self) -> RustvelloResult<CallDTO> {
136 let args = self.serialized_arguments()?;
137 Ok(CallDTO::new(self.task.task_id().clone(), args))
138 }
139
140 pub fn serialized_args_for_concurrency_check(
149 &self,
150 ) -> RustvelloResult<Option<SerializedArguments>> {
151 let config = self.task.config();
152 match config.concurrency_control {
153 ConcurrencyControlType::Unlimited => Ok(None),
154 ConcurrencyControlType::Task => Ok(Some(SerializedArguments::new())),
155 ConcurrencyControlType::Argument => {
156 let all_args = self.serialized_arguments()?;
157 if config.key_arguments.is_empty() {
158 Ok(Some(all_args))
159 } else {
160 let mut filtered = SerializedArguments::new();
161 for key in &config.key_arguments {
162 if let Some(val) = all_args.0.get(key) {
163 filtered.insert(key, val.clone());
164 }
165 }
166 Ok(Some(filtered))
167 }
168 }
169 ConcurrencyControlType::None => {
170 let all_args = self.serialized_arguments()?;
171 Ok(Some(all_args))
172 }
173 _ => {
176 let all_args = self.serialized_arguments()?;
177 Ok(Some(all_args))
178 }
179 }
180 }
181}
182
183pub fn call_dto_from_parts(task_id: TaskId, serialized_args: BTreeMap<String, String>) -> CallDTO {
185 let mut args = SerializedArguments::new();
186 for (k, v) in serialized_args {
187 args.insert(k, v);
188 }
189 CallDTO::new(task_id, args)
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195 use crate::error::RustvelloResult;
196 use rustvello_proto::config::TaskConfig;
197 use serde::{Deserialize, Serialize};
198
199 struct AddTask {
202 task_id: TaskId,
203 config: TaskConfig,
204 }
205 impl AddTask {
206 fn new() -> Self {
207 Self {
208 task_id: TaskId::new("test", "add"),
209 config: TaskConfig::default(),
210 }
211 }
212 }
213 impl Task for AddTask {
214 type Params = AddParams;
215 type Result = i32;
216 fn task_id(&self) -> &TaskId {
217 &self.task_id
218 }
219 fn config(&self) -> &TaskConfig {
220 &self.config
221 }
222 fn run(&self, p: AddParams) -> RustvelloResult<i32> {
223 Ok(p.x + p.y)
224 }
225 }
226
227 #[derive(Serialize, Deserialize)]
228 struct AddParams {
229 x: i32,
230 y: i32,
231 }
232
233 struct DoubleTask {
234 task_id: TaskId,
235 config: TaskConfig,
236 }
237 impl DoubleTask {
238 fn new() -> Self {
239 Self {
240 task_id: TaskId::new("test", "double"),
241 config: TaskConfig::default(),
242 }
243 }
244 }
245 impl Task for DoubleTask {
246 type Params = i32;
247 type Result = i32;
248 fn task_id(&self) -> &TaskId {
249 &self.task_id
250 }
251 fn config(&self) -> &TaskConfig {
252 &self.config
253 }
254 fn run(&self, x: i32) -> RustvelloResult<i32> {
255 Ok(x * 2)
256 }
257 }
258
259 #[test]
260 fn call_serialized_arguments_struct() {
261 let task = AddTask::new();
262 let call = Call::new(&task, AddParams { x: 1, y: 2 });
263 let args = call.serialized_arguments().unwrap();
264 assert!(args.0.contains_key("x"));
266 assert!(args.0.contains_key("y"));
267 assert_eq!(args.0["x"], "1");
268 assert_eq!(args.0["y"], "2");
269 }
270
271 #[test]
272 fn call_serialized_arguments_primitive() {
273 let task = DoubleTask::new();
274 let call = Call::new(&task, 42);
275 let args = call.serialized_arguments().unwrap();
276 assert!(args.0.contains_key("__args__"));
278 assert_eq!(args.0["__args__"], "42");
279 }
280
281 #[test]
282 fn call_id_deterministic() {
283 let task1 = AddTask::new();
284 let call1 = Call::new(&task1, AddParams { x: 1, y: 2 });
285 let task2 = AddTask::new();
286 let call2 = Call::new(&task2, AddParams { x: 1, y: 2 });
287 assert_eq!(call1.call_id().unwrap(), call2.call_id().unwrap());
288 }
289
290 #[test]
291 fn call_id_different_args() {
292 let task1 = AddTask::new();
293 let call1 = Call::new(&task1, AddParams { x: 1, y: 2 });
294 let task2 = AddTask::new();
295 let call2 = Call::new(&task2, AddParams { x: 3, y: 4 });
296 assert_ne!(call1.call_id().unwrap(), call2.call_id().unwrap());
297 }
298
299 #[test]
300 fn call_to_dto() {
301 let task = AddTask::new();
302 let call = Call::new(&task, AddParams { x: 10, y: 20 });
303 let dto = call.to_dto().unwrap();
304 assert_eq!(dto.task_id, TaskId::new("test", "add"));
305 assert_eq!(dto.serialized_arguments.0["x"], "10");
306 assert_eq!(dto.serialized_arguments.0["y"], "20");
307 }
308
309 #[test]
310 fn call_dto_from_parts_works() {
311 let mut map = BTreeMap::new();
312 map.insert("a".to_string(), "1".to_string());
313 let dto = call_dto_from_parts(TaskId::new("m", "f"), map);
314 assert_eq!(dto.task_id, TaskId::new("m", "f"));
315 assert_eq!(dto.serialized_arguments.0["a"], "1");
316 }
317
318 struct TaskCCTask {
321 task_id: TaskId,
322 config: TaskConfig,
323 }
324 impl TaskCCTask {
325 fn new() -> Self {
326 let mut config = TaskConfig::default();
327 config.concurrency_control = ConcurrencyControlType::Task;
328 Self {
329 task_id: TaskId::new("test", "cc_task"),
330 config,
331 }
332 }
333 }
334 impl Task for TaskCCTask {
335 type Params = AddParams;
336 type Result = i32;
337 fn task_id(&self) -> &TaskId {
338 &self.task_id
339 }
340 fn config(&self) -> &TaskConfig {
341 &self.config
342 }
343 fn run(&self, p: AddParams) -> RustvelloResult<i32> {
344 Ok(p.x + p.y)
345 }
346 }
347
348 struct ArgCCTask {
349 task_id: TaskId,
350 config: TaskConfig,
351 }
352 impl ArgCCTask {
353 fn new() -> Self {
354 let mut config = TaskConfig::default();
355 config.concurrency_control = ConcurrencyControlType::Argument;
356 Self {
357 task_id: TaskId::new("test", "cc_arg"),
358 config,
359 }
360 }
361 }
362 impl Task for ArgCCTask {
363 type Params = AddParams;
364 type Result = i32;
365 fn task_id(&self) -> &TaskId {
366 &self.task_id
367 }
368 fn config(&self) -> &TaskConfig {
369 &self.config
370 }
371 fn run(&self, p: AddParams) -> RustvelloResult<i32> {
372 Ok(p.x + p.y)
373 }
374 }
375
376 struct KeyCCTask {
377 task_id: TaskId,
378 config: TaskConfig,
379 }
380 impl KeyCCTask {
381 fn new() -> Self {
382 let mut config = TaskConfig::default();
383 config.concurrency_control = ConcurrencyControlType::Argument;
384 config.key_arguments = vec!["x".to_string()];
385 Self {
386 task_id: TaskId::new("test", "cc_key"),
387 config,
388 }
389 }
390 }
391 impl Task for KeyCCTask {
392 type Params = AddParams;
393 type Result = i32;
394 fn task_id(&self) -> &TaskId {
395 &self.task_id
396 }
397 fn config(&self) -> &TaskConfig {
398 &self.config
399 }
400 fn run(&self, p: AddParams) -> RustvelloResult<i32> {
401 Ok(p.x + p.y)
402 }
403 }
404
405 struct NoneCCTask {
406 task_id: TaskId,
407 config: TaskConfig,
408 }
409 impl NoneCCTask {
410 fn new() -> Self {
411 let mut config = TaskConfig::default();
412 config.concurrency_control = ConcurrencyControlType::None;
413 Self {
414 task_id: TaskId::new("test", "cc_none"),
415 config,
416 }
417 }
418 }
419 impl Task for NoneCCTask {
420 type Params = AddParams;
421 type Result = i32;
422 fn task_id(&self) -> &TaskId {
423 &self.task_id
424 }
425 fn config(&self) -> &TaskConfig {
426 &self.config
427 }
428 fn run(&self, p: AddParams) -> RustvelloResult<i32> {
429 Ok(p.x + p.y)
430 }
431 }
432
433 #[test]
434 fn cc_args_unlimited_returns_none() {
435 let task = AddTask::new();
436 let call = Call::new(&task, AddParams { x: 1, y: 2 });
437 assert!(call
438 .serialized_args_for_concurrency_check()
439 .unwrap()
440 .is_none());
441 }
442
443 #[test]
444 fn cc_args_task_returns_empty() {
445 let task = TaskCCTask::new();
446 let call = Call::new(&task, AddParams { x: 1, y: 2 });
447 let args = call
448 .serialized_args_for_concurrency_check()
449 .unwrap()
450 .unwrap();
451 assert!(args.0.is_empty());
452 }
453
454 #[test]
455 fn cc_args_argument_returns_all() {
456 let task = ArgCCTask::new();
457 let call = Call::new(&task, AddParams { x: 1, y: 2 });
458 let args = call
459 .serialized_args_for_concurrency_check()
460 .unwrap()
461 .unwrap();
462 assert_eq!(args.0.len(), 2);
463 assert_eq!(args.0["x"], "1");
464 assert_eq!(args.0["y"], "2");
465 }
466
467 #[test]
468 fn cc_args_argument_with_key_args_returns_subset() {
469 let task = KeyCCTask::new();
470 let call = Call::new(&task, AddParams { x: 1, y: 2 });
471 let args = call
472 .serialized_args_for_concurrency_check()
473 .unwrap()
474 .unwrap();
475 assert_eq!(args.0.len(), 1);
476 assert_eq!(args.0["x"], "1");
477 assert!(!args.0.contains_key("y"));
478 }
479
480 #[test]
481 fn cc_args_none_returns_all() {
482 let task = NoneCCTask::new();
483 let call = Call::new(&task, AddParams { x: 1, y: 2 });
484 let args = call
485 .serialized_args_for_concurrency_check()
486 .unwrap()
487 .unwrap();
488 assert_eq!(args.0.len(), 2);
489 }
490}