Skip to main content

rustvello_proto/
call.rs

1use serde::{Deserialize, Serialize};
2use std::collections::BTreeMap;
3
4use crate::config::ArgumentPrintMode;
5use crate::identifiers::{CallId, TaskId};
6
7/// Serialized arguments for a task call.
8///
9/// Arguments are stored as a sorted map of key-value pairs where
10/// values are JSON-serialized strings. Sorting ensures deterministic
11/// hashing for deduplication.
12#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
13pub struct SerializedArguments(pub BTreeMap<String, String>);
14
15impl SerializedArguments {
16    pub fn new() -> Self {
17        Self(BTreeMap::new())
18    }
19
20    pub fn insert(&mut self, key: impl Into<String>, value: impl Into<String>) {
21        self.0.insert(key.into(), value.into());
22    }
23
24    /// Compute a deterministic hash of the serialized arguments.
25    /// Used to generate the `args_id` component of `CallId`.
26    /// Returns `"no_args"` for empty argument maps (matches pynenc convention).
27    pub fn compute_args_id(&self) -> String {
28        if self.0.is_empty() {
29            return "no_args".to_string();
30        }
31        use sha2::{Digest, Sha256};
32        let mut hasher = Sha256::new();
33        for (k, v) in &self.0 {
34            // Use JSON-escaped keys and values to prevent delimiter collisions
35            let ek = serde_json::to_string(k).unwrap_or_else(|_| k.clone());
36            let ev = serde_json::to_string(v).unwrap_or_else(|_| v.clone());
37            hasher.update(ek.as_bytes());
38            hasher.update(b"=");
39            hasher.update(ev.as_bytes());
40            hasher.update(b";");
41        }
42        format!("{:x}", hasher.finalize())
43    }
44}
45
46impl Default for SerializedArguments {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52impl SerializedArguments {
53    /// Convert arguments into individual `(key, value)` pairs for per-pair
54    /// concurrency-control indexing.
55    ///
56    /// When the argument map is present but empty, a sentinel `("", "")` pair
57    /// is returned so that `Some(empty_args)` is distinguishable from `None`
58    /// (which means CC is disabled for this invocation).
59    ///
60    /// Backend implementations should always call this method instead of
61    /// manually iterating `self.0` to ensure the sentinel convention is
62    /// applied consistently.
63    pub fn cc_arg_pairs(&self) -> Vec<(String, String)> {
64        if self.0.is_empty() {
65            vec![(String::new(), String::new())]
66        } else {
67            self.0.iter().map(|(k, v)| (k.clone(), v.clone())).collect()
68        }
69    }
70}
71
72impl SerializedArguments {
73    /// Compute a concurrency control key from optional arguments.
74    ///
75    /// Returns an empty string when `args` is `None` or empty,
76    /// otherwise returns the deterministic args hash.
77    pub fn cc_key(args: Option<&Self>) -> String {
78        match args {
79            None => String::new(),
80            Some(a) if a.0.is_empty() => String::new(),
81            Some(a) => a.compute_args_id(),
82        }
83    }
84
85    /// Format arguments for display using the given print mode.
86    pub fn display(&self, mode: ArgumentPrintMode, truncate_length: usize) -> String {
87        if self.0.is_empty() {
88            return "<no_args>".to_string();
89        }
90        match mode {
91            ArgumentPrintMode::Hidden => "<arguments hidden>".to_string(),
92            ArgumentPrintMode::Keys => {
93                let keys: Vec<&str> = self.0.keys().map(std::string::String::as_str).collect();
94                format!("{{{}}}", keys.join(", "))
95            }
96            ArgumentPrintMode::Full => {
97                let pairs: Vec<String> = self.0.iter().map(|(k, v)| format!("{k}: {v}")).collect();
98                format!("{{{}}}", pairs.join(", "))
99            }
100            ArgumentPrintMode::Truncated => {
101                let pairs: Vec<String> = self
102                    .0
103                    .iter()
104                    .map(|(k, v)| {
105                        if v.len() > truncate_length {
106                            // Safe UTF-8 truncation: find the last char boundary
107                            let end = v
108                                .char_indices()
109                                .nth(truncate_length)
110                                .map_or(v.len(), |(i, _)| i);
111                            format!("{k}: {}...", &v[..end])
112                        } else {
113                            format!("{k}: {v}")
114                        }
115                    })
116                    .collect();
117                format!("{{{}}}", pairs.join(", "))
118            }
119        }
120    }
121}
122
123/// A call represents a task with specific arguments, ready to be invoked.
124///
125/// This is the DTO form suitable for persistence and wire transfer.
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct CallDTO {
128    pub call_id: CallId,
129    pub task_id: TaskId,
130    pub serialized_arguments: SerializedArguments,
131}
132
133impl CallDTO {
134    pub fn new(task_id: TaskId, args: SerializedArguments) -> Self {
135        let args_id = args.compute_args_id();
136        let call_id = CallId::new(task_id.clone(), args_id);
137        Self {
138            call_id,
139            task_id,
140            serialized_arguments: args,
141        }
142    }
143}
144
145impl std::fmt::Display for CallDTO {
146    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147        let keys: Vec<&str> = self
148            .serialized_arguments
149            .0
150            .keys()
151            .map(std::string::String::as_str)
152            .collect();
153        write!(
154            f,
155            "Call(task={}, arguments=[{}])",
156            self.task_id,
157            keys.join(", ")
158        )
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165
166    #[test]
167    fn args_id_is_deterministic() {
168        let mut args1 = SerializedArguments::new();
169        args1.insert("x", "42");
170        args1.insert("y", "hello");
171
172        let mut args2 = SerializedArguments::new();
173        args2.insert("y", "hello");
174        args2.insert("x", "42");
175
176        // BTreeMap ensures sorted order, so same args_id regardless of insert order
177        assert_eq!(args1.compute_args_id(), args2.compute_args_id());
178    }
179
180    #[test]
181    fn different_args_different_id() {
182        let mut args1 = SerializedArguments::new();
183        args1.insert("x", "42");
184
185        let mut args2 = SerializedArguments::new();
186        args2.insert("x", "43");
187
188        assert_ne!(args1.compute_args_id(), args2.compute_args_id());
189    }
190
191    #[test]
192    fn empty_args_id() {
193        let args = SerializedArguments::new();
194        let id = args.compute_args_id();
195        assert!(!id.is_empty());
196        // Empty args should be deterministic
197        let args2 = SerializedArguments::default();
198        assert_eq!(id, args2.compute_args_id());
199    }
200
201    #[test]
202    fn call_dto_new() {
203        let task_id = TaskId::new("mod", "func");
204        let mut args = SerializedArguments::new();
205        args.insert("a", "1");
206        let call = CallDTO::new(task_id.clone(), args.clone());
207
208        assert_eq!(call.task_id, task_id);
209        assert_eq!(call.call_id.task_id, task_id);
210        assert_eq!(&*call.call_id.args_id, args.compute_args_id());
211        assert_eq!(call.serialized_arguments, args);
212    }
213
214    #[test]
215    fn serde_round_trip_call_dto() {
216        let task_id = TaskId::new("mod", "func");
217        let mut args = SerializedArguments::new();
218        args.insert("key", "val");
219        let call = CallDTO::new(task_id, args);
220
221        let json = serde_json::to_string(&call).unwrap();
222        let back: CallDTO = serde_json::from_str(&json).unwrap();
223        assert_eq!(back.call_id, call.call_id);
224        assert_eq!(back.task_id, call.task_id);
225        assert_eq!(back.serialized_arguments, call.serialized_arguments);
226    }
227
228    #[test]
229    fn args_id_no_delimiter_collision() {
230        // {"a": "b;c=d"} must differ from {"a": "b", "c": "d"}
231        let mut args1 = SerializedArguments::new();
232        args1.insert("a", "b;c=d");
233
234        let mut args2 = SerializedArguments::new();
235        args2.insert("a", "b");
236        args2.insert("c", "d");
237
238        assert_ne!(args1.compute_args_id(), args2.compute_args_id());
239    }
240
241    #[test]
242    fn truncated_display_safe_on_multibyte_utf8() {
243        let mut args = SerializedArguments::new();
244        // Each Japanese char is 3 bytes in UTF-8
245        args.insert("x", "日本語テスト");
246        // Truncate at 2 chars — should not panic
247        let result = args.display(ArgumentPrintMode::Truncated, 2);
248        assert!(result.contains("日本..."));
249    }
250}