Skip to main content

tlauc/
lib.rs

1mod strmeasure;
2use crate::strmeasure::*;
3
4use serde::{Deserialize, Deserializer};
5use std::ops::Range;
6use streaming_iterator::StreamingIterator;
7use tree_sitter::{Node, Parser, Query, QueryCursor, Tree, TreeCursor};
8
9pub enum Mode {
10    AsciiToUnicode,
11    UnicodeToAscii,
12}
13
14#[derive(Debug)]
15pub enum TlaError {
16    InputFileParseError {
17        parse_tree: Tree,
18        error_lines: Vec<usize>,
19    },
20    OutputFileParseError {
21        output_tree: Tree,
22        output: String,
23    },
24    InvalidTranslationError {
25        input_tree: Tree,
26        output_tree: Tree,
27        output: String,
28        first_diff: String,
29    },
30}
31
32pub fn rewrite(
33    input: &str,
34    mode: &Mode,
35    force: bool,
36    should_translate: impl Fn(&SymbolMapping) -> bool,
37) -> Result<String, TlaError> {
38    let mut parser = Parser::new();
39    parser
40        .set_language(&tree_sitter_tlaplus::LANGUAGE.into())
41        .expect("Error loading TLA⁺ grammar");
42    let mut cursor = QueryCursor::new();
43
44    // Parse input TLA⁺ file and construct data structures to hold information about it
45    let input_tree = parser.parse(input, None).unwrap();
46    if !force && input_tree.root_node().has_error() {
47        let error_lines = find_error_lines(&input_tree);
48        return Err(TlaError::InputFileParseError {
49            parse_tree: input_tree,
50            error_lines,
51        });
52    }
53
54    let mut tla_lines = TlaLine::construct_from(input);
55
56    // Identify & replace symbols
57    mark_jlists(&input_tree, &mut cursor, &mut tla_lines);
58    mark_symbols(
59        &input_tree,
60        &mut cursor,
61        &mut tla_lines,
62        mode,
63        should_translate,
64    );
65    //println!("{:#?}", tla_lines);
66    replace_symbols(&mut tla_lines);
67
68    // if the input ends with '\n', we should put the '\n' back to output
69    let extra_newline = input
70        .chars()
71        .last()
72        .map_or("", |x| if x == '\n' { "\n" } else { "" });
73
74    // Ensure output parse tree is identical to input parse tree
75    let output = TlaLine::output_from_lines(&tla_lines, &extra_newline);
76
77    let output_tree = parser.parse(&output, None).unwrap();
78    if !force {
79        if output_tree.root_node().has_error() {
80            return Err(TlaError::OutputFileParseError {
81                output_tree,
82                output,
83            });
84        }
85        if let Err(first_diff) = compare_parse_trees(&input_tree, &output_tree) {
86            return Err(TlaError::InvalidTranslationError {
87                input_tree,
88                output_tree,
89                output,
90                first_diff,
91            });
92        }
93    }
94
95    Ok(output)
96}
97
98fn find_error_lines(tree: &Tree) -> Vec<usize> {
99    let mut error_lines: Vec<usize> = vec![];
100    traverse_parse_tree(tree, |n| {
101        if n.is_error() || n.is_missing() {
102            error_lines.push(n.start_position().row + 1);
103        }
104    });
105    error_lines
106}
107
108fn traverse_parse_tree<F>(tree: &Tree, mut visit: F)
109where
110    F: FnMut(Node),
111{
112    let mut cursor: TreeCursor = tree.walk();
113    loop {
114        // Every time a new node is found the control flow passes here
115        visit(cursor.node());
116        // Descend as far as possible
117        if !cursor.goto_first_child() {
118            loop {
119                // Attempt to go to sibling
120                if cursor.goto_next_sibling() {
121                    // If sibling exists, break out into descent loop
122                    break;
123                } else {
124                    // If sibling does not exist, go to parent, then
125                    // parent's sibling in next loop iteration
126                    if !cursor.goto_parent() {
127                        // If parent does not exist, we are done
128                        return;
129                    }
130                }
131            }
132        }
133    }
134}
135
136fn compare_parse_trees(input_tree: &Tree, output_tree: &Tree) -> Result<(), String> {
137    let mut input_cursor: TreeCursor = input_tree.walk();
138    let mut output_cursor: TreeCursor = output_tree.walk();
139
140    loop {
141        check_node_equality(&input_cursor, &output_cursor)?;
142        if !simultaneous_step(&mut input_cursor, &mut output_cursor, |c| {
143            c.goto_first_child()
144        })? {
145            loop {
146                if !simultaneous_step(&mut input_cursor, &mut output_cursor, |c| {
147                    c.goto_next_sibling()
148                })? {
149                    if !simultaneous_step(&mut input_cursor, &mut output_cursor, |c| {
150                        c.goto_parent()
151                    })? {
152                        return Ok(());
153                    }
154                } else {
155                    break;
156                }
157            }
158        }
159    }
160}
161
162fn simultaneous_step(
163    input_cursor: &mut TreeCursor,
164    output_cursor: &mut TreeCursor,
165    step: fn(&mut TreeCursor) -> bool,
166) -> Result<bool, String> {
167    let (input_next, output_next) = (step(input_cursor), step(output_cursor));
168    if input_next != output_next {
169        return Err(format!(
170            "First diff: Input {:?} Output {:?}",
171            input_cursor.node(),
172            output_cursor.node()
173        ));
174    }
175
176    Ok(input_next)
177}
178
179fn check_node_equality(
180    input_cursor: &TreeCursor,
181    output_cursor: &TreeCursor,
182) -> Result<(), String> {
183    if (input_cursor.node().is_named() || output_cursor.node().is_named())
184        && input_cursor.node().kind() != output_cursor.node().kind()
185    {
186        return Err(format!(
187            "First diff: Input {:?} Output {:?}",
188            input_cursor.node(),
189            output_cursor.node()
190        ));
191    }
192
193    Ok(())
194}
195
196#[derive(Debug, Deserialize)]
197pub struct SymbolMapping {
198    #[serde(rename = "Name")]
199    pub name: String,
200    #[serde(
201        rename = "ASCII",
202        deserialize_with = "vec_from_semicolon_separated_str"
203    )]
204    pub ascii: Vec<String>,
205    #[serde(rename = "Unicode")]
206    pub unicode: String,
207}
208
209impl SymbolMapping {
210    pub fn canonical_ascii(&self) -> &str {
211        self.ascii.first().unwrap()
212    }
213
214    pub fn ascii_query(&self) -> String {
215        let query = self
216            .ascii
217            .iter()
218            .map(|a| a.replace('\\', "\\\\"))
219            .map(|a| format!("\"{}\"", a))
220            .reduce(|a, b| a + " " + &b)
221            .unwrap();
222        let name = &self.name;
223        format!("({name} [{query}] @{name})")
224    }
225
226    pub fn unicode_query(&self) -> String {
227        let name = &self.name;
228        let unicode = &self.unicode;
229        format!("({name} \"{unicode}\" @{name})")
230    }
231
232    fn target_symbol(&self, mode: &Mode) -> &str {
233        match mode {
234            Mode::AsciiToUnicode => &self.unicode,
235            Mode::UnicodeToAscii => self.canonical_ascii(),
236        }
237    }
238
239    fn source_query(&self, mode: &Mode) -> String {
240        match mode {
241            Mode::AsciiToUnicode => self.ascii_query(),
242            Mode::UnicodeToAscii => self.unicode_query(),
243        }
244    }
245
246    fn chars_added(&self, mode: &Mode, src_symbol: &str) -> CharDiff {
247        match mode {
248            Mode::AsciiToUnicode => {
249                CharQuantity(self.unicode.chars().count())
250                    - CharQuantity(src_symbol.chars().count())
251            }
252            Mode::UnicodeToAscii => {
253                CharQuantity(self.canonical_ascii().chars().count())
254                    - CharQuantity(self.unicode.chars().count())
255            }
256        }
257    }
258}
259
260fn vec_from_semicolon_separated_str<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
261where
262    D: Deserializer<'de>,
263{
264    let s: &str = Deserialize::deserialize(deserializer)?;
265    Ok(s.split(';').map(|s| s.to_string()).collect())
266}
267
268pub fn get_unicode_mappings() -> Vec<SymbolMapping> {
269    let csv = include_str!("../resources/tla-unicode.csv");
270    let mut reader = csv::Reader::from_reader(csv.as_bytes());
271    reader.deserialize().map(|result| result.unwrap()).collect()
272}
273
274#[derive(Debug)]
275struct TlaLine {
276    text: String,
277    jlists: Vec<JList>,
278    symbols: Vec<Symbol>,
279}
280
281impl TlaLine {
282    fn construct_from(input: &str) -> Vec<Self> {
283        input
284            .lines()
285            .map(|line| TlaLine {
286                jlists: Vec::new(),
287                symbols: Vec::new(),
288                text: line.to_string(),
289            })
290            .collect()
291    }
292
293    // same as join("\n") + extra,
294    // but to avoid unnecessary the reallocation,
295    // ref: https://doc.rust-lang.org/src/alloc/slice.rs.html#787
296    fn output_from_lines(tla_lines: &Vec<Self>, extra: &str) -> String {
297        let mut iter = tla_lines.iter();
298        let first = match iter.next() {
299            Some(first) => first,
300            None => return extra.to_string(),
301        };
302        let text_size = tla_lines.iter().map(|v| v.text.len()).sum::<usize>();
303        // Note: tla_lines.len() > 0 is always true
304        let size = text_size + tla_lines.len() - 1 + extra.len();
305        let mut result = String::with_capacity(size);
306        result.push_str(&first.text);
307        for v in iter {
308            result.push('\n');
309            result.push_str(&v.text);
310        }
311        result.push_str(extra);
312        result
313    }
314
315    fn shift_jlists(&mut self, &diff: &CharDiff, &start_index: &CharQuantity) {
316        for jlist in &mut self.jlists {
317            if jlist.column > start_index {
318                jlist.column = jlist.column + diff;
319            }
320        }
321    }
322
323    fn shift_symbols(&mut self, diff: &StrElementDiff, start_index: &StrElementQuantity) {
324        for symbol in &mut self.symbols {
325            if symbol.src_range.start.byte >= start_index.byte {
326                symbol.src_range.start.byte = symbol.src_range.start.byte + diff.byte;
327                symbol.src_range.end.byte = symbol.src_range.end.byte + diff.byte;
328            }
329            if symbol.src_range.start.char >= start_index.char {
330                symbol.src_range.start.char = symbol.src_range.start.char + diff.char;
331                symbol.src_range.end.char = symbol.src_range.end.char + diff.char;
332            }
333        }
334    }
335}
336
337#[derive(Debug)]
338struct JList {
339    column: CharQuantity,
340    bullet_line_offsets: Vec<usize>,
341    terminating_infix_op_offset: Option<InfixOp>,
342}
343
344#[derive(Debug)]
345struct InfixOp {
346    line_offset: usize,
347    column: CharQuantity,
348}
349
350impl JList {
351    fn query() -> Query {
352        Query::new(
353            &tree_sitter_tlaplus::LANGUAGE.into(),
354            "[(conj_list) (disj_list)] @jlist",
355        )
356        .unwrap()
357    }
358
359    fn terminating_infix_op_query() -> Query {
360        Query::new(
361            &tree_sitter_tlaplus::LANGUAGE.into(),
362            "(bound_infix_op lhs: [(conj_list) (disj_list)]) @capture",
363        )
364        .unwrap()
365    }
366
367    fn is_jlist_item_node(cursor: &TreeCursor) -> bool {
368        "conj_item" == cursor.node().kind() || "disj_item" == cursor.node().kind()
369    }
370}
371
372fn mark_jlists(tree: &Tree, query_cursor: &mut QueryCursor, tla_lines: &mut [TlaLine]) {
373    let mut tree_cursor: TreeCursor = tree.walk();
374    let query = JList::query();
375    let mut captures = query_cursor.matches(&query, tree.root_node(), "".as_bytes());
376    while let Some(capture) = captures.next() {
377        let node = capture.captures[0].node;
378        let start_line = node.start_position().row;
379        let line = &mut tla_lines[start_line];
380        let column =
381            CharQuantity::from_byte_index(&ByteQuantity(node.start_position().column), &line.text);
382        let mut jlist = JList {
383            column,
384            bullet_line_offsets: Vec::new(),
385            terminating_infix_op_offset: None,
386        };
387        tree_cursor.reset(node);
388        tree_cursor.goto_first_child();
389        while {
390            if JList::is_jlist_item_node(&tree_cursor) {
391                jlist
392                    .bullet_line_offsets
393                    .push(tree_cursor.node().start_position().row - start_line);
394            }
395
396            tree_cursor.goto_next_sibling()
397        } {}
398
399        line.jlists.push(jlist);
400    }
401
402    let query = JList::terminating_infix_op_query();
403    let mut captures = query_cursor.matches(&query, tree.root_node(), "".as_bytes());
404    while let Some(capture) = captures.next() {
405        let infix_op_node = capture.captures[0].node;
406        let jlist_node = infix_op_node.child_by_field_name("lhs").unwrap();
407        let jlist_start_line_index = jlist_node.start_position().row;
408        let (prefix, suffix) = tla_lines.split_at_mut(jlist_start_line_index + 1);
409        let jlist_start_line = &mut prefix[jlist_start_line_index];
410        let jlist_column = CharQuantity::from_byte_index(
411            &ByteQuantity(jlist_node.start_position().column),
412            &jlist_start_line.text,
413        );
414        let jlist = jlist_start_line
415            .jlists
416            .iter_mut()
417            .find(|j| j.column == jlist_column)
418            .unwrap();
419        let symbol_node = infix_op_node.child_by_field_name("symbol").unwrap();
420        let symbol_line_offset = symbol_node.start_position().row - jlist_start_line_index;
421        let symbol_line = &suffix[symbol_line_offset - 1];
422        let symbol_column = ByteQuantity(symbol_node.start_position().column);
423        jlist.terminating_infix_op_offset = Some(InfixOp {
424            line_offset: symbol_line_offset,
425            column: CharQuantity::from_byte_index(&symbol_column, &symbol_line.text),
426        });
427    }
428}
429
430#[derive(Debug)]
431struct Symbol {
432    diff: CharDiff,
433    src_range: Range<StrElementQuantity>,
434    target: String,
435}
436
437fn mark_symbols(
438    tree: &Tree,
439    cursor: &mut QueryCursor,
440    tla_lines: &mut [TlaLine],
441    mode: &Mode,
442    should_translate: impl Fn(&SymbolMapping) -> bool,
443) {
444    let mappings: Vec<SymbolMapping> = get_unicode_mappings()
445        .into_iter()
446        .filter(should_translate)
447        .collect();
448    let queries = &mappings
449        .iter()
450        .map(|s| s.source_query(mode))
451        .collect::<Vec<String>>()
452        .join("");
453    let query = Query::new(&tree_sitter_tlaplus::LANGUAGE.into(), queries).unwrap();
454
455    let mut captures = cursor.matches(&query, tree.root_node(), "".as_bytes());
456    while let Some(capture) = captures.next() {
457        let capture = capture.captures[0];
458        let mapping = &mappings[capture.index as usize];
459        let start_position = capture.node.start_position();
460        let end_position = capture.node.end_position();
461        assert!(start_position.row == end_position.row);
462        let line = &mut tla_lines[start_position.row];
463        let src_range =
464            StrElementQuantity::from_byte_index(&ByteQuantity(start_position.column), &line.text)
465                ..StrElementQuantity::from_byte_index(
466                    &ByteQuantity(end_position.column),
467                    &line.text,
468                );
469        let src_symbol = &line.text[StrElementQuantity::as_byte_range(&src_range)];
470        let target = mapping.target_symbol(mode).to_string();
471        line.symbols.push(Symbol {
472            diff: mapping.chars_added(mode, src_symbol),
473            src_range,
474            target,
475        });
476    }
477}
478
479fn replace_symbols(tla_lines: &mut [TlaLine]) {
480    for line_number in 0..tla_lines.len().saturating_add_signed(-1) {
481        let (prefix, suffix) = tla_lines.split_at_mut(line_number + 1);
482        let line = &mut prefix[line_number];
483        while let Some(symbol) = line.symbols.pop() {
484            line.text.replace_range(
485                StrElementQuantity::as_byte_range(&symbol.src_range),
486                &symbol.target,
487            );
488            line.shift_jlists(&symbol.diff, &symbol.src_range.start.char);
489            fix_alignment(line, suffix, &symbol.diff, &symbol.src_range.start);
490        }
491    }
492}
493
494fn fix_alignment(
495    line: &mut TlaLine,
496    suffix: &mut [TlaLine],
497    &diff: &CharDiff,
498    symbol_start_index: &StrElementQuantity,
499) {
500    // If there was no net change in character count, there is no need to fix alignment
501    if diff == CharDiff(0) {
502        return;
503    }
504
505    // Recursively fix alignment of all jlist bullets
506    for jlist in &mut line.jlists {
507        // Ignore jlists starting before the index of modification in this line
508        if jlist.column <= symbol_start_index.char {
509            continue;
510        }
511
512        // Add or remove spaces from the start of the line for each bullet in this jlist
513        let mod_index = StrElementQuantity {
514            char: CharQuantity(0),
515            byte: ByteQuantity(0),
516        };
517        for &line_offset in &jlist.bullet_line_offsets {
518            // Alignment of first element of jlist was already changed by original modification
519            if line_offset == 0 {
520                continue;
521            }
522
523            let (suffix_prefix, suffix_suffix) = suffix.split_at_mut(line_offset);
524            let bullet_line = &mut suffix_prefix[line_offset - 1];
525            let bullet_column = jlist.column - diff;
526            pad(bullet_line, &diff, &mod_index, &bullet_column);
527
528            // Recursively fix alignment of any jlists starting on this line
529            fix_alignment(bullet_line, suffix_suffix, &diff, &mod_index);
530        }
531
532        // Fix alignment of terminating infix op for this jlist, if it exists
533        if let Some(infix_op_offset) = &mut jlist.terminating_infix_op_offset {
534            let (suffix_prefix, suffix_suffix) = suffix.split_at_mut(infix_op_offset.line_offset);
535            let infix_op_line = &mut suffix_prefix[infix_op_offset.line_offset - 1];
536            let diff = pad(infix_op_line, &diff, &mod_index, &infix_op_offset.column);
537            infix_op_offset.column = infix_op_offset.column + diff;
538            fix_alignment(infix_op_line, suffix_suffix, &diff, &mod_index);
539        }
540    }
541}
542
543fn pad(
544    line: &mut TlaLine,
545    &diff: &CharDiff,
546    mod_index: &StrElementQuantity,
547    &first_symbol_index: &CharQuantity,
548) -> CharDiff {
549    if diff < CharDiff(0) {
550        // Calculate min to ensure we don't move a symbol to before the end of the line
551        let spaces_to_remove = CharQuantity::min(diff.magnitude(), first_symbol_index);
552        let bytes_to_remove = ByteQuantity::from_char_index(&spaces_to_remove, &line.text);
553        line.text.drain(bytes_to_remove.range_to());
554        let diff = StrElementDiff {
555            char: mod_index.char - spaces_to_remove,
556            byte: mod_index.byte - bytes_to_remove,
557        };
558        line.shift_jlists(&diff.char, &mod_index.char);
559        line.shift_symbols(&diff, &mod_index);
560        diff.char
561    } else {
562        let spaces_to_add = diff.magnitude();
563        line.text.insert_str(0, &spaces_to_add.repeat(" "));
564        let spaces_added_in_bytes = ByteQuantity::from_char_index(&spaces_to_add, &line.text);
565        let diff = StrElementDiff {
566            char: diff,
567            byte: spaces_added_in_bytes - mod_index.byte,
568        };
569        line.shift_jlists(&diff.char, &mod_index.char);
570        line.shift_symbols(&diff, &mod_index);
571        diff.char
572    }
573}
574
575#[cfg(test)]
576mod tests {
577    use super::*;
578    use std::iter::zip;
579
580    fn check_ascii_replaced(text: &str) {
581        let mut parser = Parser::new();
582        parser
583            .set_language(&tree_sitter_tlaplus::LANGUAGE.into())
584            .unwrap();
585        let tree = parser.parse(&text, None).unwrap();
586        assert!(!tree.root_node().has_error());
587        let mut cursor = QueryCursor::new();
588        let queries = get_unicode_mappings()
589            .iter()
590            .map(|s| s.ascii_query())
591            .collect::<Vec<String>>()
592            .join("");
593        let query = Query::new(&tree_sitter_tlaplus::LANGUAGE.into(), &queries).unwrap();
594        assert!(cursor
595            .matches(&query, tree.root_node(), "".as_bytes())
596            .is_done());
597    }
598
599    fn unwrap_conversion(input: Result<String, TlaError>) -> String {
600        match input {
601            Ok(converted) => converted,
602            Err(TlaError::InputFileParseError {
603                parse_tree,
604                error_lines,
605            }) => {
606                panic!("{:?}\n{}", error_lines, parse_tree.root_node().to_sexp())
607            }
608            Err(TlaError::OutputFileParseError {
609                output_tree,
610                output,
611            }) => {
612                panic!("{}\n{}", output, output_tree.root_node().to_sexp())
613            }
614            Err(TlaError::InvalidTranslationError {
615                input_tree: _,
616                output_tree: _,
617                output: _,
618                first_diff,
619            }) => {
620                panic!("{}", first_diff)
621            }
622        }
623    }
624
625    fn run_roundtrip_test(expected: &str) {
626        let intermediate =
627            unwrap_conversion(rewrite(expected, &Mode::AsciiToUnicode, false, |_| true));
628        check_ascii_replaced(&intermediate);
629        let actual =
630            unwrap_conversion(rewrite(&intermediate, &Mode::UnicodeToAscii, false, |_| {
631                true
632            }));
633        assert_eq!(
634            expected, actual,
635            "\nExpected:\n{}\nActual:\n{}",
636            expected, actual
637        );
638    }
639
640    #[test]
641    fn basic_roundtrip() {
642        run_roundtrip_test(
643            r#"
644---- MODULE Test ----
645op == \A n \in Nat: n >= 0
646===="#,
647        );
648    }
649
650    #[test]
651    fn all_canonical_symbols_roundtrip() {
652        run_roundtrip_test(
653            r#"
654---- MODULE Test ----
655op == \A n \in Nat : \E r \in Real : ~(n = r)
656op == {x \in R : TRUE}
657op == INSTANCE Module WITH x <- y
658op == [n \in Nat |-> n]
659op == [Nat -> Real]
660op == <<1,2,3>>
661op == <<<>F>>_vars
662op == CASE A -> B [] C -> D [] OTHER -> E
663op == label :: []P => Q
664op == A -+-> B \equiv C <=> D ~> E /\ F \/ G
665op == A := B ::= C /= D <= E >= F \approx G
666op == A |- B |= C -| D =| E \asymp F \cong G
667op == A \doteq B \gg C \ll D \in E \notin F \prec G
668op == A \succ B \preceq C \succeq D \propto E \sim F \simeq G
669op == A \sqsubset B \sqsupset C \sqsubseteq D \sqsupseteq E
670op == A \subset B \supset C \subseteq D \supseteq E
671op == A \intersect B \union C .. D ... E (+) F (-) G
672op == A || B (.) C (/) D (\X) E \bigcirc F \bullet G
673op == A \div B \o C \star D !! E ?? F \sqcap G
674op == A \sqcup B \uplus C \X D \wr E \cdot F ^+
675===="#,
676        );
677    }
678
679    #[test]
680    fn all_non_canonical_symbols_roundtrip() {
681        let expected = r#"
682---- MODULE Test ----
683op == \forall n \in Nat : TRUE
684op == \exists r \in Real : TRUE
685op == \neg P
686op == P \land Q
687op == P \lor Q
688op == x # y
689op == x =< y
690op == x \leq y
691op == x \geq y
692op == P \cap Q
693op == P \cup Q
694op == x \oplus y
695op == x \ominus y
696op == x \odot y
697op == x \oslash y
698op == x \otimes y
699op == x \circ y
700op == P \times Q
701===="#;
702        let intermediate =
703            unwrap_conversion(rewrite(expected, &Mode::AsciiToUnicode, false, |_| true));
704        check_ascii_replaced(&intermediate);
705        let actual =
706            unwrap_conversion(rewrite(&intermediate, &Mode::UnicodeToAscii, false, |_| {
707                true
708            }));
709        // Only first and last lines should be the same
710        for (i, (expected_line, actual_line)) in zip(expected.lines(), actual.lines()).enumerate() {
711            if i <= 1 || i == expected.lines().count() - 1 {
712                assert_eq!(expected_line, actual_line);
713            } else {
714                assert_ne!(expected_line, actual_line);
715            }
716        }
717    }
718
719    #[test]
720    fn test_basic_jlist() {
721        run_roundtrip_test(
722            r#"
723---- MODULE Test ----
724op == /\ A
725      /\ B
726      /\ C
727      /\ D
728===="#,
729        );
730    }
731
732    #[test]
733    fn test_nested_jlist() {
734        run_roundtrip_test(
735            r#"
736---- MODULE Test ----
737op == /\ A
738      /\ \/ B 
739         \/ C
740      /\ D
741===="#,
742        );
743    }
744
745    #[test]
746    fn test_full_binary_tree_jlist() {
747        run_roundtrip_test(
748            r#"
749---- MODULE Test ----
750op == /\ \/ /\ \/ /\ A
751                  /\ B
752               \/ /\ C
753                  /\ D
754            /\ \/ /\ E
755                  /\ F
756               \/ /\ G
757                  /\ H
758         \/ /\ \/ /\ I
759                  /\ J
760               \/ /\ K
761                  /\ L
762            /\ \/ /\ M
763                  /\ N
764               \/ /\ O
765                  /\ P
766      /\ \/ /\ \/ /\ Q
767                  /\ R
768               \/ /\ S
769                  /\ T
770            /\ \/ /\ U
771                  /\ V
772               \/ /\ W
773                  /\ X
774         \/ /\ \/ /\ Y
775                  /\ Z
776               \/ /\ A
777                  /\ B
778            /\ \/ /\ C
779                  /\ D
780               \/ /\ E
781                  /\ F
782===="#,
783        );
784    }
785
786    #[test]
787    fn jlist_with_comments() {
788        run_roundtrip_test(
789            r#"
790---- MODULE Test ----
791op == /\ A
792      /\ \/ B 
793\* This is a comment
794         \/ C
795(* This is another comment *)
796      /\ D
797===="#,
798        );
799    }
800
801    #[test]
802    fn test_aligned_trailing_infix_op() {
803        run_roundtrip_test(
804            r#"
805---- MODULE Test ----
806op == /\ A
807      /\ B
808      => C
809===="#,
810        );
811    }
812
813    #[test]
814    fn test_trailing_infix_op_at_line_start() {
815        let expected = r#"
816---- MODULE Test ----
817op == /\ A
818      /\ B
819=> C
820===="#;
821        let intermediate =
822            unwrap_conversion(rewrite(expected, &Mode::AsciiToUnicode, false, |_| true));
823        check_ascii_replaced(&intermediate);
824        unwrap_conversion(rewrite(&intermediate, &Mode::UnicodeToAscii, false, |_| {
825            true
826        }));
827    }
828
829    #[test]
830    fn test_nested_trailing_infix_op() {
831        let expected = r#"
832---- MODULE Test ----
833op == /\ A
834      /\ B
835=> /\ C
836   /\ \/ D
837      \/ E
838      => /\ F
839         /\ G
840 => H
841op == A <=> /\ B
842            /\ C
843 => D
844===="#;
845        let intermediate =
846            unwrap_conversion(rewrite(expected, &Mode::AsciiToUnicode, false, |_| true));
847        check_ascii_replaced(&intermediate);
848        unwrap_conversion(rewrite(&intermediate, &Mode::UnicodeToAscii, false, |_| {
849            true
850        }));
851    }
852
853    #[test]
854    fn test_misaligned_jlist() {
855        run_roundtrip_test(
856            r#"
857---- MODULE Test ----
858op == /\ A
859     /\ B
860     /\ C
861===="#,
862        );
863    }
864
865    // See https://github.com/tlaplus-community/tlauc/issues/11
866    // Test translation of number sets in their three forms:
867    //  1. As an expression
868    //  2. As the left-hand-side of an operator definition
869    //  3. As a reference to an imported module
870    #[test]
871    fn test_translate_number_set() {
872        run_roundtrip_test(
873            r#"
874---- MODULE Test ----
875Nat == Nat \union A!B!Nat
876Int == Int \union A!B!Int
877Real == Real \union A!B!Real
878===="#,
879        );
880    }
881
882    // https://github.com/tlaplus-community/tlauc/issues/1
883    #[ignore]
884    #[test]
885    fn test_infix_op_jlist_from_unicode() {
886        run_roundtrip_test(
887            r#"
888---- MODULE Test ----
889op ≜ ∧ A
890     ∧ B
891      = C
892     ∧ D
893      = E
894===="#,
895        );
896    }
897
898    // https://github.com/tlaplus-community/tlauc/issues/2
899    #[ignore]
900    #[test]
901    fn test_block_comments_prefixing_jlist_items() {
902        run_roundtrip_test(
903            r#"
904---- MODULE Test ----
905op == /\ A
906(***) /\ \/ B
907(******) \/ C
908(***) => D
909===="#,
910        );
911    }
912
913    // Tests that file ends with newline (or without newline)
914    #[test]
915    fn test_empty_input() {
916        let input = "";
917        let output = rewrite(&input, &Mode::UnicodeToAscii, true, |_| true);
918        assert_eq!(input, output.unwrap());
919        let output = rewrite(&input, &Mode::AsciiToUnicode, true, |_| true);
920        assert_eq!(input, output.unwrap());
921    }
922
923    #[test]
924    fn test_single_newline() {
925        let input = "\n";
926        let output = rewrite(&input, &Mode::UnicodeToAscii, true, |_| true);
927        assert_eq!(input, output.unwrap());
928        let output = rewrite(&input, &Mode::AsciiToUnicode, true, |_| true);
929        assert_eq!(input, output.unwrap());
930    }
931
932    #[test]
933    fn test_normal_input_without_newline() {
934        run_roundtrip_test(
935            r#"
936---- MODULE Test ----
937op == 1
938===="#,
939        );
940    }
941
942    #[test]
943    fn test_normal_input_with_newline() {
944        run_roundtrip_test(
945            r#"
946---- MODULE Test ----
947op == 1
948====
949"#,
950        );
951    }
952
953    #[test]
954    fn test_translate_to_unicode_with_skip_filter() {
955        let input = r#"
956---- MODULE Test ----
957op ==
958  /\ x <= y
959  /\ Nat
960  /\ \/ Int
961     \/ Real
962===="#;
963        let output = rewrite(&input, &Mode::AsciiToUnicode, true, |s| {
964            ["def_eq"].contains(&s.name.as_str())
965        })
966        .unwrap();
967        assert_eq!(
968            r#"
969---- MODULE Test ----
970op ≜
971  /\ x <= y
972  /\ Nat
973  /\ \/ Int
974     \/ Real
975===="#,
976            output
977        );
978        let output = rewrite(&input, &Mode::AsciiToUnicode, true, |s| {
979            ["leq"].contains(&s.name.as_str())
980        })
981        .unwrap();
982        assert_eq!(
983            r#"
984---- MODULE Test ----
985op ==
986  /\ x ≤ y
987  /\ Nat
988  /\ \/ Int
989     \/ Real
990===="#,
991            output
992        );
993        let output = rewrite(&input, &Mode::AsciiToUnicode, true, |s| {
994            ["nat_number_set", "int_number_set", "real_number_set"].contains(&s.name.as_str())
995        })
996        .unwrap();
997        assert_eq!(
998            r#"
999---- MODULE Test ----
1000op ==
1001  /\ x <= y
1002  /\ ℕ
1003  /\ \/ ℤ
1004     \/ ℝ
1005===="#,
1006            output
1007        );
1008        let output = rewrite(&input, &Mode::AsciiToUnicode, true, |s| {
1009            ["def_eq", "leq", "bullet_conj", "bullet_disj"].contains(&s.name.as_str())
1010        })
1011        .unwrap();
1012        assert_eq!(
1013            r#"
1014---- MODULE Test ----
1015op ≜
1016  ∧ x ≤ y
1017  ∧ Nat
1018  ∧ ∨ Int
1019    ∨ Real
1020===="#,
1021            output
1022        );
1023    }
1024
1025    #[test]
1026    fn test_translate_to_ascii_with_skip_filter() {
1027        let input = r#"
1028---- MODULE Test ----
1029op ≜
1030  ∧ x ≤ y
1031  ∧ ℕ
1032  ∧ ∨ ℤ
1033    ∨ ℝ
1034===="#;
1035        let output = rewrite(&input, &Mode::UnicodeToAscii, true, |s| {
1036            ["def_eq"].contains(&s.name.as_str())
1037        })
1038        .unwrap();
1039        assert_eq!(
1040            r#"
1041---- MODULE Test ----
1042op ==
1043  ∧ x ≤ y
1044  ∧ ℕ
1045  ∧ ∨ ℤ
1046    ∨ ℝ
1047===="#,
1048            output
1049        );
1050        let output = rewrite(&input, &Mode::UnicodeToAscii, true, |s| {
1051            ["def_eq", "leq", "bullet_conj"].contains(&s.name.as_str())
1052        })
1053        .unwrap();
1054        assert_eq!(
1055            r#"
1056---- MODULE Test ----
1057op ==
1058  /\ x <= y
1059  /\ ℕ
1060  /\ ∨ ℤ
1061     ∨ ℝ
1062===="#,
1063            output
1064        );
1065        let output = rewrite(&input, &Mode::UnicodeToAscii, true, |s| {
1066            ["nat_number_set", "int_number_set", "real_number_set"].contains(&s.name.as_str())
1067        })
1068        .unwrap();
1069        assert_eq!(
1070            r#"
1071---- MODULE Test ----
1072op ≜
1073  ∧ x ≤ y
1074  ∧ Nat
1075  ∧ ∨ Int
1076    ∨ Real
1077===="#,
1078            output
1079        );
1080    }
1081}