1use std::cell::RefCell;
5use std::cmp::Ordering;
6use texlang::token::trace;
7use texlang::traits::*;
8use texlang::*;
9
10pub const ELSE_DOC: &str = "Start the else branch of a conditional or switch statement";
11pub const IFCASE_DOC: &str = "Begin a switch statement";
12pub const IFNUM_DOC: &str = "Compare two variables";
13pub const IFODD_DOC: &str = "Check if a variable is odd";
14pub const IFTRUE_DOC: &str = "Evaluate the true branch";
15pub const IFFALSE_DOC: &str = "Evaluate the false branch";
16pub const FI_DOC: &str = "End a conditional or switch statement";
17pub const OR_DOC: &str = "Begin the next branch of a switch statement";
18
19#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
21pub struct Component {
22 #[cfg_attr(
35 feature = "serde",
36 serde(
37 serialize_with = "serialize_branches",
38 deserialize_with = "deserialize_branches"
39 )
40 )]
41 branches: RefCell<Vec<Branch>>,
42
43 #[cfg_attr(feature = "serde", serde(skip))]
45 tags: Tags,
46}
47
48#[cfg(feature = "serde")]
49fn serialize_branches<S>(input: &RefCell<Vec<Branch>>, serializer: S) -> Result<S::Ok, S::Error>
50where
51 S: serde::Serializer,
52{
53 use serde::Serialize;
54 let slice: &[Branch] = &input.borrow();
55 slice.serialize(serializer)
56}
57
58#[cfg(feature = "serde")]
59fn deserialize_branches<'de, D>(deserializer: D) -> Result<RefCell<Vec<Branch>>, D::Error>
60where
61 D: serde::Deserializer<'de>,
62{
63 use serde::Deserialize;
64 let vec = Vec::<Branch>::deserialize(deserializer)?;
65 Ok(RefCell::new(vec))
66}
67
68struct Tags {
69 if_tag: command::Tag,
70 else_tag: command::Tag,
71 or_tag: command::Tag,
72 fi_tag: command::Tag,
73}
74
75impl Default for Tags {
76 fn default() -> Self {
77 Self {
78 if_tag: IF_TAG.get(),
79 else_tag: ELSE_TAG.get(),
80 or_tag: OR_TAG.get(),
81 fi_tag: FI_TAG.get(),
82 }
83 }
84}
85
86#[derive(Debug)]
87#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
88enum BranchKind {
89 True,
91 Else,
93 Switch,
95}
96
97#[derive(Debug)]
98#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
99struct Branch {
100 _token: token::Token,
101 kind: BranchKind,
102}
103
104impl Component {
105 pub fn new() -> Component {
106 Component {
107 branches: RefCell::new(Vec::new()),
108 tags: Default::default(),
109 }
110 }
111}
112
113impl Default for Component {
114 fn default() -> Self {
115 Self::new()
116 }
117}
118
119fn push_branch<S: HasComponent<Component>>(input: &mut vm::ExpansionInput<S>, branch: Branch) {
120 input.state().component().branches.borrow_mut().push(branch)
121}
122
123fn pop_branch<S: HasComponent<Component>>(input: &mut vm::ExpansionInput<S>) -> Option<Branch> {
124 input.state().component().branches.borrow_mut().pop()
125}
126
127static IF_TAG: command::StaticTag = command::StaticTag::new();
128static ELSE_TAG: command::StaticTag = command::StaticTag::new();
129static OR_TAG: command::StaticTag = command::StaticTag::new();
130static FI_TAG: command::StaticTag = command::StaticTag::new();
131
132fn true_case<S: HasComponent<Component>>(
134 token: token::Token,
135 input: &mut vm::ExpansionInput<S>,
136) -> Result<Vec<token::Token>, Box<error::Error>> {
137 push_branch(
138 input,
139 Branch {
140 _token: token,
141 kind: BranchKind::True,
142 },
143 );
144 Ok(Vec::new())
145}
146
147fn false_case<S: HasComponent<Component>>(
152 original_token: token::Token,
153 input: &mut vm::ExpansionInput<S>,
154) -> Result<Vec<token::Token>, Box<error::Error>> {
155 let mut depth = 0;
156 while let Some(token) = input.unexpanded().next()? {
157 if let token::Value::ControlSequence(name) = &token.value() {
158 let tag = input.commands_map().get_tag(name);
160 if tag == Some(input.state().component().tags.else_tag) && depth == 0 {
161 push_branch(
162 input,
163 Branch {
164 _token: original_token,
165 kind: BranchKind::Else,
166 },
167 );
168 return Ok(Vec::new());
169 }
170 if tag == Some(input.state().component().tags.if_tag) {
171 depth += 1;
172 }
173 if tag == Some(input.state().component().tags.fi_tag) {
174 depth -= 1;
175 if depth < 0 {
176 return Ok(Vec::new());
177 }
178 }
179 }
180 }
181 let branch = pop_branch(input);
182 Err(FalseBranchEndOfInputError {
183 trace: input.vm().trace_end_of_input(),
184 branch,
185 }
186 .into())
187}
188
189#[derive(Debug)]
190struct FalseBranchEndOfInputError {
191 trace: trace::SourceCodeTrace,
192 branch: Option<Branch>,
193}
194
195impl error::TexError for FalseBranchEndOfInputError {
196 fn kind(&self) -> error::Kind {
197 error::Kind::EndOfInput(&self.trace)
198 }
199
200 fn title(&self) -> String {
201 "unexpected end of input while expanding an `if` command".into()
202 }
203
204 fn notes(&self) -> Vec<error::display::Note> {
205 vec![
206 "each `if` command must be terminated by a `fi` command, with an optional `else` in between".into(),
207 "this `if` command evaluated to false, and the input ended while skipping the true branch".into(),
208 "this is the `if` command involved in the error:".into(),
209 format!["{:?}", self.branch].into(),
210 ]
211 }
212}
213
214macro_rules! create_if_primitive {
215 ($if_fn: ident, $if_primitive_fn: ident, $get_if: ident, $docs: expr) => {
216 fn $if_primitive_fn<S: HasComponent<Component>>(
217 token: token::Token,
218 input: &mut vm::ExpansionInput<S>,
219 ) -> Result<Vec<token::Token>, Box<error::Error>> {
220 match $if_fn(input)? {
221 true => true_case(token, input),
222 false => false_case(token, input),
223 }
224 }
225
226 pub fn $get_if<S: HasComponent<Component>>() -> command::BuiltIn<S> {
227 command::BuiltIn::new_expansion($if_primitive_fn)
228 .with_tag(IF_TAG.get())
229 .with_doc($docs)
230 }
231 };
232}
233
234fn if_true<S>(_: &mut vm::ExpansionInput<S>) -> Result<bool, Box<error::Error>> {
235 Ok(true)
236}
237
238fn if_false<S>(_: &mut vm::ExpansionInput<S>) -> Result<bool, Box<error::Error>> {
239 Ok(false)
240}
241
242fn if_num<S: TexlangState>(stream: &mut vm::ExpansionInput<S>) -> Result<bool, Box<error::Error>> {
243 let (a, o, b) = <(i32, Ordering, i32)>::parse(stream)?;
244 Ok(a.cmp(&b) == o)
245}
246
247fn if_odd<S: TexlangState>(stream: &mut vm::ExpansionInput<S>) -> Result<bool, Box<error::Error>> {
248 let n = i32::parse(stream)?;
249 Ok((n % 2) == 1)
250}
251
252create_if_primitive![if_true, if_true_primitive_fn, get_if_true, IFTRUE_DOC];
253create_if_primitive![if_false, if_false_primitive_fn, get_if_false, IFFALSE_DOC];
254create_if_primitive![if_num, if_num_primitive_fn, get_if_num, IFNUM_DOC];
255create_if_primitive![if_odd, if_odd_primitive_fn, get_if_odd, IFODD_DOC];
256
257fn if_case_primitive_fn<S: HasComponent<Component>>(
258 ifcase_token: token::Token,
259 input: &mut vm::ExpansionInput<S>,
260) -> Result<Vec<token::Token>, Box<error::Error>> {
261 let mut cases_to_skip = i32::parse(input)?;
263 if cases_to_skip == 0 {
264 push_branch(
265 input,
266 Branch {
267 _token: ifcase_token,
268 kind: BranchKind::Switch,
269 },
270 );
271 return Ok(Vec::new());
272 }
273 let mut depth = 0;
274 while let Some(token) = input.unexpanded().next()? {
275 if let token::Value::ControlSequence(name) = &token.value() {
276 let tag = input.commands_map().get_tag(name);
278 if tag == Some(input.state().component().tags.or_tag) && depth == 0 {
279 cases_to_skip -= 1;
280 if cases_to_skip == 0 {
281 push_branch(
282 input,
283 Branch {
284 _token: ifcase_token,
285 kind: BranchKind::Switch,
286 },
287 );
288 return Ok(Vec::new());
289 }
290 }
291 if tag == Some(input.state().component().tags.else_tag) && depth == 0 {
292 push_branch(
293 input,
294 Branch {
295 _token: ifcase_token,
296 kind: BranchKind::Else,
297 },
298 );
299 return Ok(Vec::new());
300 }
301 if tag == Some(input.state().component().tags.if_tag) {
302 depth += 1;
303 }
304 if tag == Some(input.state().component().tags.fi_tag) {
305 depth -= 1;
306 if depth < 0 {
307 return Ok(Vec::new());
308 }
309 }
310 }
311 }
312 Err(IfCaseEndOfInputError {
313 trace: input.trace_end_of_input(),
314 }
315 .into())
316}
317
318#[derive(Debug)]
319struct IfCaseEndOfInputError {
320 trace: trace::SourceCodeTrace,
321}
322
323impl error::TexError for IfCaseEndOfInputError {
324 fn kind(&self) -> error::Kind {
325 error::Kind::EndOfInput(&self.trace)
326 }
327
328 fn title(&self) -> String {
329 "unexpected end of input while expanding an `ifcase` command".into()
330 }
331
332 fn notes(&self) -> Vec<error::display::Note> {
333 vec![
334 "each `ifcase` command must be matched by a `or`, `else` or `fi` command".into(),
335 "this `ifcase` case evaluated to %d and we skipped %d cases before the input ran out"
336 .into(),
337 "this is the `ifnum` command involved in the error:".into(),
338 ]
339 }
340}
341
342pub fn get_if_case<S: HasComponent<Component>>() -> command::BuiltIn<S> {
344 command::BuiltIn::new_expansion(if_case_primitive_fn).with_tag(IF_TAG.get())
345}
346
347fn or_primitive_fn<S: HasComponent<Component>>(
348 ifcase_token: token::Token,
349 input: &mut vm::ExpansionInput<S>,
350) -> Result<Vec<token::Token>, Box<error::Error>> {
351 let branch = pop_branch(input);
352 let is_valid = match branch {
354 None => false,
355 Some(branch) => matches!(branch.kind, BranchKind::Switch),
356 };
357 if !is_valid {
358 return Err(error::SimpleTokenError::new(
359 input.vm(),
360 ifcase_token,
361 "unexpected `or` command",
362 )
363 .into());
364 }
365
366 let mut depth = 0;
367 while let Some(token) = input.unexpanded().next()? {
368 if let token::Value::ControlSequence(name) = &token.value() {
369 let tag = input.commands_map().get_tag(name);
370 if tag == Some(input.state().component().tags.if_tag) {
371 depth += 1;
372 }
373 if tag == Some(input.state().component().tags.fi_tag) {
374 depth -= 1;
375 if depth < 0 {
376 return Ok(Vec::new());
377 }
378 }
379 }
380 }
381 Err(OrEndOfInputError {
382 trace: input.vm().trace_end_of_input(),
383 }
384 .into())
385}
386
387#[derive(Debug)]
388struct OrEndOfInputError {
389 trace: trace::SourceCodeTrace,
390}
391
392impl error::TexError for OrEndOfInputError {
393 fn kind(&self) -> error::Kind {
394 error::Kind::EndOfInput(&self.trace)
395 }
396
397 fn title(&self) -> String {
398 "unexpected end of input while expanding an `or` command".into()
399 }
400
401 fn notes(&self) -> Vec<error::display::Note> {
402 vec![
403 "each `or` command must be terminated by a `fi` command".into(),
404 "this `or` corresponds to an `ifcase` command that evaluated to %d, and the input ended while skipping the remaining cases".into(),
405 "this is the `ifcase` command involved in the error:".into(),
406 "this is the `or` command involved in the error:".into(),
407 ]
408 }
409}
410
411pub fn get_or<S: HasComponent<Component>>() -> command::BuiltIn<S> {
413 command::BuiltIn::new_expansion(or_primitive_fn).with_tag(OR_TAG.get())
414}
415
416fn else_primitive_fn<S: HasComponent<Component>>(
417 else_token: token::Token,
418 input: &mut vm::ExpansionInput<S>,
419) -> Result<Vec<token::Token>, Box<error::Error>> {
420 let branch = pop_branch(input);
421 let is_valid = match branch {
423 None => false,
424 Some(branch) => matches!(branch.kind, BranchKind::True | BranchKind::Switch),
425 };
426 if !is_valid {
427 return Err(error::SimpleTokenError::new(
428 input.vm(),
429 else_token,
430 "unexpected `else` command",
431 )
432 .into());
433 }
434
435 let mut depth = 0;
437 while let Some(token) = input.unexpanded().next()? {
438 if let token::Value::ControlSequence(name) = &token.value() {
439 let tag = input.commands_map().get_tag(name);
441 if tag == Some(input.state().component().tags.if_tag) {
442 depth += 1;
443 }
444 if tag == Some(input.state().component().tags.fi_tag) {
445 depth -= 1;
446 if depth < 0 {
447 return Ok(Vec::new());
448 }
449 }
450 }
451 }
452 Err(ElseEndOfInputError {
453 trace: input.vm().trace_end_of_input(),
454 }
455 .into())
456}
457
458#[derive(Debug)]
459struct ElseEndOfInputError {
460 trace: trace::SourceCodeTrace,
461}
462
463impl error::TexError for ElseEndOfInputError {
464 fn kind(&self) -> error::Kind {
465 error::Kind::EndOfInput(&self.trace)
466 }
467
468 fn title(&self) -> String {
469 "unexpected end of input while expanding an `else` command".into()
470 }
471
472 fn notes(&self) -> Vec<error::display::Note> {
473 vec![
474 "each `else` command must be terminated by a `fi` command".into(),
475 "this `else` corresponds to an `if` command that evaluated to true, and the input ended while skipping the false branch".into(),
476 "this is the `if` command involved in the error:".into(),
477 "this is the `else` command involved in the error:".into(),
478 ]
479 }
480}
481
482pub fn get_else<S: HasComponent<Component>>() -> command::BuiltIn<S> {
484 command::BuiltIn::new_expansion(else_primitive_fn).with_tag(ELSE_TAG.get())
485}
486
487fn fi_primitive_fn<S: HasComponent<Component>>(
489 token: token::Token,
490 input: &mut vm::ExpansionInput<S>,
491) -> Result<Vec<token::Token>, Box<error::Error>> {
492 let branch = pop_branch(input);
493 if branch.is_none() {
498 return Err(
499 error::SimpleTokenError::new(input.vm(), token, "unexpected `fi` command").into(),
500 );
501 }
502 Ok(Vec::new())
503}
504
505pub fn get_fi<S: HasComponent<Component>>() -> command::BuiltIn<S> {
506 command::BuiltIn::new_expansion(fi_primitive_fn).with_tag(FI_TAG.get())
507}
508
509#[cfg(test)]
510mod tests {
511 use std::collections::HashMap;
512
513 use super::*;
514 use crate::{script, testing::*};
515 use texlang::vm::implement_has_component;
516
517 #[derive(Default)]
518 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
519 struct State {
520 conditional: Component,
521 exec: script::Component,
522 }
523
524 impl TexlangState for State {}
525
526 implement_has_component![State, (Component, conditional), (script::Component, exec),];
527
528 fn initial_commands() -> HashMap<&'static str, command::BuiltIn<State>> {
529 HashMap::from([
530 ("else", get_else()),
531 ("fi", get_fi()),
532 ("ifcase", get_if_case()),
533 ("iffalse", get_if_false()),
534 ("ifnum", get_if_num()),
535 ("ifodd", get_if_odd()),
536 ("iftrue", get_if_true()),
537 ("or", get_or()),
538 ])
539 }
540
541 test_suite![
542 expansion_equality_tests(
543 (iftrue_base_case, r"\iftrue a\else b\fi c", r"ac"),
544 (iftrue_no_else, r"\iftrue a\fi c", r"ac"),
545 (
546 iftrue_skip_nested_ifs,
547 r"\iftrue a\else b\iftrue \else c\fi d\fi e",
548 r"ae"
549 ),
550 (iffalse_base_case, r"\iffalse a\else b\fi c", r"bc"),
551 (iffalse_no_else, r"\iffalse a\fi c", r"c"),
552 (
553 iffalse_skip_nested_ifs,
554 r"\iffalse \iftrue a\else b\fi c\else d\fi e",
555 r"de"
556 ),
557 (
558 iffalse_and_iftrue_1,
559 r"\iffalse a\else b\iftrue c\else d\fi e\fi f",
560 r"bcef"
561 ),
562 (
563 iffalse_and_iftrue_2,
564 r"\iftrue a\iffalse b\else c\fi d\else e\fi f",
565 r"acdf"
566 ),
567 (ifnum_less_than_true, r"\ifnum 4<5a\else b\fi c", r"ac"),
568 (ifnum_less_than_false, r"\ifnum 5<4a\else b\fi c", r"bc"),
569 (ifnum_equal_true, r"\ifnum 4=4a\else b\fi c", r"ac"),
570 (ifnum_equal_false, r"\ifnum 5=4a\else b\fi c", r"bc"),
571 (ifnum_greater_than_true, r"\ifnum 5>4a\else b\fi c", r"ac"),
572 (ifnum_greater_than_false, r"\ifnum 4>5a\else b\fi c", r"bc"),
573 (ifodd_odd, r"\ifodd 3a\else b\fi c", r"ac"),
574 (ifodd_even, r"\ifodd 4a\else b\fi c", r"bc"),
575 (ifcase_zero_no_ors, r"\ifcase 0 a\else b\fi c", r"ac"),
576 (ifcase_zero_one_or, r"\ifcase 0 a\or b\else c\fi d", r"ad"),
577 (ifcase_one, r"\ifcase 1 a\or b\else c\fi d", r"bd"),
578 (
579 ifcase_one_more_cases,
580 r"\ifcase 1 a\or b\or c\else d\fi e",
581 r"be"
582 ),
583 (ifcase_else_no_ors, r"\ifcase 1 a\else b\fi c", r"bc"),
584 (ifcase_else_one_or, r"\ifcase 2 a\or b\else c\fi d", r"cd"),
585 (ifcase_no_matching_case, r"\ifcase 3 a\or b\or c\fi d", r"d"),
586 (
587 ifcase_nested,
588 r"\ifcase 1 a\or b\ifcase 1 c\or d\or e\else f\fi g\or h\fi i",
589 r"bdgi"
590 ),
591 ),
592 serde_tests(
593 (serde_if, r"\iftrue true ", r"branch \else false branch \fi"),
594 (
595 serde_ifcase,
596 r"\ifcase 2 a\or b\or executed ",
597 r"case \or d \fi"
598 )
599 ),
600 failure_tests(
601 (iftrue_end_of_input, r"\iftrue a\else b"),
602 (iffalse_end_of_input, r"\iffalse a"),
603 (else_not_expected, r"a\else"),
604 (fi_not_expected, r"a\fi"),
605 (or_not_expected, r"a\or"),
606 ),
607 ];
608}