texlang_stdlib/
conditional.rs

1//! Control flow primitives (if, else, switch)
2//!
3
4use std::cell::RefCell;
5use std::cmp::Ordering;
6use texlang::token::trace;
7use texlang::traits::*;
8use texlang::*;
9
10pub const ELSE_DOC: &str = "Start the else branch of a conditional or switch statement";
11pub const IFCASE_DOC: &str = "Begin a switch statement";
12pub const IFNUM_DOC: &str = "Compare two variables";
13pub const IFODD_DOC: &str = "Check if a variable is odd";
14pub const IFTRUE_DOC: &str = "Evaluate the true branch";
15pub const IFFALSE_DOC: &str = "Evaluate the false branch";
16pub const FI_DOC: &str = "End a conditional or switch statement";
17pub const OR_DOC: &str = "Begin the next branch of a switch statement";
18
19/// A component for keeping track of conditional branches as they are expanded.
20#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
21pub struct Component {
22    // Branches is a stack where each element corresponds to a conditional that is currently
23    // expanding. A nested conditional is further up the stack than the conditional it is
24    // nested in.
25    //
26    // This stack is used to
27    // verify that \else and \fi tokens are valid; i.e., if a \else is encountered, the current
28    // conditional must be true otherwise the \else is invalid. For correct TeX code, the stack
29    // is never actually used.
30    //
31    // Because the conditional commands are expansion commands, they cannot get a mutable reference
32    // to the state. We thus wrap the branches in a ref cell to support mutating them through
33    // an immutable reference.
34    #[cfg_attr(
35        feature = "serde",
36        serde(
37            serialize_with = "serialize_branches",
38            deserialize_with = "deserialize_branches"
39        )
40    )]
41    branches: RefCell<Vec<Branch>>,
42
43    // We cache the tag values inside the component for performance reasons.
44    #[cfg_attr(feature = "serde", serde(skip))]
45    tags: Tags,
46}
47
48#[cfg(feature = "serde")]
49fn serialize_branches<S>(input: &RefCell<Vec<Branch>>, serializer: S) -> Result<S::Ok, S::Error>
50where
51    S: serde::Serializer,
52{
53    use serde::Serialize;
54    let slice: &[Branch] = &input.borrow();
55    slice.serialize(serializer)
56}
57
58#[cfg(feature = "serde")]
59fn deserialize_branches<'de, D>(deserializer: D) -> Result<RefCell<Vec<Branch>>, D::Error>
60where
61    D: serde::Deserializer<'de>,
62{
63    use serde::Deserialize;
64    let vec = Vec::<Branch>::deserialize(deserializer)?;
65    Ok(RefCell::new(vec))
66}
67
68struct Tags {
69    if_tag: command::Tag,
70    else_tag: command::Tag,
71    or_tag: command::Tag,
72    fi_tag: command::Tag,
73}
74
75impl Default for Tags {
76    fn default() -> Self {
77        Self {
78            if_tag: IF_TAG.get(),
79            else_tag: ELSE_TAG.get(),
80            or_tag: OR_TAG.get(),
81            fi_tag: FI_TAG.get(),
82        }
83    }
84}
85
86#[derive(Debug)]
87#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
88enum BranchKind {
89    // The true branch of an if conditional.
90    True,
91    // The false branch of an if conditional, or the default branch of a switch statement.
92    Else,
93    // A regular case brach of a switch statement.
94    Switch,
95}
96
97#[derive(Debug)]
98#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
99struct Branch {
100    _token: token::Token,
101    kind: BranchKind,
102}
103
104impl Component {
105    pub fn new() -> Component {
106        Component {
107            branches: RefCell::new(Vec::new()),
108            tags: Default::default(),
109        }
110    }
111}
112
113impl Default for Component {
114    fn default() -> Self {
115        Self::new()
116    }
117}
118
119fn push_branch<S: HasComponent<Component>>(input: &mut vm::ExpansionInput<S>, branch: Branch) {
120    input.state().component().branches.borrow_mut().push(branch)
121}
122
123fn pop_branch<S: HasComponent<Component>>(input: &mut vm::ExpansionInput<S>) -> Option<Branch> {
124    input.state().component().branches.borrow_mut().pop()
125}
126
127static IF_TAG: command::StaticTag = command::StaticTag::new();
128static ELSE_TAG: command::StaticTag = command::StaticTag::new();
129static OR_TAG: command::StaticTag = command::StaticTag::new();
130static FI_TAG: command::StaticTag = command::StaticTag::new();
131
132// The `true_case` function is executed whenever a conditional evaluates to true.
133fn true_case<S: HasComponent<Component>>(
134    token: token::Token,
135    input: &mut vm::ExpansionInput<S>,
136) -> Result<Vec<token::Token>, Box<error::Error>> {
137    push_branch(
138        input,
139        Branch {
140            _token: token,
141            kind: BranchKind::True,
142        },
143    );
144    Ok(Vec::new())
145}
146
147// The `false_case` function is executed whenever a conditional evaluates to false.
148//
149// The function scans forward in the input stream, discarding all tokens, until it encounters
150// either a \else or \fi command.
151fn false_case<S: HasComponent<Component>>(
152    original_token: token::Token,
153    input: &mut vm::ExpansionInput<S>,
154) -> Result<Vec<token::Token>, Box<error::Error>> {
155    let mut depth = 0;
156    while let Some(token) = input.unexpanded().next()? {
157        if let token::Value::ControlSequence(name) = &token.value() {
158            // TODO: use switch
159            let tag = input.commands_map().get_tag(name);
160            if tag == Some(input.state().component().tags.else_tag) && depth == 0 {
161                push_branch(
162                    input,
163                    Branch {
164                        _token: original_token,
165                        kind: BranchKind::Else,
166                    },
167                );
168                return Ok(Vec::new());
169            }
170            if tag == Some(input.state().component().tags.if_tag) {
171                depth += 1;
172            }
173            if tag == Some(input.state().component().tags.fi_tag) {
174                depth -= 1;
175                if depth < 0 {
176                    return Ok(Vec::new());
177                }
178            }
179        }
180    }
181    let branch = pop_branch(input);
182    Err(FalseBranchEndOfInputError {
183        trace: input.vm().trace_end_of_input(),
184        branch,
185    }
186    .into())
187}
188
189#[derive(Debug)]
190struct FalseBranchEndOfInputError {
191    trace: trace::SourceCodeTrace,
192    branch: Option<Branch>,
193}
194
195impl error::TexError for FalseBranchEndOfInputError {
196    fn kind(&self) -> error::Kind {
197        error::Kind::EndOfInput(&self.trace)
198    }
199
200    fn title(&self) -> String {
201        "unexpected end of input while expanding an `if` command".into()
202    }
203
204    fn notes(&self) -> Vec<error::display::Note> {
205        vec![
206            "each `if` command must be terminated by a `fi` command, with an optional `else` in between".into(),
207            "this `if` command evaluated to false, and the input ended while skipping the true branch".into(),
208            "this is the `if` command involved in the error:".into(),
209            format!["{:?}", self.branch].into(),
210        ]
211    }
212}
213
214macro_rules! create_if_primitive {
215    ($if_fn: ident, $if_primitive_fn: ident, $get_if: ident, $docs: expr) => {
216        fn $if_primitive_fn<S: HasComponent<Component>>(
217            token: token::Token,
218            input: &mut vm::ExpansionInput<S>,
219        ) -> Result<Vec<token::Token>, Box<error::Error>> {
220            match $if_fn(input)? {
221                true => true_case(token, input),
222                false => false_case(token, input),
223            }
224        }
225
226        pub fn $get_if<S: HasComponent<Component>>() -> command::BuiltIn<S> {
227            command::BuiltIn::new_expansion($if_primitive_fn)
228                .with_tag(IF_TAG.get())
229                .with_doc($docs)
230        }
231    };
232}
233
234fn if_true<S>(_: &mut vm::ExpansionInput<S>) -> Result<bool, Box<error::Error>> {
235    Ok(true)
236}
237
238fn if_false<S>(_: &mut vm::ExpansionInput<S>) -> Result<bool, Box<error::Error>> {
239    Ok(false)
240}
241
242fn if_num<S: TexlangState>(stream: &mut vm::ExpansionInput<S>) -> Result<bool, Box<error::Error>> {
243    let (a, o, b) = <(i32, Ordering, i32)>::parse(stream)?;
244    Ok(a.cmp(&b) == o)
245}
246
247fn if_odd<S: TexlangState>(stream: &mut vm::ExpansionInput<S>) -> Result<bool, Box<error::Error>> {
248    let n = i32::parse(stream)?;
249    Ok((n % 2) == 1)
250}
251
252create_if_primitive![if_true, if_true_primitive_fn, get_if_true, IFTRUE_DOC];
253create_if_primitive![if_false, if_false_primitive_fn, get_if_false, IFFALSE_DOC];
254create_if_primitive![if_num, if_num_primitive_fn, get_if_num, IFNUM_DOC];
255create_if_primitive![if_odd, if_odd_primitive_fn, get_if_odd, IFODD_DOC];
256
257fn if_case_primitive_fn<S: HasComponent<Component>>(
258    ifcase_token: token::Token,
259    input: &mut vm::ExpansionInput<S>,
260) -> Result<Vec<token::Token>, Box<error::Error>> {
261    // TODO: should we reading the number from the unexpanded stream? Probably!
262    let mut cases_to_skip = i32::parse(input)?;
263    if cases_to_skip == 0 {
264        push_branch(
265            input,
266            Branch {
267                _token: ifcase_token,
268                kind: BranchKind::Switch,
269            },
270        );
271        return Ok(Vec::new());
272    }
273    let mut depth = 0;
274    while let Some(token) = input.unexpanded().next()? {
275        if let token::Value::ControlSequence(name) = &token.value() {
276            // TODO: switch
277            let tag = input.commands_map().get_tag(name);
278            if tag == Some(input.state().component().tags.or_tag) && depth == 0 {
279                cases_to_skip -= 1;
280                if cases_to_skip == 0 {
281                    push_branch(
282                        input,
283                        Branch {
284                            _token: ifcase_token,
285                            kind: BranchKind::Switch,
286                        },
287                    );
288                    return Ok(Vec::new());
289                }
290            }
291            if tag == Some(input.state().component().tags.else_tag) && depth == 0 {
292                push_branch(
293                    input,
294                    Branch {
295                        _token: ifcase_token,
296                        kind: BranchKind::Else,
297                    },
298                );
299                return Ok(Vec::new());
300            }
301            if tag == Some(input.state().component().tags.if_tag) {
302                depth += 1;
303            }
304            if tag == Some(input.state().component().tags.fi_tag) {
305                depth -= 1;
306                if depth < 0 {
307                    return Ok(Vec::new());
308                }
309            }
310        }
311    }
312    Err(IfCaseEndOfInputError {
313        trace: input.trace_end_of_input(),
314    }
315    .into())
316}
317
318#[derive(Debug)]
319struct IfCaseEndOfInputError {
320    trace: trace::SourceCodeTrace,
321}
322
323impl error::TexError for IfCaseEndOfInputError {
324    fn kind(&self) -> error::Kind {
325        error::Kind::EndOfInput(&self.trace)
326    }
327
328    fn title(&self) -> String {
329        "unexpected end of input while expanding an `ifcase` command".into()
330    }
331
332    fn notes(&self) -> Vec<error::display::Note> {
333        vec![
334            "each `ifcase` command must be matched by a `or`, `else` or `fi` command".into(),
335            "this `ifcase` case evaluated to %d and we skipped %d cases before the input ran out"
336                .into(),
337            "this is the `ifnum` command involved in the error:".into(),
338        ]
339    }
340}
341
342/// Get the `\ifcase` primitive.
343pub fn get_if_case<S: HasComponent<Component>>() -> command::BuiltIn<S> {
344    command::BuiltIn::new_expansion(if_case_primitive_fn).with_tag(IF_TAG.get())
345}
346
347fn or_primitive_fn<S: HasComponent<Component>>(
348    ifcase_token: token::Token,
349    input: &mut vm::ExpansionInput<S>,
350) -> Result<Vec<token::Token>, Box<error::Error>> {
351    let branch = pop_branch(input);
352    // For an or command to be valid, we must be in a switch statement
353    let is_valid = match branch {
354        None => false,
355        Some(branch) => matches!(branch.kind, BranchKind::Switch),
356    };
357    if !is_valid {
358        return Err(error::SimpleTokenError::new(
359            input.vm(),
360            ifcase_token,
361            "unexpected `or` command",
362        )
363        .into());
364    }
365
366    let mut depth = 0;
367    while let Some(token) = input.unexpanded().next()? {
368        if let token::Value::ControlSequence(name) = &token.value() {
369            let tag = input.commands_map().get_tag(name);
370            if tag == Some(input.state().component().tags.if_tag) {
371                depth += 1;
372            }
373            if tag == Some(input.state().component().tags.fi_tag) {
374                depth -= 1;
375                if depth < 0 {
376                    return Ok(Vec::new());
377                }
378            }
379        }
380    }
381    Err(OrEndOfInputError {
382        trace: input.vm().trace_end_of_input(),
383    }
384    .into())
385}
386
387#[derive(Debug)]
388struct OrEndOfInputError {
389    trace: trace::SourceCodeTrace,
390}
391
392impl error::TexError for OrEndOfInputError {
393    fn kind(&self) -> error::Kind {
394        error::Kind::EndOfInput(&self.trace)
395    }
396
397    fn title(&self) -> String {
398        "unexpected end of input while expanding an `or` command".into()
399    }
400
401    fn notes(&self) -> Vec<error::display::Note> {
402        vec![
403        "each `or` command must be terminated by a `fi` command".into(),
404        "this `or` corresponds to an `ifcase` command that evaluated to %d, and the input ended while skipping the remaining cases".into(),
405        "this is the `ifcase` command involved in the error:".into(),
406        "this is the `or` command involved in the error:".into(),
407        ]
408    }
409}
410
411/// Get the `\or` primitive.
412pub fn get_or<S: HasComponent<Component>>() -> command::BuiltIn<S> {
413    command::BuiltIn::new_expansion(or_primitive_fn).with_tag(OR_TAG.get())
414}
415
416fn else_primitive_fn<S: HasComponent<Component>>(
417    else_token: token::Token,
418    input: &mut vm::ExpansionInput<S>,
419) -> Result<Vec<token::Token>, Box<error::Error>> {
420    let branch = pop_branch(input);
421    // For else token to be valid, we must be in the true branch of a conditional
422    let is_valid = match branch {
423        None => false,
424        Some(branch) => matches!(branch.kind, BranchKind::True | BranchKind::Switch),
425    };
426    if !is_valid {
427        return Err(error::SimpleTokenError::new(
428            input.vm(),
429            else_token,
430            "unexpected `else` command",
431        )
432        .into());
433    }
434
435    // Now consume all of the tokens until the next \fi
436    let mut depth = 0;
437    while let Some(token) = input.unexpanded().next()? {
438        if let token::Value::ControlSequence(name) = &token.value() {
439            // TODO: switch
440            let tag = input.commands_map().get_tag(name);
441            if tag == Some(input.state().component().tags.if_tag) {
442                depth += 1;
443            }
444            if tag == Some(input.state().component().tags.fi_tag) {
445                depth -= 1;
446                if depth < 0 {
447                    return Ok(Vec::new());
448                }
449            }
450        }
451    }
452    Err(ElseEndOfInputError {
453        trace: input.vm().trace_end_of_input(),
454    }
455    .into())
456}
457
458#[derive(Debug)]
459struct ElseEndOfInputError {
460    trace: trace::SourceCodeTrace,
461}
462
463impl error::TexError for ElseEndOfInputError {
464    fn kind(&self) -> error::Kind {
465        error::Kind::EndOfInput(&self.trace)
466    }
467
468    fn title(&self) -> String {
469        "unexpected end of input while expanding an `else` command".into()
470    }
471
472    fn notes(&self) -> Vec<error::display::Note> {
473        vec![
474            "each `else` command must be terminated by a `fi` command".into(),
475            "this `else` corresponds to an `if` command that evaluated to true, and the input ended while skipping the false branch".into(),
476            "this is the `if` command involved in the error:".into(),
477            "this is the `else` command involved in the error:".into(),
478        ]
479    }
480}
481
482/// Get the `\else` primitive.
483pub fn get_else<S: HasComponent<Component>>() -> command::BuiltIn<S> {
484    command::BuiltIn::new_expansion(else_primitive_fn).with_tag(ELSE_TAG.get())
485}
486
487/// Get the `\fi` primitive.
488fn fi_primitive_fn<S: HasComponent<Component>>(
489    token: token::Token,
490    input: &mut vm::ExpansionInput<S>,
491) -> Result<Vec<token::Token>, Box<error::Error>> {
492    let branch = pop_branch(input);
493    // For a \fi primitive to be valid, we must be in a conditional.
494    // Note that we could be in the false branch: \iftrue\else\fi
495    // Or in the true branch: \iftrue\fi
496    // Or in a switch statement.
497    if branch.is_none() {
498        return Err(
499            error::SimpleTokenError::new(input.vm(), token, "unexpected `fi` command").into(),
500        );
501    }
502    Ok(Vec::new())
503}
504
505pub fn get_fi<S: HasComponent<Component>>() -> command::BuiltIn<S> {
506    command::BuiltIn::new_expansion(fi_primitive_fn).with_tag(FI_TAG.get())
507}
508
509#[cfg(test)]
510mod tests {
511    use std::collections::HashMap;
512
513    use super::*;
514    use crate::{script, testing::*};
515    use texlang::vm::implement_has_component;
516
517    #[derive(Default)]
518    #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
519    struct State {
520        conditional: Component,
521        exec: script::Component,
522    }
523
524    impl TexlangState for State {}
525
526    implement_has_component![State, (Component, conditional), (script::Component, exec),];
527
528    fn initial_commands() -> HashMap<&'static str, command::BuiltIn<State>> {
529        HashMap::from([
530            ("else", get_else()),
531            ("fi", get_fi()),
532            ("ifcase", get_if_case()),
533            ("iffalse", get_if_false()),
534            ("ifnum", get_if_num()),
535            ("ifodd", get_if_odd()),
536            ("iftrue", get_if_true()),
537            ("or", get_or()),
538        ])
539    }
540
541    test_suite![
542        expansion_equality_tests(
543            (iftrue_base_case, r"\iftrue a\else b\fi c", r"ac"),
544            (iftrue_no_else, r"\iftrue a\fi c", r"ac"),
545            (
546                iftrue_skip_nested_ifs,
547                r"\iftrue a\else b\iftrue \else c\fi d\fi e",
548                r"ae"
549            ),
550            (iffalse_base_case, r"\iffalse a\else b\fi c", r"bc"),
551            (iffalse_no_else, r"\iffalse a\fi c", r"c"),
552            (
553                iffalse_skip_nested_ifs,
554                r"\iffalse \iftrue a\else b\fi c\else d\fi e",
555                r"de"
556            ),
557            (
558                iffalse_and_iftrue_1,
559                r"\iffalse a\else b\iftrue c\else d\fi e\fi f",
560                r"bcef"
561            ),
562            (
563                iffalse_and_iftrue_2,
564                r"\iftrue a\iffalse b\else c\fi d\else e\fi f",
565                r"acdf"
566            ),
567            (ifnum_less_than_true, r"\ifnum 4<5a\else b\fi c", r"ac"),
568            (ifnum_less_than_false, r"\ifnum 5<4a\else b\fi c", r"bc"),
569            (ifnum_equal_true, r"\ifnum 4=4a\else b\fi c", r"ac"),
570            (ifnum_equal_false, r"\ifnum 5=4a\else b\fi c", r"bc"),
571            (ifnum_greater_than_true, r"\ifnum 5>4a\else b\fi c", r"ac"),
572            (ifnum_greater_than_false, r"\ifnum 4>5a\else b\fi c", r"bc"),
573            (ifodd_odd, r"\ifodd 3a\else b\fi c", r"ac"),
574            (ifodd_even, r"\ifodd 4a\else b\fi c", r"bc"),
575            (ifcase_zero_no_ors, r"\ifcase 0 a\else b\fi c", r"ac"),
576            (ifcase_zero_one_or, r"\ifcase 0 a\or b\else c\fi d", r"ad"),
577            (ifcase_one, r"\ifcase 1 a\or b\else c\fi d", r"bd"),
578            (
579                ifcase_one_more_cases,
580                r"\ifcase 1 a\or b\or c\else d\fi e",
581                r"be"
582            ),
583            (ifcase_else_no_ors, r"\ifcase 1 a\else b\fi c", r"bc"),
584            (ifcase_else_one_or, r"\ifcase 2 a\or b\else c\fi d", r"cd"),
585            (ifcase_no_matching_case, r"\ifcase 3 a\or b\or c\fi d", r"d"),
586            (
587                ifcase_nested,
588                r"\ifcase 1 a\or b\ifcase 1 c\or d\or e\else f\fi g\or h\fi i",
589                r"bdgi"
590            ),
591        ),
592        serde_tests(
593            (serde_if, r"\iftrue true ", r"branch \else false branch \fi"),
594            (
595                serde_ifcase,
596                r"\ifcase 2 a\or b\or executed ",
597                r"case \or d \fi"
598            )
599        ),
600        failure_tests(
601            (iftrue_end_of_input, r"\iftrue a\else b"),
602            (iffalse_end_of_input, r"\iffalse a"),
603            (else_not_expected, r"a\else"),
604            (fi_not_expected, r"a\fi"),
605            (or_not_expected, r"a\or"),
606        ),
607    ];
608}