Skip to main content

telltale_lean_bridge/
import.rs

1//! Import Lean JSON into Rust types.
2//!
3//! This module provides functions to parse JSON (from Lean output)
4//! back into GlobalType and LocalTypeR.
5
6use serde_json::Value;
7use telltale_types::{GlobalType, Label, LocalTypeR, PayloadSort, ValType};
8use thiserror::Error;
9
10/// Errors that can occur during JSON import.
11#[derive(Debug, Error)]
12pub enum ImportError {
13    #[error("Missing field: {0}")]
14    MissingField(String),
15
16    #[error("Invalid kind: {0}")]
17    InvalidKind(String),
18
19    #[error("Invalid sort: {0}")]
20    InvalidSort(String),
21
22    #[error("Expected string, got: {0}")]
23    ExpectedString(String),
24
25    #[error("Expected array, got: {0}")]
26    ExpectedArray(String),
27
28    #[error("Expected object, got: {0}")]
29    ExpectedObject(String),
30}
31
32fn required_field<'a>(json: &'a Value, field: &str) -> Result<&'a Value, ImportError> {
33    json.get(field)
34        .ok_or_else(|| ImportError::MissingField(field.to_string()))
35}
36
37fn required_str(json: &Value, field: &str) -> Result<String, ImportError> {
38    let value = required_field(json, field)?;
39    value
40        .as_str()
41        .map(ToString::to_string)
42        .ok_or_else(|| ImportError::ExpectedString(value.to_string()))
43}
44
45fn required_array<'a>(json: &'a Value, field: &str) -> Result<&'a [Value], ImportError> {
46    let value = required_field(json, field)?;
47    value
48        .as_array()
49        .map(Vec::as_slice)
50        .ok_or_else(|| ImportError::ExpectedArray(value.to_string()))
51}
52
53fn parse_global_comm(json: &Value) -> Result<GlobalType, ImportError> {
54    let sender = required_str(json, "sender")?;
55    let receiver = required_str(json, "receiver")?;
56    let branches = parse_global_branches(required_array(json, "branches")?)?;
57    Ok(GlobalType::Comm {
58        sender,
59        receiver,
60        branches,
61    })
62}
63
64fn parse_global_branches(branches: &[Value]) -> Result<Vec<(Label, GlobalType)>, ImportError> {
65    let mut parsed = Vec::with_capacity(branches.len());
66    for branch in branches {
67        let label = parse_label(
68            branch
69                .get("label")
70                .ok_or_else(|| ImportError::MissingField("label in branch".to_string()))?,
71        )?;
72        let cont = json_to_global(
73            branch
74                .get("continuation")
75                .ok_or_else(|| ImportError::MissingField("continuation in branch".to_string()))?,
76        )?;
77        parsed.push((label, cont));
78    }
79    Ok(parsed)
80}
81
82fn parse_global_rec(json: &Value) -> Result<GlobalType, ImportError> {
83    let var = required_str(json, "var")?;
84    let body = json_to_global(required_field(json, "body")?)?;
85    Ok(GlobalType::mu(var, body))
86}
87
88fn parse_global_var(json: &Value) -> Result<GlobalType, ImportError> {
89    Ok(GlobalType::var(required_str(json, "name")?))
90}
91
92/// Parse JSON into a GlobalType.
93///
94/// # Example
95///
96/// ```
97/// use telltale_lean_bridge::import::json_to_global;
98/// use serde_json::json;
99///
100/// let json = json!({ "kind": "end" });
101/// let g = json_to_global(&json).unwrap();
102/// assert!(matches!(g, telltale_types::GlobalType::End));
103/// ```
104pub fn json_to_global(json: &Value) -> Result<GlobalType, ImportError> {
105    let kind = json
106        .get("kind")
107        .and_then(|v| v.as_str())
108        .ok_or_else(|| ImportError::MissingField("kind".to_string()))?;
109
110    match kind {
111        "end" => Ok(GlobalType::End),
112        "comm" => parse_global_comm(json),
113        "rec" => parse_global_rec(json),
114        "var" => parse_global_var(json),
115        other => Err(ImportError::InvalidKind(other.to_string())),
116    }
117}
118
119fn parse_local_branches(
120    branches: &[Value],
121) -> Result<Vec<(Label, Option<ValType>, LocalTypeR)>, ImportError> {
122    let mut parsed = Vec::with_capacity(branches.len());
123    for branch in branches {
124        let label = parse_label(
125            branch
126                .get("label")
127                .ok_or_else(|| ImportError::MissingField("label in branch".to_string()))?,
128        )?;
129        let cont = json_to_local(
130            branch
131                .get("continuation")
132                .ok_or_else(|| ImportError::MissingField("continuation in branch".to_string()))?,
133        )?;
134        parsed.push((label, None, cont));
135    }
136    Ok(parsed)
137}
138
139fn parse_local_send(json: &Value) -> Result<LocalTypeR, ImportError> {
140    let partner = required_str(json, "partner")?;
141    let branches = parse_local_branches(required_array(json, "branches")?)?;
142    Ok(LocalTypeR::Send { partner, branches })
143}
144
145fn parse_local_recv(json: &Value) -> Result<LocalTypeR, ImportError> {
146    let partner = required_str(json, "partner")?;
147    let branches = parse_local_branches(required_array(json, "branches")?)?;
148    Ok(LocalTypeR::Recv { partner, branches })
149}
150
151fn parse_local_rec(json: &Value) -> Result<LocalTypeR, ImportError> {
152    let var = required_str(json, "var")?;
153    let body = json_to_local(required_field(json, "body")?)?;
154    Ok(LocalTypeR::mu(var, body))
155}
156
157fn parse_local_var(json: &Value) -> Result<LocalTypeR, ImportError> {
158    Ok(LocalTypeR::var(required_str(json, "name")?))
159}
160
161/// Parse JSON into a LocalTypeR.
162///
163/// # Example
164///
165/// ```
166/// use telltale_lean_bridge::import::json_to_local;
167/// use serde_json::json;
168///
169/// let json = json!({ "kind": "end" });
170/// let lt = json_to_local(&json).unwrap();
171/// assert!(matches!(lt, telltale_types::LocalTypeR::End));
172/// ```
173pub fn json_to_local(json: &Value) -> Result<LocalTypeR, ImportError> {
174    let kind = json
175        .get("kind")
176        .and_then(|v| v.as_str())
177        .ok_or_else(|| ImportError::MissingField("kind".to_string()))?;
178
179    match kind {
180        "end" => Ok(LocalTypeR::End),
181        "send" => parse_local_send(json),
182        "recv" => parse_local_recv(json),
183        "rec" => parse_local_rec(json),
184        "var" => parse_local_var(json),
185        other => Err(ImportError::InvalidKind(other.to_string())),
186    }
187}
188
189/// Parse a Label from JSON.
190fn parse_label(json: &Value) -> Result<Label, ImportError> {
191    let name = json
192        .get("name")
193        .and_then(|v| v.as_str())
194        .ok_or_else(|| ImportError::MissingField("name in label".to_string()))?
195        .to_string();
196
197    let sort = json
198        .get("sort")
199        .map(parse_sort)
200        .transpose()?
201        .unwrap_or(PayloadSort::Unit);
202
203    Ok(Label { name, sort })
204}
205
206/// Parse a PayloadSort from JSON.
207fn parse_sort(json: &Value) -> Result<PayloadSort, ImportError> {
208    if let Some(s) = json.as_str() {
209        match s {
210            "unit" => Ok(PayloadSort::Unit),
211            "nat" => Ok(PayloadSort::Nat),
212            "bool" => Ok(PayloadSort::Bool),
213            "string" => Ok(PayloadSort::String),
214            "real" => Ok(PayloadSort::Real),
215            other => Err(ImportError::InvalidSort(other.to_string())),
216        }
217    } else if let Some(obj) = json.as_object() {
218        if let Some(arr) = obj.get("prod").and_then(|v| v.as_array()) {
219            if arr.len() == 2 {
220                let left = parse_sort(&arr[0])?;
221                let right = parse_sort(&arr[1])?;
222                Ok(PayloadSort::Prod(Box::new(left), Box::new(right)))
223            } else {
224                Err(ImportError::InvalidSort(
225                    "prod must have 2 elements".to_string(),
226                ))
227            }
228        } else if let Some(n) = obj.get("vector").and_then(|v| v.as_u64()) {
229            let size = usize::try_from(n)
230                .map_err(|_| ImportError::InvalidSort("vector size exceeds usize".to_string()))?;
231            Ok(PayloadSort::Vector(size))
232        } else {
233            Err(ImportError::InvalidSort(format!("{:?}", obj)))
234        }
235    } else {
236        Err(ImportError::InvalidSort(format!("{}", json)))
237    }
238}
239
240/// Parse a GlobalType from a JSON string.
241pub fn parse_global_from_str(s: &str) -> Result<GlobalType, ImportError> {
242    let json: Value = serde_json::from_str(s)
243        .map_err(|_| ImportError::ExpectedObject("invalid JSON".to_string()))?;
244    json_to_global(&json)
245}
246
247/// Parse a LocalTypeR from a JSON string.
248pub fn parse_local_from_str(s: &str) -> Result<LocalTypeR, ImportError> {
249    let json: Value = serde_json::from_str(s)
250        .map_err(|_| ImportError::ExpectedObject("invalid JSON".to_string()))?;
251    json_to_local(&json)
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257    use serde_json::json;
258
259    #[test]
260    fn test_parse_global_end() {
261        let json = json!({ "kind": "end" });
262        let g = json_to_global(&json).unwrap();
263        assert!(matches!(g, GlobalType::End));
264    }
265
266    #[test]
267    fn test_parse_global_comm() {
268        let json = json!({
269            "kind": "comm",
270            "sender": "A",
271            "receiver": "B",
272            "branches": [{
273                "label": { "name": "msg", "sort": "unit" },
274                "continuation": { "kind": "end" }
275            }]
276        });
277        let g = json_to_global(&json).unwrap();
278
279        match g {
280            GlobalType::Comm {
281                sender,
282                receiver,
283                branches,
284            } => {
285                assert_eq!(sender, "A");
286                assert_eq!(receiver, "B");
287                assert_eq!(branches.len(), 1);
288            }
289            _ => panic!("Expected Comm"),
290        }
291    }
292
293    #[test]
294    fn test_parse_global_rec() {
295        let json = json!({
296            "kind": "rec",
297            "var": "X",
298            "body": {
299                "kind": "var",
300                "name": "X"
301            }
302        });
303        let g = json_to_global(&json).unwrap();
304
305        match g {
306            GlobalType::Mu { var, .. } => assert_eq!(var, "X"),
307            _ => panic!("Expected Mu"),
308        }
309    }
310
311    #[test]
312    fn test_parse_local_end() {
313        let json = json!({ "kind": "end" });
314        let lt = json_to_local(&json).unwrap();
315        assert!(matches!(lt, LocalTypeR::End));
316    }
317
318    #[test]
319    fn test_parse_local_send() {
320        let json = json!({
321            "kind": "send",
322            "partner": "B",
323            "branches": [{
324                "label": { "name": "hello", "sort": "string" },
325                "continuation": { "kind": "end" }
326            }]
327        });
328        let lt = json_to_local(&json).unwrap();
329
330        match lt {
331            LocalTypeR::Send { partner, branches } => {
332                assert_eq!(partner, "B");
333                assert_eq!(branches.len(), 1);
334                assert_eq!(branches[0].0.sort, PayloadSort::String);
335            }
336            _ => panic!("Expected Send"),
337        }
338    }
339
340    #[test]
341    fn test_parse_prod_sort() {
342        let json = json!({
343            "kind": "send",
344            "partner": "B",
345            "branches": [{
346                "label": {
347                    "name": "pair",
348                    "sort": { "prod": ["nat", "bool"] }
349                },
350                "continuation": { "kind": "end" }
351            }]
352        });
353        let lt = json_to_local(&json).unwrap();
354
355        match lt {
356            LocalTypeR::Send { branches, .. } => {
357                let sort = &branches[0].0.sort;
358                assert!(matches!(sort, PayloadSort::Prod(..)));
359            }
360            _ => panic!("Expected Send"),
361        }
362    }
363
364    #[test]
365    fn test_parse_real_sort() {
366        let json = json!({
367            "kind": "send",
368            "partner": "B",
369            "branches": [{
370                "label": { "name": "value", "sort": "real" },
371                "continuation": { "kind": "end" }
372            }]
373        });
374        let lt = json_to_local(&json).unwrap();
375        match lt {
376            LocalTypeR::Send { branches, .. } => {
377                assert_eq!(branches[0].0.sort, PayloadSort::Real);
378            }
379            _ => panic!("Expected Send"),
380        }
381    }
382
383    #[test]
384    fn test_parse_vector_sort() {
385        let json = json!({
386            "kind": "send",
387            "partner": "B",
388            "branches": [{
389                "label": { "name": "config", "sort": { "vector": 4 } },
390                "continuation": { "kind": "end" }
391            }]
392        });
393        let lt = json_to_local(&json).unwrap();
394        match lt {
395            LocalTypeR::Send { branches, .. } => {
396                assert_eq!(branches[0].0.sort, PayloadSort::Vector(4));
397            }
398            _ => panic!("Expected Send"),
399        }
400    }
401
402    #[test]
403    fn test_roundtrip() {
404        use crate::export::global_to_json;
405
406        let original = GlobalType::comm(
407            "Client",
408            "Server",
409            vec![
410                (Label::new("request"), GlobalType::End),
411                (Label::with_sort("data", PayloadSort::Nat), GlobalType::End),
412            ],
413        );
414
415        let json = global_to_json(&original);
416        let parsed = json_to_global(&json).unwrap();
417
418        // Compare structure
419        match (&original, &parsed) {
420            (
421                GlobalType::Comm {
422                    sender: s1,
423                    receiver: r1,
424                    branches: b1,
425                },
426                GlobalType::Comm {
427                    sender: s2,
428                    receiver: r2,
429                    branches: b2,
430                },
431            ) => {
432                assert_eq!(s1, s2);
433                assert_eq!(r1, r2);
434                assert_eq!(b1.len(), b2.len());
435            }
436            _ => panic!("Structure mismatch"),
437        }
438    }
439
440    #[test]
441    fn test_parse_global_reports_expected_string_for_sender() {
442        let json = json!({
443            "kind": "comm",
444            "sender": 1,
445            "receiver": "B",
446            "branches": []
447        });
448
449        let err = json_to_global(&json).expect_err("sender should require string");
450        assert!(matches!(err, ImportError::ExpectedString(_)));
451    }
452
453    #[test]
454    fn test_parse_global_reports_expected_array_for_branches() {
455        let json = json!({
456            "kind": "comm",
457            "sender": "A",
458            "receiver": "B",
459            "branches": { "bad": true }
460        });
461
462        let err = json_to_global(&json).expect_err("branches should require array");
463        assert!(matches!(err, ImportError::ExpectedArray(_)));
464    }
465
466    #[cfg(target_pointer_width = "32")]
467    #[test]
468    fn test_parse_vector_sort_rejects_overflow() {
469        let json = json!({
470            "kind": "send",
471            "partner": "B",
472            "branches": [{
473                "label": { "name": "config", "sort": { "vector": 4294967296u64 } },
474                "continuation": { "kind": "end" }
475            }]
476        });
477
478        let err = json_to_local(&json).expect_err("vector size should overflow on 32-bit");
479        assert!(matches!(err, ImportError::InvalidSort(_)));
480    }
481}