tex2typst_rs/
command_registry.rs

1use crate::definitions::{TexToken, TexTokenType};
2use crate::tex_tokenizer::tokenize;
3use std::collections::HashMap;
4
5pub const UNARY_COMMANDS: &[&'static str] = &[
6    "text",
7    "bar",
8    "bold",
9    "boldsymbol",
10    "ddot",
11    "dot",
12    "hat",
13    "mathbb",
14    "mathbf",
15    "mathcal",
16    "mathfrak",
17    "mathit",
18    "mathrm",
19    "mathscr",
20    "mathsf",
21    "mathtt",
22    "operatorname",
23    "overbrace",
24    "overline",
25    "pmb",
26    "rm",
27    "tilde",
28    "underbrace",
29    "underline",
30    "vec",
31    "overrightarrow",
32    "widehat",
33    "widetilde",
34    "floor", // This is a custom macro
35];
36
37pub const BINARY_COMMANDS: &[&'static str] = &["frac", "tfrac", "binom", "dbinom", "dfrac", "tbinom", "overset"];
38
39pub const OPTION_UNARY_COMMANDS: &[&'static str] = &[];
40
41pub const OPTION_BINARY_COMMANDS: &[&'static str] = &["sqrt"];
42
43pub type ExpandResult = Result<(Vec<TexToken>, usize), String>;
44
45#[derive(Debug, PartialEq, Copy, Clone)]
46pub enum CommandType {
47    Symbol,
48    Unary,
49    Binary,
50    OptionalUnary,
51    OptionalBinary,
52}
53
54pub struct CustomMacro {
55    pub name: String,
56    pub command_type: CommandType,
57    pub implementation: Box<dyn Fn(&Vec<Vec<TexToken>>) -> Result<Vec<TexToken>, String>>,
58}
59
60#[derive(Default)]
61pub struct CommandRegistry {
62    pub custom_macros: Vec<CustomMacro>,
63    pub custom_macro_names: HashMap<String, CommandType>,
64}
65
66impl CommandRegistry {
67    pub fn new() -> CommandRegistry {
68        Self::default()
69    }
70
71    pub fn register_custom_macro(
72        &mut self,
73        name: &str,
74        command_type: CommandType,
75        implementation: Box<dyn Fn(&Vec<Vec<TexToken>>) -> Result<Vec<TexToken>, String>>,
76    ) {
77        self.custom_macros.push(CustomMacro {
78            name: name.to_string(),
79            command_type,
80            implementation,
81        });
82        self.custom_macro_names.insert(name.to_string(), command_type);
83    }
84
85    pub fn register_custom_macros(&mut self, custom_macros: Vec<CustomMacro>) {
86        for custom_macro in custom_macros {
87            self.custom_macro_names
88                .insert(custom_macro.name.clone(), custom_macro.command_type);
89            self.custom_macros.push(custom_macro);
90        }
91    }
92
93    pub fn get_command_type(&self, command_name: &str) -> Option<CommandType> {
94        if UNARY_COMMANDS.contains(&command_name) {
95            Some(CommandType::Unary)
96        } else if BINARY_COMMANDS.contains(&command_name) {
97            Some(CommandType::Binary)
98        } else if OPTION_BINARY_COMMANDS.contains(&command_name) {
99            Some(CommandType::OptionalBinary)
100        } else if self.custom_macro_names.contains_key(command_name) {
101            self.custom_macro_names.get(command_name).copied()
102        } else {
103            // fallback to symbol (no arguments)
104            Some(CommandType::Symbol)
105        }
106    }
107
108    pub fn expand_macros(&self, tokens: &[TexToken]) -> Result<Vec<TexToken>, String> {
109        let mut expanded_tokens: Vec<TexToken> = Vec::new();
110        let mut pos: usize = 0;
111
112        while pos < tokens.len() {
113            let token = &tokens[pos];
114            if token.token_type == TexTokenType::Command {
115                if let Some(custom_macro) = self.custom_macros.iter().find(|macro_| macro_.name == token.value) {
116                    let (expanded_command, new_pos) = self.expand_command(tokens, custom_macro, pos)?;
117                    expanded_tokens.extend(expanded_command);
118                    pos = new_pos;
119                } else {
120                    expanded_tokens.push(token.clone());
121                    pos += 1;
122                }
123            } else {
124                expanded_tokens.push(token.clone());
125                pos += 1;
126            }
127        }
128        Ok(expanded_tokens)
129    }
130
131    // this will get called recursively
132    fn expand_command(&self, tokens: &[TexToken], custom_macro: &CustomMacro, start: usize) -> ExpandResult {
133        let command_name = &tokens[start].value; // starts with \
134        assert_eq!(command_name, &custom_macro.name);
135        let command_type = custom_macro.command_type;
136        let mut pos = start + 1; // come to what comes after the command
137        let mut arguments: Vec<Vec<TexToken>> = Vec::new();
138
139        match command_type {
140            CommandType::Symbol => {
141                // no arguments, don't move the pos
142            }
143            CommandType::Unary => {
144                if !tokens[pos].value.eq("{") {
145                    return Err(format!("Expecting one argument for command {}", command_name));
146                }
147                pos += 1;
148                if let Some(right_curly_bracket_pos) = find_matching_right_curly_bracket_token(tokens, pos) {
149                    let argument: &[TexToken] = &tokens[pos..right_curly_bracket_pos];
150                    arguments.push(self.expand_macros(argument)?);
151                    pos = right_curly_bracket_pos + 1;
152                } else {
153                    return Err(format!("Unmatched curly brackets for command {}", command_name));
154                }
155            }
156            CommandType::Binary => {
157                if !tokens[pos].value.eq("{") {
158                    return Err(format!("No argument provided for command {}", command_name));
159                }
160                pos += 1;
161                if let Some(right_curly_bracket_pos) = find_matching_right_curly_bracket_token(tokens, pos) {
162                    let first_argument: &[TexToken] = &tokens[pos..right_curly_bracket_pos];
163                    arguments.push(self.expand_macros(first_argument)?);
164                    pos = right_curly_bracket_pos;
165                } else {
166                    return Err(format!("Unmatched curly brackets for command {}", command_name));
167                }
168                pos += 1;
169
170                if !tokens[pos].value.eq("{") {
171                    return Err(format!("Expecting two arguments for command {}", command_name));
172                }
173                pos += 1;
174                if let Some(right_curly_bracket_pos) = find_matching_right_curly_bracket_token(tokens, pos) {
175                    let second_argument: &[TexToken] = &tokens[pos..right_curly_bracket_pos];
176                    arguments.push(self.expand_macros(second_argument)?);
177                    pos = right_curly_bracket_pos;
178                } else {
179                    return Err(format!("Unmatched curly brackets for command {}", command_name));
180                }
181                pos += 1;
182            }
183            CommandType::OptionalUnary => {
184                let s;
185                match tokens.get(pos) {
186                    None => {
187                        return Err(format!("Expecting optional argument for command {}", command_name));
188                    }
189                    Some(token) => {
190                        s = token.value.as_str();
191                    }
192                }
193                match s {
194                    "[" => {
195                        // one optional argument
196                        pos += 1;
197                        if let Some(right_square_bracket) = find_matching_right_square_bracket_token(tokens, pos) {
198                            let optional_argument: &[TexToken] = &tokens[pos..right_square_bracket];
199                            arguments.push(self.expand_macros(optional_argument)?);
200                            pos = right_square_bracket + 1;
201                        } else {
202                            return Err(format!("Unmatched right square brackets for command {}", command_name));
203                        }
204                    }
205                    _ => {
206                        // no given optional argument, will use the default value
207                    }
208                };
209            }
210            CommandType::OptionalBinary => {
211                let s;
212                match tokens.get(pos) {
213                    None => {
214                        return Err(format!("Expecting optional argument for command {}", command_name));
215                    }
216                    Some(token) => {
217                        s = token.value.as_str();
218                    }
219                }
220                match tokens[pos].value.as_str() {
221                    "[" => {
222                        // one optional argument, one mandatory argument
223                        pos += 1;
224                        if let Some(right_square_bracket) = find_matching_right_square_bracket_token(tokens, pos) {
225                            let optional_argument: &[TexToken] = &tokens[pos..right_square_bracket];
226                            arguments.push(self.expand_macros(optional_argument)?);
227                            pos = right_square_bracket;
228                            pos += 1;
229                        } else {
230                            return Err(format!("Unmatched square brackets for command {}", command_name));
231                        }
232
233                        if tokens.get(pos).map(|token| token.value.as_str()) != Some("{") {
234                            return Err(format!(
235                                "Expecting the mandatory argument after the optional argument for command {}",
236                                command_name
237                            ));
238                        }
239                        pos += 1;
240                        if let Some(right_curly_bracket_pos) = find_matching_right_curly_bracket_token(tokens, pos) {
241                            let mandatory_argument: &[TexToken] = &tokens[pos..right_curly_bracket_pos];
242                            arguments.push(self.expand_macros(mandatory_argument)?);
243                            pos = right_curly_bracket_pos + 1;
244                        } else {
245                            return Err(format!("Unmatched curly brackets for command {}", command_name));
246                        }
247                    }
248                    "{" => {
249                        // no optional argument, one mandatory argument
250                        pos += 1;
251                        if let Some(right_curly_bracket_pos) = find_matching_right_curly_bracket_token(tokens, pos) {
252                            let mandatory_argument: &[TexToken] = &tokens[pos..right_curly_bracket_pos];
253                            arguments.push(self.expand_macros(mandatory_argument)?);
254                            pos = right_curly_bracket_pos + 1;
255                        } else {
256                            return Err(format!("Unmatched curly brackets for command {}", command_name));
257                        }
258                    }
259                    _ => {
260                        return Err(format!(
261                            "Expecting optional or mandatory argument for command {}",
262                            command_name
263                        ));
264                    }
265                };
266            }
267        }
268
269        let expanded_tokens = (custom_macro.implementation)(&arguments)?;
270        Ok((expanded_tokens, pos))
271    }
272}
273
274fn find_matching_right_curly_bracket_token(tokens: &[TexToken], start: usize) -> Option<usize> {
275    let mut count = 1;
276    let mut pos = start + 1;
277
278    while count > 0 {
279        if pos >= tokens.len() {
280            return None;
281        }
282        if pos + 1 < tokens.len() && tokens[pos].value == "\\" && tokens[pos + 1].value == "}" {
283            pos += 2;
284            continue;
285        }
286        match tokens[pos].value.as_str() {
287            "{" => count += 1,
288            "}" => count -= 1,
289            _ => {}
290        }
291        pos += 1;
292    }
293
294    Some(pos - 1)
295}
296
297fn find_matching_right_square_bracket_token(tokens: &[TexToken], start: usize) -> Option<usize> {
298    let mut count = 1;
299    let mut pos = start;
300
301    while count > 0 {
302        if pos >= tokens.len() {
303            return None;
304        }
305        if pos + 1 < tokens.len() && tokens[pos].value == "\\" && tokens[pos + 1].value == "]" {
306            pos += 2;
307            continue;
308        }
309        match tokens[pos].value.as_str() {
310            "[" => count += 1,
311            "]" => count -= 1,
312            _ => {}
313        }
314        pos += 1;
315    }
316
317    Some(pos - 1)
318}
319
320fn find_matching_right_curly_bracket_char(latex: &Vec<char>, start: usize) -> Option<usize> {
321    let mut count = 1;
322    let mut pos = start + 1;
323
324    while count > 0 {
325        if pos >= latex.len() {
326            return None;
327        }
328        if pos + 1 < latex.len() && latex[pos] == '\\' && latex[pos + 1] == '}' {
329            pos += 2;
330            continue;
331        }
332        match latex[pos] {
333            '{' => count += 1,
334            '}' => count -= 1,
335            _ => {}
336        }
337        pos += 1;
338    }
339
340    Some(pos - 1)
341}
342
343fn find_matching_right_square_bracket_char(latex: &Vec<char>, start: usize) -> Option<usize> {
344    let mut count = 1;
345    let mut pos = start;
346
347    while count > 0 {
348        if pos >= latex.len() {
349            return None;
350        }
351        if pos + 1 < latex.len() && latex[pos] == '\\' && latex[pos + 1] == ']' {
352            pos += 2;
353            continue;
354        }
355        match latex[pos] {
356            '[' => count += 1,
357            ']' => count -= 1,
358            _ => {}
359        }
360        pos += 1;
361    }
362
363    Some(pos - 1)
364}
365
366pub fn parse_custom_macros(latex: &str) -> Result<Vec<CustomMacro>, String> {
367    let latex: Vec<char> = latex.chars().collect();
368    let pattern: Vec<char> = "\\newcommand".chars().collect();
369    let pattern_len = pattern.len();
370    let mut pos = 0;
371    let mut custom_macros: Vec<CustomMacro> = Vec::new();
372
373    while pos < latex.len().saturating_sub(pattern_len) {
374        if latex[pos..pos + pattern_len] == pattern[..] {
375            pos += pattern_len;
376            // extract the new command name
377            let new_command_name: String;
378            if latex.get(pos) != Some(&'{') {
379                return Err("Expecting { after \\newcommand".to_string());
380            }
381            pos += 1;
382            if latex.get(pos) != Some(&'\\') {
383                return Err("Expecting backslash for command name after {".to_string());
384            }
385            if let Some(right_curly_bracket_pos) = find_matching_right_curly_bracket_char(&latex, pos) {
386                new_command_name = latex[pos..right_curly_bracket_pos].iter().collect();
387                pos = right_curly_bracket_pos;
388            } else {
389                return Err("Unmatched curly brackets".to_string());
390            }
391
392            // check if there is a specification of number of arguments
393            let num_of_args: usize;
394            pos += 1;
395            if latex.get(pos) == Some(&'[') {
396                pos += 1;
397                if let Some(right_square_bracket) = find_matching_right_square_bracket_char(&latex, pos) {
398                    num_of_args = latex[pos..right_square_bracket]
399                        .iter()
400                        .collect::<String>()
401                        .parse::<usize>()
402                        .map_err(|e| e.to_string())?;
403                    if num_of_args > 2 {
404                        return Err("Only unary and binary commands are supported".to_string());
405                    }
406                    pos = right_square_bracket;
407                } else {
408                    return Err("Unmatched square brackets".to_string());
409                }
410                pos += 1;
411            } else {
412                num_of_args = 0;
413            }
414
415            // check if there is a default value for the first argument
416            let default_value: Option<String>;
417            if latex.get(pos) == Some(&'[') {
418                pos += 1;
419                if let Some(right_square_bracket) = find_matching_right_square_bracket_char(&latex, pos) {
420                    default_value = Some(latex[pos..right_square_bracket].iter().collect::<String>());
421                    pos = right_square_bracket;
422                } else {
423                    return Err("Unmatched square brackets".to_string());
424                }
425                pos += 1;
426            } else {
427                default_value = None;
428            }
429
430            // extract the definition
431            let definition: String;
432            if latex.get(pos) != Some(&'{') {
433                return Err("Expecting { before the definition".to_string());
434            }
435            pos += 1;
436            if let Some(right_curly_bracket_pos) = find_matching_right_curly_bracket_char(&latex, pos) {
437                definition = latex[pos..right_curly_bracket_pos].iter().collect();
438                pos = right_curly_bracket_pos;
439            } else {
440                return Err("Unmatched curly brackets".to_string());
441            }
442
443            custom_macros.push(construct_custom_macro(
444                new_command_name,
445                num_of_args,
446                default_value,
447                definition,
448            )?);
449        }
450        pos += 1;
451    }
452
453    if custom_macros.is_empty() && latex.len() > 0 {
454        return Err("No custom macros found".to_string());
455    }
456
457    Ok(custom_macros)
458}
459
460fn construct_custom_macro(
461    new_command_name: String,
462    num_of_args: usize,
463    default_value: Option<String>,
464    definition: String,
465) -> Result<CustomMacro, String> {
466    let command_type: CommandType;
467    let implementation: Box<dyn Fn(&Vec<Vec<TexToken>>) -> Result<Vec<TexToken>, String>>;
468
469    if let Some(default_value) = default_value {
470        // default value provided, so it's an optional unary or optional binary command
471        match num_of_args {
472            0 => {
473                return Err("Default value provided for a command with no arguments".to_string());
474            }
475            1 => {
476                // optional unary command
477                command_type = CommandType::OptionalUnary;
478                implementation = Box::new(move |args: &Vec<Vec<TexToken>>| {
479                    let replaced_string: String;
480                    if args.is_empty() {
481                        replaced_string = definition.replace("#1", &default_value);
482                    } else {
483                        replaced_string = definition.replace(
484                            "#1",
485                            &args[0].iter().map(|token| token.value.clone()).collect::<String>(),
486                        );
487                    }
488                    tokenize(&replaced_string)
489                });
490            }
491            2 => {
492                // optional binary command
493                command_type = CommandType::OptionalBinary;
494                implementation = Box::new(move |args: &Vec<Vec<TexToken>>| {
495                    let replaced_string: String;
496                    if args.len() == 1 {
497                        replaced_string = definition.replace("#1", &default_value).replace(
498                            "#2",
499                            &args[0].iter().map(|token| token.value.clone()).collect::<String>(),
500                        );
501                    } else if args.len() == 2 {
502                        replaced_string = definition
503                            .replace(
504                                "#1",
505                                &args[0].iter().map(|token| token.value.clone()).collect::<String>(),
506                            )
507                            .replace(
508                                "#2",
509                                &args[1].iter().map(|token| token.value.clone()).collect::<String>(),
510                            );
511                    } else {
512                        return Err("Expecting one or two arguments".to_string());
513                    }
514                    tokenize(&replaced_string)
515                });
516            }
517            _ => {
518                return Err("Only unary and binary commands are supported".to_string());
519            }
520        }
521    } else {
522        // no default value, it's either a symbol, unary or binary command
523        match num_of_args {
524            0 => {
525                // symbol command
526                command_type = CommandType::Symbol;
527                implementation = Box::new(move |_| tokenize(&definition));
528            }
529            1 => {
530                // unary command
531                command_type = CommandType::Unary;
532                implementation = Box::new(move |args: &Vec<Vec<TexToken>>| {
533                    let replaced_string = definition.replace(
534                        "#1",
535                        &args[0].iter().map(|token| token.value.clone()).collect::<String>(),
536                    );
537                    tokenize(&replaced_string)
538                });
539            }
540            2 => {
541                // binary command
542                command_type = CommandType::Binary;
543                implementation = Box::new(move |args: &Vec<Vec<TexToken>>| {
544                    let replaced_string = definition
545                        .replace(
546                            "#1",
547                            &args[0].iter().map(|token| token.value.clone()).collect::<String>(),
548                        )
549                        .replace(
550                            "#2",
551                            &args[1].iter().map(|token| token.value.clone()).collect::<String>(),
552                        );
553                    tokenize(&replaced_string)
554                });
555            }
556            _ => {
557                return Err("Only unary and binary commands are supported".to_string());
558            }
559        }
560    }
561
562    Ok(CustomMacro {
563        name: new_command_name,
564        command_type,
565        implementation,
566    })
567}
568
569#[cfg(test)]
570mod tests {
571    use super::*;
572    use crate::definitions::TexTokenType;
573    use crate::tex_tokenizer::tokenize;
574
575    #[test]
576    fn test_tokenize() {
577        let tex = r"\alpha";
578        let tokens = tokenize(tex).unwrap();
579        assert_eq!(
580            tokens,
581            vec![TexToken {
582                token_type: TexTokenType::Command,
583                value: r"\alpha".to_string(),
584            }]
585        );
586    }
587
588    #[test]
589    fn test_command_registry_symbol() {
590        let mut registry = CommandRegistry::new();
591
592        let implementation = |tokens: &Vec<Vec<TexToken>>| {
593            Ok(vec![TexToken {
594                token_type: TexTokenType::Command,
595                value: r"\mycommandexpanded".to_string(),
596            }])
597        };
598        registry.register_custom_macro(r"\mycommand", CommandType::Symbol, Box::new(implementation));
599
600        assert_eq!(registry.get_command_type(r"\mycommand"), Some(CommandType::Symbol));
601
602        let tokens = vec![TexToken {
603            token_type: TexTokenType::Command,
604            value: r"\mycommand".to_string(),
605        }];
606        let expanded_tokens = registry.expand_macros(&tokens).unwrap();
607        assert_eq!(
608            expanded_tokens,
609            vec![TexToken {
610                token_type: TexTokenType::Command,
611                value: r"\mycommandexpanded".to_string(),
612            }]
613        );
614    }
615
616    #[test]
617    fn test_command_registry_simple_unary() {
618        let mut registry = CommandRegistry::new();
619
620        let implementation = |tokens: &Vec<Vec<TexToken>>| {
621            let mut res = tokenize(r"\expanded{").unwrap();
622            res.extend(tokens[0].iter().cloned());
623            res.push(TexToken {
624                token_type: TexTokenType::Control,
625                value: "}".to_string(),
626            });
627            Ok(res)
628        };
629        registry.register_custom_macro(r"\mycommand", CommandType::Unary, Box::new(implementation));
630
631        assert_eq!(registry.get_command_type(r"\mycommand"), Some(CommandType::Unary));
632
633        let tokens = tokenize(r"\mycommand{a}").unwrap();
634        let expanded_tokens = registry.expand_macros(&tokens).unwrap();
635        assert_eq!(expanded_tokens, tokenize(r"\expanded{a}").unwrap(),);
636    }
637
638    #[test]
639    fn test_parse_custom_macros_symbol() {
640        let macro_string = r"\newcommand{\mycommand}{\expanded}";
641        let tex = r"\mycommand";
642
643        let custom_macros = parse_custom_macros(macro_string).unwrap();
644
645        assert_eq!(custom_macros.len(), 1);
646        assert_eq!(custom_macros[0].name, "\\mycommand");
647        assert_eq!(custom_macros[0].command_type, CommandType::Symbol);
648        assert_eq!(
649            (custom_macros[0].implementation)(&vec![]).unwrap(),
650            tokenize(r"\expanded").unwrap()
651        );
652
653        let mut registry = CommandRegistry::new();
654        registry.register_custom_macros(custom_macros);
655        let tokens = tokenize(tex).unwrap();
656        let expanded_tokens = registry.expand_macros(&tokens).unwrap();
657        assert_eq!(expanded_tokens, tokenize(r"\expanded").unwrap());
658    }
659
660    #[test]
661    fn test_parse_custom_macros_unary() {
662        let macro_string = r"\newcommand{\mycommand}[1]{\expanded{#1}}";
663        let tex = r"\mycommand{a}";
664
665        let custom_macros = parse_custom_macros(macro_string).unwrap();
666
667        assert_eq!(custom_macros.len(), 1);
668        assert_eq!(custom_macros[0].name, "\\mycommand");
669        assert_eq!(custom_macros[0].command_type, CommandType::Unary);
670        assert_eq!(
671            (custom_macros[0].implementation)(&vec![tokenize("a").unwrap()]).unwrap(),
672            tokenize(r"\expanded{a}").unwrap()
673        );
674
675        let mut registry = CommandRegistry::new();
676        registry.register_custom_macros(custom_macros);
677        let tokens = tokenize(tex).unwrap();
678        let expanded_tokens = registry.expand_macros(&tokens).unwrap();
679        assert_eq!(expanded_tokens, tokenize(r"\expanded{a}").unwrap());
680    }
681
682    #[test]
683    fn test_parse_custom_macros_binary() {
684        let macro_string = r"\newcommand{\mycommand}[2]{\expanded{#1}\and{#2}}";
685        let tex = r"\mycommand{a}{b}";
686
687        let custom_macros = parse_custom_macros(macro_string).unwrap();
688
689        assert_eq!(custom_macros.len(), 1);
690        assert_eq!(custom_macros[0].name, "\\mycommand");
691        assert_eq!(custom_macros[0].command_type, CommandType::Binary);
692        assert_eq!(
693            (custom_macros[0].implementation)(&vec![tokenize("a").unwrap(), tokenize("b").unwrap()]).unwrap(),
694            tokenize(r"\expanded{a}\and{b}").unwrap()
695        );
696
697        let mut registry = CommandRegistry::new();
698        registry.register_custom_macros(custom_macros);
699        let tokens = tokenize(tex).unwrap();
700        let expanded_tokens = registry.expand_macros(&tokens).unwrap();
701        assert_eq!(expanded_tokens, tokenize(r"\expanded{a}\and{b}").unwrap());
702    }
703
704    #[test]
705    fn test_parse_custom_macros_optional_unary() {
706        let macro_string = r"\newcommand{\mycommand}[1][default]{\expanded{#1}}";
707        let tex = r"\mycommand \mycommand[a]";
708
709        let custom_macros = parse_custom_macros(macro_string).unwrap();
710
711        assert_eq!(custom_macros.len(), 1);
712        assert_eq!(custom_macros[0].name, "\\mycommand");
713        assert_eq!(custom_macros[0].command_type, CommandType::OptionalUnary);
714
715        let mut registry = CommandRegistry::new();
716        registry.register_custom_macros(custom_macros);
717        let tokens = tokenize(tex).unwrap();
718        let expanded_tokens = registry.expand_macros(&tokens).unwrap();
719        assert_eq!(expanded_tokens, tokenize(r"\expanded{default} \expanded{a}").unwrap());
720    }
721
722    #[test]
723    fn test_parse_custom_macros_optional_binary() {
724        let macro_string = r"\newcommand{\mycommand}[2][def]{\expanded{#1}\and{#2}}";
725        let tex = r"\mycommand{b} \mycommand[a]{b}";
726
727        let custom_macros = parse_custom_macros(macro_string).unwrap();
728
729        assert_eq!(custom_macros.len(), 1);
730        assert_eq!(custom_macros[0].name, "\\mycommand");
731        assert_eq!(custom_macros[0].command_type, CommandType::OptionalBinary);
732
733        let mut registry = CommandRegistry::new();
734        registry.register_custom_macros(custom_macros);
735        let tokens = tokenize(tex).unwrap();
736        let expanded_tokens = registry.expand_macros(&tokens).unwrap();
737        assert_eq!(
738            expanded_tokens,
739            tokenize(r"\expanded{def}\and{b} \expanded{a}\and{b}").unwrap()
740        );
741    }
742
743    #[test]
744    fn test_multiple_custom_macros() {
745        let macro_string = r"\newcommand{\mysym}{\texttt{sym}}
746        \newcommand{\aunary}[1]{\expanded{#1}}
747        \newcommand{\abinary}[2]{\expanded{#1}\and{#2}}
748        \newcommand{\aoptionalunary}[1][def1]{\expanded{#1}}
749        \newcommand{\aoptionalbinary}[2][def2]{\expanded{#1}\and{#2}}";
750        let tex = r"\mysym \aunary{a} \abinary{a}{b} \aoptionalunary \aoptionalunary[a] \aoptionalbinary{b} \aoptionalbinary[a]{b}";
751
752        let custom_macros = parse_custom_macros(macro_string).unwrap();
753
754        assert_eq!(custom_macros.len(), 5);
755
756        let mut registry = CommandRegistry::new();
757        registry.register_custom_macros(custom_macros);
758        let tokens = tokenize(tex).unwrap();
759        let expanded_tokens = registry.expand_macros(&tokens).unwrap();
760        assert_eq!(expanded_tokens, tokenize(r"\texttt{sym} \expanded{a} \expanded{a}\and{b} \expanded{def1} \expanded{a} \expanded{def2}\and{b} \expanded{a}\and{b}").unwrap());
761    }
762
763    #[test]
764    fn test_recurive_square_brackets() {
765        let macro_string = r"\newcommand{\pp}[2][]{\expanded{#1}{#2}}";
766        let tex = r"\pp[f[x]]{x}";
767
768        let custom_macros = parse_custom_macros(macro_string);
769
770        let mut registry = CommandRegistry::new();
771        registry.register_custom_macros(custom_macros.unwrap());
772        let tokens = tokenize(tex).unwrap();
773        let expaned = registry.expand_macros(&tokens).unwrap();
774        assert_eq!(expaned, tokenize(r"\expanded{f[x]}{x}").unwrap());
775    }
776}