smtlib_build_util/
spec.rs

1use std::collections::HashSet;
2
3use heck::ToPascalCase;
4use indexmap::IndexMap;
5use itertools::Itertools;
6use miette::IntoDiagnostic;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Serialize, Deserialize)]
10struct RawSpec {
11    #[serde(flatten)]
12    general: IndexMap<String, RawSyntax>,
13}
14
15#[derive(Debug, Serialize, Deserialize)]
16#[serde(untagged)]
17enum RawSyntax {
18    Just {
19        syntax: String,
20        priority: Option<i32>,
21        separator: Option<String>,
22        response: Option<String>,
23    },
24    Class {
25        response: Option<String>,
26        #[serde(flatten)]
27        cases: IndexMap<String, RawSyntax>,
28    },
29}
30
31impl RawSpec {
32    fn parse(&self) -> Spec {
33        Spec {
34            general: self
35                .general
36                .iter()
37                .map(|(n, s)| (n.clone(), s.parse()))
38                .collect(),
39        }
40    }
41}
42
43impl RawSyntax {
44    fn parse(&self) -> Syntax {
45        match self {
46            RawSyntax::Just {
47                syntax,
48                priority,
49                separator,
50                response,
51            } => Syntax::Rule(Rule {
52                syntax: parse_raw_grammar(syntax),
53                priority: priority.unwrap_or_default(),
54                separator: separator.clone(),
55                response: response.as_ref().map(|s| parse_raw_token(s, 0)),
56            }),
57            RawSyntax::Class { response, cases } => Syntax::Class {
58                response: response.clone(),
59                cases: cases
60                    .iter()
61                    .map(|(n, s)| match s {
62                        RawSyntax::Just {
63                            syntax,
64                            priority,
65                            separator,
66                            response,
67                        } => (
68                            n.clone(),
69                            Rule {
70                                syntax: parse_raw_grammar(syntax),
71                                priority: priority.unwrap_or_default(),
72                                separator: separator.clone(),
73                                response: response.as_ref().map(|r| parse_raw_token(r, 0)),
74                            },
75                        ),
76                        RawSyntax::Class { .. } => todo!(),
77                    })
78                    .collect(),
79            },
80        }
81    }
82}
83
84#[derive(Debug, Clone)]
85struct Spec {
86    general: IndexMap<String, Syntax>,
87}
88
89#[derive(Debug, Clone)]
90struct Rule {
91    syntax: Grammar,
92    priority: i32,
93    separator: Option<String>,
94    response: Option<Token>,
95}
96
97#[derive(Debug, Clone)]
98enum Syntax {
99    Rule(Rule),
100    Class {
101        response: Option<String>,
102        cases: IndexMap<String, Rule>,
103    },
104}
105
106#[derive(Debug, Clone)]
107struct Grammar {
108    tokens: Vec<Token>,
109    fields: Vec<Field>,
110}
111
112#[derive(Debug, Clone)]
113enum Token {
114    LParen,
115    RParen,
116    Underscore,
117    Annotation,
118    Builtin(String),
119    Reserved(String),
120    Keyword(String),
121    Field(usize, Field),
122}
123
124impl Token {
125    pub fn is_concrete(&self) -> bool {
126        use Token::*;
127        match self {
128            LParen | RParen | Underscore | Annotation | Keyword(_) | Builtin(_) | Reserved(_) => {
129                true
130            }
131            Field(_, _) => false,
132        }
133    }
134}
135
136#[derive(Debug, Clone)]
137enum Field {
138    One(String),
139    Any(String),
140    NonZero(String),
141    NPlusOne(String),
142}
143
144impl std::fmt::Display for Token {
145    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146        use Token::*;
147
148        match self {
149            LParen => write!(f, "("),
150            RParen => write!(f, ")"),
151            Underscore => write!(f, "_"),
152            Annotation => write!(f, "!"),
153            Builtin(s) => write!(f, "{s}"),
154            Reserved(s) => write!(f, "{s}"),
155            Keyword(k) => write!(f, "{k}"),
156            Field(_, field) => write!(f, "{field}"),
157        }
158    }
159}
160impl std::fmt::Display for Field {
161    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162        match self {
163            Field::One(t) => write!(f, "<{t}>"),
164            Field::Any(t) => write!(f, "<{t}>*"),
165            Field::NonZero(t) => write!(f, "<{t}>+"),
166            Field::NPlusOne(t) => write!(f, "<{t}>n+1"),
167        }
168    }
169}
170impl std::fmt::Display for Grammar {
171    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172        let s = self.tokens.iter().fold("".to_string(), |mut acc, t| {
173            use Token::*;
174            if !acc.ends_with(|c| c == '(') && !acc.is_empty() {
175                acc += match t {
176                    RParen => "",
177                    _ => " ",
178                };
179            }
180            acc += &t.to_string();
181            acc
182        });
183        write!(f, "{s}")
184    }
185}
186
187fn parse_raw_grammar(s: &str) -> Grammar {
188    let mut acc = 0;
189    let p = s
190        .split(' ')
191        .map(|t| {
192            let t = parse_raw_token(t, acc);
193            if let Token::Field(_, _) = t {
194                acc += 1
195            }
196            t
197        })
198        .collect_vec();
199    let fields = p
200        .iter()
201        .filter_map(|t| match t {
202            Token::Field(_, f) => Some(f.clone()),
203            _ => None,
204        })
205        .collect();
206    Grammar { tokens: p, fields }
207}
208
209fn parse_raw_token(s: &str, field_idx: usize) -> Token {
210    match s {
211        "(" => Token::LParen,
212        ")" => Token::RParen,
213        "_" => Token::Underscore,
214        "!" => Token::Annotation,
215        f if f.starts_with(':') => Token::Keyword(f.to_string()),
216        f if f.starts_with('<') && f.ends_with('>') => {
217            Token::Field(field_idx, Field::One(f[1..f.len() - 1].to_string()))
218        }
219        f if f.starts_with('<') && f.ends_with(">*") => {
220            Token::Field(field_idx, Field::Any(f[1..f.len() - 2].to_string()))
221        }
222        f if f.starts_with('<') && f.ends_with(">+") => {
223            Token::Field(field_idx, Field::NonZero(f[1..f.len() - 2].to_string()))
224        }
225        f if f.starts_with('<') && f.ends_with(">n+1") => {
226            Token::Field(field_idx, Field::NPlusOne(f[1..f.len() - 4].to_string()))
227        }
228        f if f.chars().all(|c| c.is_alphabetic() || c == '-') => {
229            if [
230                "_",
231                "!",
232                "as",
233                "BINARY",
234                "DECIMAL",
235                "exists",
236                "forall",
237                "HEXADECIMAL",
238                "let",
239                "match",
240                "NUMERAL",
241                "par",
242                "STRING",
243                "assert",
244                "check-sat",
245                "check-sat-assuming",
246                "declare-const",
247                "declare-datatype",
248                "declare-datatypes",
249                "declare-fun",
250                "declare-sort",
251                "define-fun",
252                "define-fun-rec",
253                "define-sort",
254                "echo",
255                "exit",
256                "get-assertions",
257                "get-assignment",
258                "get-info",
259                "get-model",
260                "get-option",
261                "get-proof",
262                "get-unsat-assumptions",
263                "get-unsat-core",
264                "get-value",
265                "pop",
266                "push",
267                "reset",
268                "reset-assertions",
269                "set-info",
270                "set-logic",
271                "set-option",
272            ]
273            .contains(&f)
274            {
275                Token::Reserved(f.to_string())
276            } else {
277                Token::Builtin(f.to_string())
278            }
279        }
280        _ => todo!("{:?}", s),
281    }
282}
283
284impl Syntax {
285    fn rust_ty_decl_top(&self, name: &str) -> String {
286        let derive = r#"#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]"#;
287        match self {
288            Syntax::Rule(r) => format!(
289                "/// `{}`\n{derive}\npub struct {}({});",
290                r.syntax,
291                name.to_pascal_case(),
292                r.syntax
293                    .tuple_fields(&[name.to_string()].into_iter().collect())
294                    .map(|f| format!("pub {f}"))
295                    .format(",")
296            ),
297            Syntax::Class { cases, .. } => format!(
298                "{derive}pub enum {} {{ {} }}",
299                name.to_pascal_case(),
300                cases
301                    .iter()
302                    .map(|(n, c)| c.rust_ty_decl_child(n, [name.to_string()].into_iter().collect()))
303                    .format(", ")
304            ),
305        }
306    }
307    fn rust_display(&self, name: &str) -> String {
308        match self {
309            Syntax::Rule(r) => format!(
310                r#"
311                impl std::fmt::Display for {} {{
312                    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {{
313                        {}
314                    }}
315                }}
316                "#,
317                name.to_pascal_case(),
318                r.rust_display_impl("self.")
319            ),
320            Syntax::Class { cases, .. } => format!(
321                r#"
322                impl std::fmt::Display for {} {{
323                    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {{
324                        match self {{ {} }}
325                    }}
326                }}
327                "#,
328                name.to_pascal_case(),
329                cases
330                    .iter()
331                    .map(|(n, c)| if c.syntax.fields.is_empty() {
332                        format!(
333                            "Self::{} => {},",
334                            n.to_pascal_case(),
335                            c.rust_display_impl("todo.")
336                        )
337                    } else {
338                        format!(
339                            "Self::{}({}) => {},",
340                            n.to_pascal_case(),
341                            c.syntax
342                                .fields
343                                .iter()
344                                .enumerate()
345                                .map(|(idx, _)| format!("m{idx}"))
346                                .format(","),
347                            c.rust_display_impl("m")
348                        )
349                    })
350                    .format("\n")
351            ),
352        }
353    }
354    fn rust_parse(&self, name: &str) -> String {
355        match self {
356            Syntax::Rule(r) => {
357                format!(
358                    "impl {} {{
359                        pub fn parse(src: &str) -> Result<Self, ParseError> {{
360                            SmtlibParse::parse(&mut Parser::new(src))
361                        }}
362                    }}
363                    impl SmtlibParse for {} {{
364                        fn is_start_of(offset: usize, p: &mut Parser) -> bool {{
365                            {}
366                        }}
367                        fn parse(p: &mut Parser) -> Result<Self, ParseError> {{
368                            {}
369                            {}
370                        }}
371                    }}",
372                    name.to_pascal_case(),
373                    name.to_pascal_case(),
374                    r.rust_start_of_impl(),
375                    r.rust_parse_impl(),
376                    if r.syntax.fields.is_empty() {
377                        "Ok(Self)".to_string()
378                    } else {
379                        format!(
380                            "Ok(Self({}))",
381                            r.syntax
382                                .fields
383                                .iter()
384                                .enumerate()
385                                .map(|(idx, _)| format!("m{idx}.into()"))
386                                .format(", ")
387                        )
388                    }
389                )
390            }
391            Syntax::Class { cases, .. } => {
392                let is_start_of = cases
393                    .iter()
394                    .sorted_by_key(|(_, c)| {
395                        (
396                            c.priority,
397                            c.syntax.tokens.iter().filter(|t| t.is_concrete()).count(),
398                        )
399                    })
400                    .rev()
401                    .map(|(_, c)| format!("({})", c.rust_start_of_check()))
402                    .format(" || ");
403                let parse = cases
404                    .iter()
405                    .sorted_by_key(|(_, c)| {
406                        (
407                            c.priority,
408                            c.syntax.tokens.iter().filter(|t| t.is_concrete()).count(),
409                        )
410                    })
411                    .rev()
412                    .map(|(n, c)| {
413                        let construct = rust_parse_construct_variant("self", n, &c.syntax);
414                        format!(
415                            "if {} {{ {}\n return Ok({construct}); }}",
416                            c.rust_start_of_check(),
417                            c.rust_parse_impl(),
418                        )
419                    })
420                    .format("\n");
421                format!(
422                    "impl {} {{
423                        pub fn parse(src: &str) -> Result<Self, ParseError> {{
424                            SmtlibParse::parse(&mut Parser::new(src))
425                        }}
426                    }}
427                    impl SmtlibParse for {} {{
428                        fn is_start_of(offset: usize, p: &mut Parser) -> bool {{
429                            {is_start_of}
430                        }}
431                        fn parse(p: &mut Parser) -> Result<Self, ParseError> {{
432                            let offset = 0;
433                            {parse}
434                            Err(p.stuck({name:?}))
435                        }}
436                    }}",
437                    name.to_pascal_case(),
438                    name.to_pascal_case(),
439                )
440            }
441        }
442    }
443    fn rust_response(&self, name: &str) -> String {
444        match self {
445            Syntax::Rule(_) | Syntax::Class { response: None, .. } => "".to_string(),
446            Syntax::Class {
447                cases,
448                response: Some(response),
449            } => {
450                let has_response = cases
451                    .iter()
452                    .map(|(n, c)| {
453                        format!(
454                            "Self::{}{} => {},",
455                            n.to_pascal_case(),
456                            if c.syntax.fields.is_empty() {
457                                "".to_string()
458                            } else {
459                                format!("({})", c.syntax.fields.iter().map(|_| "_").format(", "))
460                            },
461                            c.response.is_some(),
462                        )
463                    })
464                    .format("\n");
465                let parse_response = cases
466                    .iter()
467                    .map(|(n, c)| {
468                        format!(
469                            "Self::{}{} => {},",
470                            n.to_pascal_case(),
471                            if c.syntax.fields.is_empty() {
472                                "".to_string()
473                            } else {
474                                format!("({})", c.syntax.fields.iter().map(|_| "_").format(", "))
475                            },
476                            if let Some(res) = &c.response {
477                                let res_ty = match res {
478                                    Token::LParen
479                                    | Token::RParen
480                                    | Token::Underscore
481                                    | Token::Annotation
482                                    | Token::Builtin(_)
483                                    | Token::Reserved(_)
484                                    | Token::Keyword(_) => todo!(),
485                                    Token::Field(_, f) => match f {
486                                        Field::One(t)
487                                        | Field::Any(t)
488                                        | Field::NonZero(t)
489                                        | Field::NPlusOne(t) => t.to_string(),
490                                    },
491                                };
492                                format!(
493                                    "Ok(Some({}::{}({}::parse(response)?)))",
494                                    response.to_pascal_case(),
495                                    res_ty.to_pascal_case(),
496                                    res_ty.to_pascal_case(),
497                                )
498                            } else {
499                                "Ok(None)".to_string()
500                            },
501                        )
502                    })
503                    .format("\n");
504
505                format!(
506                        "
507                        impl {} {{
508                            pub fn has_response(&self) -> bool {{
509                                match self {{
510                                    {}
511                                }}
512                            }}
513                            pub fn parse_response(&self, response: &str) -> Result<std::option::Option<{}>, ParseError> {{
514                                match self {{
515                                    {}
516                                }}
517                            }}
518                        }}
519                        ",
520                        name.to_pascal_case(),
521                        has_response,
522                        response.to_pascal_case(),
523                        parse_response,
524                    )
525            }
526        }
527    }
528}
529
530impl Rule {
531    fn rust_display_impl(&self, scope: &str) -> String {
532        format!(
533            r#"write!(f, "{}" {})"#,
534            self.syntax
535                .tokens
536                .iter()
537                .fold("".to_string(), |mut acc, t| {
538                    use Token::*;
539                    if !acc.ends_with(|c| c == '(') && !acc.is_empty() {
540                        acc += match t {
541                            RParen => "",
542                            _ => " ",
543                        };
544                    }
545                    acc += match t {
546                        LParen => "(",
547                        RParen => ")",
548                        Underscore => "_",
549                        Annotation => "!",
550                        Builtin(s) => s,
551                        Reserved(s) => s,
552                        Keyword(k) => k,
553                        Field(_, _) => "{}",
554                    };
555                    acc
556                }),
557            self.syntax
558                .fields
559                .iter()
560                .enumerate()
561                .map(|(idx, f)| match f {
562                    Field::One(_) => {
563                        format!(", {scope}{idx}")
564                    }
565                    Field::Any(_) | Field::NonZero(_) | Field::NPlusOne(_) => {
566                        format!(
567                            r#", {scope}{idx}.iter().format({:?})"#,
568                            self.separator.as_deref().unwrap_or(" ")
569                        )
570                    }
571                })
572                .format("")
573        )
574    }
575    fn rust_ty_decl_child(&self, name: &str, inside_of: HashSet<String>) -> String {
576        if self.syntax.fields.is_empty() {
577            format!("/// `{}`\n{}", self.syntax, name.to_pascal_case())
578        } else {
579            format!(
580                "/// `{}`\n{}({})",
581                self.syntax,
582                name.to_pascal_case(),
583                self.syntax.tuple_fields(&inside_of).format(",")
584            )
585        }
586    }
587    fn rust_start_of_check(&self) -> String {
588        let is_all_variable = !self.syntax.tokens.iter().any(|t| t.is_concrete());
589
590        if is_all_variable {
591            self.syntax
592                .tokens
593                .iter()
594                .enumerate()
595                .map(|(idx, t)| rust_check_token(idx, t))
596                .format(" && ")
597                .to_string()
598        } else if !self.syntax.tokens[0].is_concrete() {
599            let q = rust_check_token(0, &self.syntax.tokens[0]);
600            assert!(!q.is_empty());
601            q
602        } else {
603            self.syntax
604                .tokens
605                .iter()
606                .take_while(|t| t.is_concrete())
607                .enumerate()
608                .map(|(idx, t)| rust_check_token(idx, t))
609                .format(" && ")
610                .to_string()
611        }
612    }
613    fn rust_start_of_impl(&self) -> String {
614        self.rust_start_of_check()
615    }
616    fn rust_parse_impl(&self) -> String {
617        let stmts = self.syntax.tokens.iter().map(rust_parse_token);
618        stmts.format("\n").to_string()
619    }
620}
621
622fn rust_parse_construct_variant(suffix: &str, name: &str, syntax: &Grammar) -> String {
623    format!(
624        "{}::{}{}",
625        suffix.to_pascal_case(),
626        name.to_pascal_case(),
627        if syntax.fields.is_empty() {
628            "".to_string()
629        } else {
630            format!(
631                "({})",
632                syntax
633                    .fields
634                    .iter()
635                    .enumerate()
636                    .map(|(idx, _)| format!("m{idx}.into()"))
637                    .format(", ")
638            )
639        }
640    )
641}
642
643fn rust_parse_token(t: &Token) -> String {
644    match t {
645        Token::LParen => "p.expect(Token::LParen)?;".to_string(),
646        Token::RParen => "p.expect(Token::RParen)?;".to_string(),
647        Token::Underscore => "p.expect_matches(Token::Reserved, \"_\")?;".to_string(),
648        Token::Annotation => "p.expect_matches(Token::Reserved, \"!\")?;".to_string(),
649        Token::Builtin(b) => format!("p.expect_matches(Token::Symbol, {b:?})?;"),
650        Token::Reserved(b) => format!("p.expect_matches(Token::Reserved, {b:?})?;"),
651        Token::Keyword(kw) => format!("p.expect_matches(Token::Keyword, {kw:?})?;"),
652        Token::Field(idx, f) => match f {
653            Field::One(t) => format!(
654                "let m{idx} = <{} as SmtlibParse>::parse(p)?;",
655                t.to_pascal_case()
656            ),
657            Field::Any(t) => format!("let m{idx} = p.any::<{}>()?;", t.to_pascal_case()),
658            Field::NonZero(t) => {
659                format!("let m{idx} = p.non_zero::<{}>()?;", t.to_pascal_case())
660            }
661            Field::NPlusOne(t) => {
662                format!("let m{idx} = p.n_plus_one::<{}>()?;", t.to_pascal_case())
663            }
664        },
665    }
666}
667
668fn rust_check_token(idx: usize, t: &Token) -> String {
669    match t {
670        Token::LParen => format!("p.nth(offset + {idx}) == Token::LParen"),
671        Token::RParen => format!("p.nth(offset + {idx}) == Token::RParen"),
672        Token::Underscore => format!("p.nth_matches(offset + {idx}, Token::Reserved, \"_\")"),
673        Token::Annotation => format!("p.nth_matches(offset + {idx}, Token::Reserved, \"!\")"),
674        Token::Builtin(b) => format!("p.nth_matches(offset + {idx}, Token::Symbol, {b:?})"),
675        Token::Reserved(b) => format!("p.nth_matches(offset + {idx}, Token::Reserved, {b:?})"),
676        Token::Keyword(kw) => {
677            format!("p.nth_matches(offset + {idx}, Token::Keyword, {kw:?})")
678        }
679        Token::Field(_, f) => match f {
680            Field::One(t) | Field::NonZero(t) | Field::NPlusOne(t) => {
681                format!("{}::is_start_of(offset + {idx}, p)", t.to_pascal_case())
682            }
683            Field::Any(_) => "todo!()".to_string(),
684        },
685    }
686}
687
688impl Grammar {
689    fn tuple_fields<'a>(
690        &'a self,
691        inside_of: &'a HashSet<String>,
692    ) -> impl Iterator<Item = String> + 'a {
693        self.fields.iter().map(|f| match &f {
694            Field::One(t) => {
695                if inside_of.contains(t) {
696                    format!("Box<{}>", t.to_pascal_case())
697                } else {
698                    t.to_pascal_case()
699                }
700            }
701            Field::Any(t) | Field::NonZero(t) | Field::NPlusOne(t) => {
702                format!("Vec<{}>", t.to_pascal_case())
703            }
704        })
705    }
706}
707
708pub fn generate(mut f: impl std::io::Write) -> miette::Result<()> {
709    use std::fmt::Write;
710
711    let mut buf = String::new();
712
713    let raw: RawSpec = toml::from_str(include_str!("./spec.toml")).into_diagnostic()?;
714    let spec = raw.parse();
715
716    writeln!(buf, "// This file is autogenerated! DO NOT EDIT!\n").into_diagnostic()?;
717    writeln!(buf, "use crate::parse::{{Token, Parser, ParseError}};").into_diagnostic()?;
718    writeln!(buf, "use itertools::Itertools; use crate::lexicon::*;\n").into_diagnostic()?;
719
720    for (name, s) in &spec.general {
721        writeln!(buf, "{}", s.rust_ty_decl_top(name)).into_diagnostic()?;
722        writeln!(buf, "{}", s.rust_display(name)).into_diagnostic()?;
723        writeln!(buf, "{}", s.rust_parse(name)).into_diagnostic()?;
724        writeln!(buf, "{}", s.rust_response(name)).into_diagnostic()?;
725    }
726
727    let file = syn::parse_file(&buf).into_diagnostic()?;
728    let pretty = prettyplease::unparse(&file);
729
730    f.write_all(pretty.as_bytes()).into_diagnostic()?;
731
732    Ok(())
733}