scheme_rs/
expand.rs

1use crate::{
2    ast::Literal,
3    continuation::Continuation,
4    error::RuntimeError,
5    gc::{Gc, Trace},
6    syntax::{Identifier, Span, Syntax},
7    value::Value,
8};
9use scheme_rs_macros::builtin;
10use std::{
11    collections::{HashMap, HashSet},
12    sync::Arc,
13};
14
15#[derive(Clone, Trace)]
16pub struct Transformer {
17    pub rules: Vec<SyntaxRule>,
18    pub is_variable_transformer: bool,
19}
20
21impl Transformer {
22    pub fn expand(&self, expr: &Syntax) -> Option<Syntax> {
23        for rule in &self.rules {
24            if let expansion @ Some(_) = rule.expand(expr) {
25                return expansion;
26            }
27        }
28        None
29    }
30}
31
32#[derive(Clone, Debug, Trace)]
33pub struct SyntaxRule {
34    pub pattern: Pattern,
35    pub template: Template,
36}
37
38impl SyntaxRule {
39    pub fn compile(keywords: &HashSet<String>, pattern: &Syntax, template: &Syntax) -> Self {
40        let mut variables = HashSet::new();
41        let pattern = Pattern::compile(pattern, keywords, &mut variables);
42        let template = Template::compile(template, &variables);
43        Self { pattern, template }
44    }
45
46    fn expand(&self, expr: &Syntax) -> Option<Syntax> {
47        let mut top_expansion_level = ExpansionLevel::default();
48        let curr_span = expr.span().clone();
49        self.pattern
50            .matches(expr, &mut top_expansion_level)
51            .then(|| {
52                let binds = Binds::new_top(&top_expansion_level);
53                self.template.execute(&binds, curr_span).unwrap()
54            })
55    }
56}
57
58#[derive(Clone, Debug, Trace)]
59pub enum Pattern {
60    Null,
61    Underscore,
62    Ellipsis(Box<Pattern>),
63    List(Vec<Pattern>),
64    Vector(Vec<Pattern>),
65    Variable(String),
66    Keyword(String),
67    Literal(Literal),
68}
69
70impl Pattern {
71    pub fn compile(
72        expr: &Syntax,
73        keywords: &HashSet<String>,
74        variables: &mut HashSet<String>,
75    ) -> Self {
76        match expr {
77            Syntax::Null { .. } => Self::Null,
78            Syntax::Identifier { ident, .. } if ident.name == "_" => Self::Underscore,
79            Syntax::Identifier { ident, .. } if keywords.contains(&ident.name) => {
80                Self::Keyword(ident.name.clone())
81            }
82            Syntax::Identifier { ident, .. } => {
83                variables.insert(ident.name.clone());
84                Self::Variable(ident.name.clone())
85            }
86            Syntax::List { list, .. } => Self::List(Self::compile_slice(list, keywords, variables)),
87            Syntax::Vector { vector, .. } => {
88                Self::Vector(Self::compile_slice(vector, keywords, variables))
89            }
90            Syntax::Literal { literal, .. } => Self::Literal(literal.clone()),
91        }
92    }
93
94    fn compile_slice(
95        mut expr: &[Syntax],
96        keywords: &HashSet<String>,
97        variables: &mut HashSet<String>,
98    ) -> Vec<Self> {
99        let mut output = Vec::new();
100        loop {
101            match expr {
102                [] => break,
103                [pattern, Syntax::Identifier {
104                    ident: ellipsis, ..
105                }, tail @ ..]
106                    if ellipsis.name == "..." =>
107                {
108                    output.push(Self::Ellipsis(Box::new(Pattern::compile(
109                        pattern, keywords, variables,
110                    ))));
111                    expr = tail;
112                }
113                [head, tail @ ..] => {
114                    output.push(Self::compile(head, keywords, variables));
115                    expr = tail;
116                }
117            }
118        }
119        output
120    }
121
122    fn matches(&self, expr: &Syntax, expansion_level: &mut ExpansionLevel) -> bool {
123        match self {
124            Self::Underscore => !expr.is_null(),
125            Self::Variable(ref name) => {
126                assert!(expansion_level
127                    .binds
128                    .insert(name.clone(), expr.clone())
129                    .is_none());
130                true
131            }
132            Self::Literal(ref lhs) => {
133                if let Syntax::Literal { literal: rhs, .. } = expr {
134                    lhs == rhs
135                } else {
136                    false
137                }
138            }
139            Self::Keyword(ref lhs) => {
140                matches!(expr, Syntax::Identifier { ident: rhs, bound: false, .. } if lhs == &rhs.name)
141            }
142            Self::List(list) => match_slice(list, expr, expansion_level),
143            Self::Vector(vec) => match_slice(vec, expr, expansion_level),
144            // We shouldn't ever see this outside of lists
145            Self::Null => expr.is_null(),
146            Self::Ellipsis(_) => unreachable!(),
147        }
148    }
149}
150
151fn match_ellipsis(
152    patterns: &[Pattern],
153    exprs: &[Syntax],
154    expansion_level: &mut ExpansionLevel,
155) -> bool {
156    // The ellipsis gets to consume any extra items, thus the difference:
157    let Some(extra_items) = (exprs.len() + 1).checked_sub(patterns.len()) else {
158        return false;
159    };
160
161    let mut expr_iter = exprs.iter();
162    for pattern in patterns.iter() {
163        if let Pattern::Ellipsis(muncher) = pattern {
164            // Gobble up the extra items:
165            for i in 0..extra_items {
166                if expansion_level.expansions.len() <= i {
167                    expansion_level.expansions.push(ExpansionLevel::default());
168                }
169                let expr = expr_iter.next().unwrap();
170                if !muncher.matches(expr, &mut expansion_level.expansions[i]) {
171                    return false;
172                }
173            }
174        } else {
175            // Otherwise, match the pattern normally
176            let expr = expr_iter.next().unwrap();
177            if !pattern.matches(expr, expansion_level) {
178                return false;
179            }
180        }
181    }
182
183    assert!(expr_iter.next().is_none());
184
185    true
186}
187
188fn match_slice(patterns: &[Pattern], expr: &Syntax, expansion_level: &mut ExpansionLevel) -> bool {
189    assert!(!patterns.is_empty());
190
191    let exprs = match expr {
192        Syntax::List { list, .. } => list,
193        Syntax::Null { .. } => return true,
194        _ => return false,
195    };
196
197    let contains_ellipsis = patterns.iter().any(|p| matches!(p, Pattern::Ellipsis(_)));
198
199    match (patterns.split_last().unwrap(), contains_ellipsis) {
200        ((Pattern::Null, _), false) => {
201            // Proper list, no ellipsis. Match everything in order
202            for (pattern, expr) in patterns.iter().zip(exprs.iter()) {
203                if !pattern.matches(expr, expansion_level) {
204                    return false;
205                }
206            }
207            true
208        }
209        ((cdr, head), false) => {
210            // The pattern is an improper list that contains no ellipsis.
211            // Math in order until the last pattern, then match that to the nth
212            // cdr.
213            let mut exprs = exprs.iter();
214            for pattern in head.iter() {
215                let Some(expr) = exprs.next() else {
216                    continue;
217                };
218                if !pattern.matches(expr, expansion_level) {
219                    return false;
220                }
221            }
222            // Match the cdr:
223            let exprs: Vec<_> = exprs.cloned().collect();
224            match exprs.as_slice() {
225                [] => false,
226                [x] => cdr.matches(x, expansion_level),
227                _ => cdr.matches(
228                    &Syntax::new_list(exprs, expr.span().clone()),
229                    expansion_level,
230                ),
231            }
232        }
233        (_, true) => match_ellipsis(patterns, exprs, expansion_level),
234    }
235}
236
237#[derive(Debug, Default)]
238pub struct ExpansionLevel {
239    binds: HashMap<String, Syntax>,
240    expansions: Vec<ExpansionLevel>,
241}
242
243#[derive(Clone, Debug, Trace)]
244pub enum Template {
245    Null,
246    Ellipsis(Box<Template>),
247    List(Vec<Template>),
248    Vector(Vec<Template>),
249    Identifier(Identifier),
250    Variable(Identifier),
251    Literal(Literal),
252}
253
254impl Template {
255    pub fn compile(expr: &Syntax, variables: &HashSet<String>) -> Self {
256        match expr {
257            Syntax::Null { .. } => Self::Null,
258            Syntax::List { list, .. } => Self::List(Self::compile_slice(list, variables)),
259            Syntax::Vector { vector, .. } => Self::Vector(Self::compile_slice(vector, variables)),
260            Syntax::Literal { literal, .. } => Self::Literal(literal.clone()),
261            Syntax::Identifier { ident, .. } if variables.contains(&ident.name) => {
262                Self::Variable(ident.clone())
263            }
264            Syntax::Identifier { ident, .. } => Self::Identifier(ident.clone()),
265        }
266    }
267
268    fn compile_slice(mut expr: &[Syntax], variables: &HashSet<String>) -> Vec<Self> {
269        let mut output = Vec::new();
270        loop {
271            match expr {
272                [] => break,
273                [template, Syntax::Identifier {
274                    ident: ellipsis, ..
275                }, tail @ ..]
276                    if ellipsis.name == "..." =>
277                {
278                    output.push(Self::Ellipsis(Box::new(Template::compile(
279                        template, variables,
280                    ))));
281                    expr = tail;
282                }
283                [head, tail @ ..] => {
284                    output.push(Self::compile(head, variables));
285                    expr = tail;
286                }
287            }
288        }
289        output
290    }
291
292    fn execute(&self, binds: &Binds<'_>, curr_span: Span) -> Option<Syntax> {
293        let syn = match self {
294            Self::Null => Syntax::new_null(curr_span),
295            Self::List(list) => {
296                let executed = execute_slice(list, binds, curr_span.clone())?;
297                Syntax::new_list(executed, curr_span).normalize()
298            }
299            Self::Vector(vec) => {
300                Syntax::new_vector(execute_slice(vec, binds, curr_span.clone())?, curr_span)
301            }
302            Self::Identifier(ident) => Syntax::Identifier {
303                ident: ident.clone(),
304                span: curr_span,
305                bound: false,
306            },
307            Self::Variable(ident) => binds.get_bind(&ident.name)?,
308            Self::Literal(literal) => Syntax::new_literal(literal.clone(), curr_span),
309            _ => unreachable!(),
310        };
311        Some(syn)
312    }
313}
314
315fn execute_slice(items: &[Template], binds: &Binds<'_>, curr_span: Span) -> Option<Vec<Syntax>> {
316    let mut output = Vec::new();
317    for item in items {
318        match item {
319            Template::Ellipsis(template) => {
320                for expansion in &binds.curr_expansion_level.expansions {
321                    let new_level = binds.new_level(expansion);
322                    let Some(result) = template.execute(&new_level, curr_span.clone()) else {
323                        break;
324                    };
325                    output.push(result);
326                }
327            }
328            Template::Null => {
329                if let Some(Syntax::Null { .. }) = output.last() {
330                    continue;
331                } else {
332                    output.push(Syntax::new_null(curr_span.clone()));
333                }
334            }
335            _ => output.push(item.execute(binds, curr_span.clone())?),
336        }
337    }
338    Some(output)
339}
340
341pub struct Binds<'a> {
342    curr_expansion_level: &'a ExpansionLevel,
343    parent_expansion_level: Option<&'a Binds<'a>>,
344}
345
346impl<'a> Binds<'a> {
347    fn new_top(top_expansion_level: &'a ExpansionLevel) -> Self {
348        Self {
349            curr_expansion_level: top_expansion_level,
350            parent_expansion_level: None,
351        }
352    }
353
354    fn new_level<'b: 'a>(&'b self, next_expansion_level: &'b ExpansionLevel) -> Binds<'b> {
355        Binds {
356            curr_expansion_level: next_expansion_level,
357            parent_expansion_level: Some(self),
358        }
359    }
360
361    fn get_bind(&self, name: &str) -> Option<Syntax> {
362        if let bind @ Some(_) = self.curr_expansion_level.binds.get(name) {
363            bind.cloned()
364        } else if let Some(up) = self.parent_expansion_level {
365            up.get_bind(name)
366        } else {
367            None
368        }
369    }
370}
371
372#[builtin("make-variable-transformer")]
373pub async fn make_variable_transformer(
374    _cont: &Option<Arc<Continuation>>,
375    proc: &Gc<Value>,
376) -> Result<Vec<Gc<Value>>, RuntimeError> {
377    let proc = proc.read().await;
378    match &*proc {
379        Value::Procedure(proc) => {
380            let mut proc = proc.clone();
381            proc.is_variable_transformer = true;
382            Ok(vec![Gc::new(Value::Procedure(proc))])
383        }
384        Value::Transformer(transformer) => {
385            let mut transformer = transformer.clone();
386            transformer.is_variable_transformer = true;
387            Ok(vec![Gc::new(Value::Transformer(transformer))])
388        }
389        _ => todo!(),
390    }
391}