1use std::collections::HashSet;
2
3use heck::ToPascalCase;
4use indexmap::IndexMap;
5use itertools::Itertools;
6use miette::IntoDiagnostic;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Serialize, Deserialize)]
10struct RawSpec {
11 #[serde(flatten)]
12 general: IndexMap<String, RawSyntax>,
13}
14
15#[derive(Debug, Serialize, Deserialize)]
16#[serde(untagged)]
17enum RawSyntax {
18 Just {
19 syntax: String,
20 priority: Option<i32>,
21 separator: Option<String>,
22 response: Option<String>,
23 },
24 Class {
25 response: Option<String>,
26 #[serde(flatten)]
27 cases: IndexMap<String, RawSyntax>,
28 },
29}
30
31impl RawSpec {
32 fn parse(&self) -> Spec {
33 Spec {
34 general: self
35 .general
36 .iter()
37 .map(|(n, s)| (n.clone(), s.parse()))
38 .collect(),
39 }
40 }
41}
42
43impl RawSyntax {
44 fn parse(&self) -> Syntax {
45 match self {
46 RawSyntax::Just {
47 syntax,
48 priority,
49 separator,
50 response,
51 } => Syntax::Rule(Rule {
52 syntax: parse_raw_grammar(syntax),
53 priority: priority.unwrap_or_default(),
54 separator: separator.clone(),
55 response: response.as_ref().map(|s| parse_raw_token(s, 0)),
56 }),
57 RawSyntax::Class { response, cases } => Syntax::Class {
58 response: response.clone(),
59 cases: cases
60 .iter()
61 .map(|(n, s)| match s {
62 RawSyntax::Just {
63 syntax,
64 priority,
65 separator,
66 response,
67 } => (
68 n.clone(),
69 Rule {
70 syntax: parse_raw_grammar(syntax),
71 priority: priority.unwrap_or_default(),
72 separator: separator.clone(),
73 response: response.as_ref().map(|r| parse_raw_token(r, 0)),
74 },
75 ),
76 RawSyntax::Class { .. } => todo!(),
77 })
78 .collect(),
79 },
80 }
81 }
82}
83
84#[derive(Debug, Clone)]
85struct Spec {
86 general: IndexMap<String, Syntax>,
87}
88
89#[derive(Debug, Clone)]
90struct Rule {
91 syntax: Grammar,
92 priority: i32,
93 separator: Option<String>,
94 response: Option<Token>,
95}
96
97#[derive(Debug, Clone)]
98enum Syntax {
99 Rule(Rule),
100 Class {
101 response: Option<String>,
102 cases: IndexMap<String, Rule>,
103 },
104}
105
106#[derive(Debug, Clone)]
107struct Grammar {
108 tokens: Vec<Token>,
109 fields: Vec<Field>,
110}
111
112#[derive(Debug, Clone)]
113enum Token {
114 LParen,
115 RParen,
116 Underscore,
117 Annotation,
118 Builtin(String),
119 Reserved(String),
120 Keyword(String),
121 Field(usize, Field),
122}
123
124impl Token {
125 pub fn is_concrete(&self) -> bool {
126 use Token::*;
127 match self {
128 LParen | RParen | Underscore | Annotation | Keyword(_) | Builtin(_) | Reserved(_) => {
129 true
130 }
131 Field(_, _) => false,
132 }
133 }
134}
135
136#[derive(Debug, Clone)]
137enum Field {
138 One(String),
139 Any(String),
140 NonZero(String),
141 NPlusOne(String),
142}
143
144impl std::fmt::Display for Token {
145 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146 use Token::*;
147
148 match self {
149 LParen => write!(f, "("),
150 RParen => write!(f, ")"),
151 Underscore => write!(f, "_"),
152 Annotation => write!(f, "!"),
153 Builtin(s) => write!(f, "{s}"),
154 Reserved(s) => write!(f, "{s}"),
155 Keyword(k) => write!(f, "{k}"),
156 Field(_, field) => write!(f, "{field}"),
157 }
158 }
159}
160impl std::fmt::Display for Field {
161 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162 match self {
163 Field::One(t) => write!(f, "<{t}>"),
164 Field::Any(t) => write!(f, "<{t}>*"),
165 Field::NonZero(t) => write!(f, "<{t}>+"),
166 Field::NPlusOne(t) => write!(f, "<{t}>n+1"),
167 }
168 }
169}
170impl std::fmt::Display for Grammar {
171 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172 let s = self.tokens.iter().fold("".to_string(), |mut acc, t| {
173 use Token::*;
174 if !acc.ends_with(|c| c == '(') && !acc.is_empty() {
175 acc += match t {
176 RParen => "",
177 _ => " ",
178 };
179 }
180 acc += &t.to_string();
181 acc
182 });
183 write!(f, "{s}")
184 }
185}
186
187fn parse_raw_grammar(s: &str) -> Grammar {
188 let mut acc = 0;
189 let p = s
190 .split(' ')
191 .map(|t| {
192 let t = parse_raw_token(t, acc);
193 if let Token::Field(_, _) = t {
194 acc += 1
195 }
196 t
197 })
198 .collect_vec();
199 let fields = p
200 .iter()
201 .filter_map(|t| match t {
202 Token::Field(_, f) => Some(f.clone()),
203 _ => None,
204 })
205 .collect();
206 Grammar { tokens: p, fields }
207}
208
209fn parse_raw_token(s: &str, field_idx: usize) -> Token {
210 match s {
211 "(" => Token::LParen,
212 ")" => Token::RParen,
213 "_" => Token::Underscore,
214 "!" => Token::Annotation,
215 f if f.starts_with(':') => Token::Keyword(f.to_string()),
216 f if f.starts_with('<') && f.ends_with('>') => {
217 Token::Field(field_idx, Field::One(f[1..f.len() - 1].to_string()))
218 }
219 f if f.starts_with('<') && f.ends_with(">*") => {
220 Token::Field(field_idx, Field::Any(f[1..f.len() - 2].to_string()))
221 }
222 f if f.starts_with('<') && f.ends_with(">+") => {
223 Token::Field(field_idx, Field::NonZero(f[1..f.len() - 2].to_string()))
224 }
225 f if f.starts_with('<') && f.ends_with(">n+1") => {
226 Token::Field(field_idx, Field::NPlusOne(f[1..f.len() - 4].to_string()))
227 }
228 f if f.chars().all(|c| c.is_alphabetic() || c == '-') => {
229 if [
230 "_",
231 "!",
232 "as",
233 "BINARY",
234 "DECIMAL",
235 "exists",
236 "forall",
237 "HEXADECIMAL",
238 "let",
239 "match",
240 "NUMERAL",
241 "par",
242 "STRING",
243 "assert",
244 "check-sat",
245 "check-sat-assuming",
246 "declare-const",
247 "declare-datatype",
248 "declare-datatypes",
249 "declare-fun",
250 "declare-sort",
251 "define-fun",
252 "define-fun-rec",
253 "define-sort",
254 "echo",
255 "exit",
256 "get-assertions",
257 "get-assignment",
258 "get-info",
259 "get-model",
260 "get-option",
261 "get-proof",
262 "get-unsat-assumptions",
263 "get-unsat-core",
264 "get-value",
265 "pop",
266 "push",
267 "reset",
268 "reset-assertions",
269 "set-info",
270 "set-logic",
271 "set-option",
272 ]
273 .contains(&f)
274 {
275 Token::Reserved(f.to_string())
276 } else {
277 Token::Builtin(f.to_string())
278 }
279 }
280 _ => todo!("{:?}", s),
281 }
282}
283
284impl Syntax {
285 fn rust_ty_decl_top(&self, name: &str) -> String {
286 let derive = r#"#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]"#;
287 match self {
288 Syntax::Rule(r) => format!(
289 "/// `{}`\n{derive}\npub struct {}({});",
290 r.syntax,
291 name.to_pascal_case(),
292 r.syntax
293 .tuple_fields(&[name.to_string()].into_iter().collect())
294 .map(|f| format!("pub {f}"))
295 .format(",")
296 ),
297 Syntax::Class { cases, .. } => format!(
298 "{derive}pub enum {} {{ {} }}",
299 name.to_pascal_case(),
300 cases
301 .iter()
302 .map(|(n, c)| c.rust_ty_decl_child(n, [name.to_string()].into_iter().collect()))
303 .format(", ")
304 ),
305 }
306 }
307 fn rust_display(&self, name: &str) -> String {
308 match self {
309 Syntax::Rule(r) => format!(
310 r#"
311 impl std::fmt::Display for {} {{
312 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {{
313 {}
314 }}
315 }}
316 "#,
317 name.to_pascal_case(),
318 r.rust_display_impl("self.")
319 ),
320 Syntax::Class { cases, .. } => format!(
321 r#"
322 impl std::fmt::Display for {} {{
323 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {{
324 match self {{ {} }}
325 }}
326 }}
327 "#,
328 name.to_pascal_case(),
329 cases
330 .iter()
331 .map(|(n, c)| if c.syntax.fields.is_empty() {
332 format!(
333 "Self::{} => {},",
334 n.to_pascal_case(),
335 c.rust_display_impl("todo.")
336 )
337 } else {
338 format!(
339 "Self::{}({}) => {},",
340 n.to_pascal_case(),
341 c.syntax
342 .fields
343 .iter()
344 .enumerate()
345 .map(|(idx, _)| format!("m{idx}"))
346 .format(","),
347 c.rust_display_impl("m")
348 )
349 })
350 .format("\n")
351 ),
352 }
353 }
354 fn rust_parse(&self, name: &str) -> String {
355 match self {
356 Syntax::Rule(r) => {
357 format!(
358 "impl {} {{
359 pub fn parse(src: &str) -> Result<Self, ParseError> {{
360 SmtlibParse::parse(&mut Parser::new(src))
361 }}
362 }}
363 impl SmtlibParse for {} {{
364 fn is_start_of(offset: usize, p: &mut Parser) -> bool {{
365 {}
366 }}
367 fn parse(p: &mut Parser) -> Result<Self, ParseError> {{
368 {}
369 {}
370 }}
371 }}",
372 name.to_pascal_case(),
373 name.to_pascal_case(),
374 r.rust_start_of_impl(),
375 r.rust_parse_impl(),
376 if r.syntax.fields.is_empty() {
377 "Ok(Self)".to_string()
378 } else {
379 format!(
380 "Ok(Self({}))",
381 r.syntax
382 .fields
383 .iter()
384 .enumerate()
385 .map(|(idx, _)| format!("m{idx}.into()"))
386 .format(", ")
387 )
388 }
389 )
390 }
391 Syntax::Class { cases, .. } => {
392 let is_start_of = cases
393 .iter()
394 .sorted_by_key(|(_, c)| {
395 (
396 c.priority,
397 c.syntax.tokens.iter().filter(|t| t.is_concrete()).count(),
398 )
399 })
400 .rev()
401 .map(|(_, c)| format!("({})", c.rust_start_of_check()))
402 .format(" || ");
403 let parse = cases
404 .iter()
405 .sorted_by_key(|(_, c)| {
406 (
407 c.priority,
408 c.syntax.tokens.iter().filter(|t| t.is_concrete()).count(),
409 )
410 })
411 .rev()
412 .map(|(n, c)| {
413 let construct = rust_parse_construct_variant("self", n, &c.syntax);
414 format!(
415 "if {} {{ {}\n return Ok({construct}); }}",
416 c.rust_start_of_check(),
417 c.rust_parse_impl(),
418 )
419 })
420 .format("\n");
421 format!(
422 "impl {} {{
423 pub fn parse(src: &str) -> Result<Self, ParseError> {{
424 SmtlibParse::parse(&mut Parser::new(src))
425 }}
426 }}
427 impl SmtlibParse for {} {{
428 fn is_start_of(offset: usize, p: &mut Parser) -> bool {{
429 {is_start_of}
430 }}
431 fn parse(p: &mut Parser) -> Result<Self, ParseError> {{
432 let offset = 0;
433 {parse}
434 Err(p.stuck({name:?}))
435 }}
436 }}",
437 name.to_pascal_case(),
438 name.to_pascal_case(),
439 )
440 }
441 }
442 }
443 fn rust_response(&self, name: &str) -> String {
444 match self {
445 Syntax::Rule(_) | Syntax::Class { response: None, .. } => "".to_string(),
446 Syntax::Class {
447 cases,
448 response: Some(response),
449 } => {
450 let has_response = cases
451 .iter()
452 .map(|(n, c)| {
453 format!(
454 "Self::{}{} => {},",
455 n.to_pascal_case(),
456 if c.syntax.fields.is_empty() {
457 "".to_string()
458 } else {
459 format!("({})", c.syntax.fields.iter().map(|_| "_").format(", "))
460 },
461 c.response.is_some(),
462 )
463 })
464 .format("\n");
465 let parse_response = cases
466 .iter()
467 .map(|(n, c)| {
468 format!(
469 "Self::{}{} => {},",
470 n.to_pascal_case(),
471 if c.syntax.fields.is_empty() {
472 "".to_string()
473 } else {
474 format!("({})", c.syntax.fields.iter().map(|_| "_").format(", "))
475 },
476 if let Some(res) = &c.response {
477 let res_ty = match res {
478 Token::LParen
479 | Token::RParen
480 | Token::Underscore
481 | Token::Annotation
482 | Token::Builtin(_)
483 | Token::Reserved(_)
484 | Token::Keyword(_) => todo!(),
485 Token::Field(_, f) => match f {
486 Field::One(t)
487 | Field::Any(t)
488 | Field::NonZero(t)
489 | Field::NPlusOne(t) => t.to_string(),
490 },
491 };
492 format!(
493 "Ok(Some({}::{}({}::parse(response)?)))",
494 response.to_pascal_case(),
495 res_ty.to_pascal_case(),
496 res_ty.to_pascal_case(),
497 )
498 } else {
499 "Ok(None)".to_string()
500 },
501 )
502 })
503 .format("\n");
504
505 format!(
506 "
507 impl {} {{
508 pub fn has_response(&self) -> bool {{
509 match self {{
510 {}
511 }}
512 }}
513 pub fn parse_response(&self, response: &str) -> Result<std::option::Option<{}>, ParseError> {{
514 match self {{
515 {}
516 }}
517 }}
518 }}
519 ",
520 name.to_pascal_case(),
521 has_response,
522 response.to_pascal_case(),
523 parse_response,
524 )
525 }
526 }
527 }
528}
529
530impl Rule {
531 fn rust_display_impl(&self, scope: &str) -> String {
532 format!(
533 r#"write!(f, "{}" {})"#,
534 self.syntax
535 .tokens
536 .iter()
537 .fold("".to_string(), |mut acc, t| {
538 use Token::*;
539 if !acc.ends_with(|c| c == '(') && !acc.is_empty() {
540 acc += match t {
541 RParen => "",
542 _ => " ",
543 };
544 }
545 acc += match t {
546 LParen => "(",
547 RParen => ")",
548 Underscore => "_",
549 Annotation => "!",
550 Builtin(s) => s,
551 Reserved(s) => s,
552 Keyword(k) => k,
553 Field(_, _) => "{}",
554 };
555 acc
556 }),
557 self.syntax
558 .fields
559 .iter()
560 .enumerate()
561 .map(|(idx, f)| match f {
562 Field::One(_) => {
563 format!(", {scope}{idx}")
564 }
565 Field::Any(_) | Field::NonZero(_) | Field::NPlusOne(_) => {
566 format!(
567 r#", {scope}{idx}.iter().format({:?})"#,
568 self.separator.as_deref().unwrap_or(" ")
569 )
570 }
571 })
572 .format("")
573 )
574 }
575 fn rust_ty_decl_child(&self, name: &str, inside_of: HashSet<String>) -> String {
576 if self.syntax.fields.is_empty() {
577 format!("/// `{}`\n{}", self.syntax, name.to_pascal_case())
578 } else {
579 format!(
580 "/// `{}`\n{}({})",
581 self.syntax,
582 name.to_pascal_case(),
583 self.syntax.tuple_fields(&inside_of).format(",")
584 )
585 }
586 }
587 fn rust_start_of_check(&self) -> String {
588 let is_all_variable = !self.syntax.tokens.iter().any(|t| t.is_concrete());
589
590 if is_all_variable {
591 self.syntax
592 .tokens
593 .iter()
594 .enumerate()
595 .map(|(idx, t)| rust_check_token(idx, t))
596 .format(" && ")
597 .to_string()
598 } else if !self.syntax.tokens[0].is_concrete() {
599 let q = rust_check_token(0, &self.syntax.tokens[0]);
600 assert!(!q.is_empty());
601 q
602 } else {
603 self.syntax
604 .tokens
605 .iter()
606 .take_while(|t| t.is_concrete())
607 .enumerate()
608 .map(|(idx, t)| rust_check_token(idx, t))
609 .format(" && ")
610 .to_string()
611 }
612 }
613 fn rust_start_of_impl(&self) -> String {
614 self.rust_start_of_check()
615 }
616 fn rust_parse_impl(&self) -> String {
617 let stmts = self.syntax.tokens.iter().map(rust_parse_token);
618 stmts.format("\n").to_string()
619 }
620}
621
622fn rust_parse_construct_variant(suffix: &str, name: &str, syntax: &Grammar) -> String {
623 format!(
624 "{}::{}{}",
625 suffix.to_pascal_case(),
626 name.to_pascal_case(),
627 if syntax.fields.is_empty() {
628 "".to_string()
629 } else {
630 format!(
631 "({})",
632 syntax
633 .fields
634 .iter()
635 .enumerate()
636 .map(|(idx, _)| format!("m{idx}.into()"))
637 .format(", ")
638 )
639 }
640 )
641}
642
643fn rust_parse_token(t: &Token) -> String {
644 match t {
645 Token::LParen => "p.expect(Token::LParen)?;".to_string(),
646 Token::RParen => "p.expect(Token::RParen)?;".to_string(),
647 Token::Underscore => "p.expect_matches(Token::Reserved, \"_\")?;".to_string(),
648 Token::Annotation => "p.expect_matches(Token::Reserved, \"!\")?;".to_string(),
649 Token::Builtin(b) => format!("p.expect_matches(Token::Symbol, {b:?})?;"),
650 Token::Reserved(b) => format!("p.expect_matches(Token::Reserved, {b:?})?;"),
651 Token::Keyword(kw) => format!("p.expect_matches(Token::Keyword, {kw:?})?;"),
652 Token::Field(idx, f) => match f {
653 Field::One(t) => format!(
654 "let m{idx} = <{} as SmtlibParse>::parse(p)?;",
655 t.to_pascal_case()
656 ),
657 Field::Any(t) => format!("let m{idx} = p.any::<{}>()?;", t.to_pascal_case()),
658 Field::NonZero(t) => {
659 format!("let m{idx} = p.non_zero::<{}>()?;", t.to_pascal_case())
660 }
661 Field::NPlusOne(t) => {
662 format!("let m{idx} = p.n_plus_one::<{}>()?;", t.to_pascal_case())
663 }
664 },
665 }
666}
667
668fn rust_check_token(idx: usize, t: &Token) -> String {
669 match t {
670 Token::LParen => format!("p.nth(offset + {idx}) == Token::LParen"),
671 Token::RParen => format!("p.nth(offset + {idx}) == Token::RParen"),
672 Token::Underscore => format!("p.nth_matches(offset + {idx}, Token::Reserved, \"_\")"),
673 Token::Annotation => format!("p.nth_matches(offset + {idx}, Token::Reserved, \"!\")"),
674 Token::Builtin(b) => format!("p.nth_matches(offset + {idx}, Token::Symbol, {b:?})"),
675 Token::Reserved(b) => format!("p.nth_matches(offset + {idx}, Token::Reserved, {b:?})"),
676 Token::Keyword(kw) => {
677 format!("p.nth_matches(offset + {idx}, Token::Keyword, {kw:?})")
678 }
679 Token::Field(_, f) => match f {
680 Field::One(t) | Field::NonZero(t) | Field::NPlusOne(t) => {
681 format!("{}::is_start_of(offset + {idx}, p)", t.to_pascal_case())
682 }
683 Field::Any(_) => "todo!()".to_string(),
684 },
685 }
686}
687
688impl Grammar {
689 fn tuple_fields<'a>(
690 &'a self,
691 inside_of: &'a HashSet<String>,
692 ) -> impl Iterator<Item = String> + 'a {
693 self.fields.iter().map(|f| match &f {
694 Field::One(t) => {
695 if inside_of.contains(t) {
696 format!("Box<{}>", t.to_pascal_case())
697 } else {
698 t.to_pascal_case()
699 }
700 }
701 Field::Any(t) | Field::NonZero(t) | Field::NPlusOne(t) => {
702 format!("Vec<{}>", t.to_pascal_case())
703 }
704 })
705 }
706}
707
708pub fn generate(mut f: impl std::io::Write) -> miette::Result<()> {
709 use std::fmt::Write;
710
711 let mut buf = String::new();
712
713 let raw: RawSpec = toml::from_str(include_str!("./spec.toml")).into_diagnostic()?;
714 let spec = raw.parse();
715
716 writeln!(buf, "// This file is autogenerated! DO NOT EDIT!\n").into_diagnostic()?;
717 writeln!(buf, "use crate::parse::{{Token, Parser, ParseError}};").into_diagnostic()?;
718 writeln!(buf, "use itertools::Itertools; use crate::lexicon::*;\n").into_diagnostic()?;
719
720 for (name, s) in &spec.general {
721 writeln!(buf, "{}", s.rust_ty_decl_top(name)).into_diagnostic()?;
722 writeln!(buf, "{}", s.rust_display(name)).into_diagnostic()?;
723 writeln!(buf, "{}", s.rust_parse(name)).into_diagnostic()?;
724 writeln!(buf, "{}", s.rust_response(name)).into_diagnostic()?;
725 }
726
727 let file = syn::parse_file(&buf).into_diagnostic()?;
728 let pretty = prettyplease::unparse(&file);
729
730 f.write_all(pretty.as_bytes()).into_diagnostic()?;
731
732 Ok(())
733}