rush_parser/
lexer.rs

1use std::{mem, str::Chars};
2
3use crate::{Error, Location, Result, Token, TokenKind};
4
5pub trait Lex<'src> {
6    fn next_token(&mut self) -> Result<'src, Token<'src>>;
7    fn source(&self) -> &'src str;
8}
9
10pub struct Lexer<'src> {
11    input: &'src str,
12    reader: Chars<'src>,
13    location: Location<'src>,
14    curr_char: Option<char>,
15    next_char: Option<char>,
16}
17
18macro_rules! char_construct {
19    ($self:ident, $kind_single:ident, $kind_with_eq:tt, $kind_double:tt, $kind_double_with_eq:tt $(,)?) => {
20        return Ok($self.make_char_construct(
21            TokenKind::$kind_single,
22            char_construct!(@optional $kind_with_eq),
23            char_construct!(@optional $kind_double),
24            char_construct!(@optional $kind_double_with_eq),
25        ))
26    };
27    (@optional _) => { None };
28    (@optional $kind:ident) => { Some(TokenKind::$kind) };
29}
30
31impl<'src> Lex<'src> for Lexer<'src> {
32    fn next_token(&mut self) -> Result<'src, Token<'src>> {
33        // skip comments, whitespaces and newlines
34        loop {
35            match (self.curr_char, self.next_char) {
36                (Some(' ' | '\t' | '\n' | '\r'), _) => self.next(),
37                (Some('/'), Some('/')) => self.skip_line_comment(),
38                (Some('/'), Some('*')) => self.skip_block_comment(),
39                _ => break,
40            }
41        }
42        let start_loc = self.location;
43        let kind = match self.curr_char {
44            None => TokenKind::Eof,
45            Some('\'') => return self.make_char(),
46            Some('(') => TokenKind::LParen,
47            Some(')') => TokenKind::RParen,
48            Some('{') => TokenKind::LBrace,
49            Some('}') => TokenKind::RBrace,
50            Some(',') => TokenKind::Comma,
51            Some(':') => TokenKind::Colon,
52            Some(';') => TokenKind::Semicolon,
53            Some('!') => char_construct!(self, Not, Neq, _, _),
54            Some('-') if self.next_char == Some('>') => {
55                self.next();
56                TokenKind::Arrow
57            }
58            Some('-') => char_construct!(self, Minus, MinusAssign, _, _),
59            Some('+') => char_construct!(self, Plus, PlusAssign, _, _),
60            Some('*') => char_construct!(self, Star, MulAssign, Pow, PowAssign),
61            Some('/') => char_construct!(self, Slash, DivAssign, _, _),
62            Some('%') => char_construct!(self, Percent, RemAssign, _, _),
63            Some('=') => char_construct!(self, Assign, Eq, _, _),
64            Some('<') => char_construct!(self, Lt, Lte, Shl, ShlAssign),
65            Some('>') => char_construct!(self, Gt, Gte, Shr, ShrAssign),
66            Some('|') => char_construct!(self, BitOr, BitOrAssign, Or, _),
67            Some('&') => char_construct!(self, BitAnd, BitAndAssign, And, _),
68            Some('^') => char_construct!(self, BitXor, BitXorAssign, _, _),
69            Some(char) if char.is_ascii_digit() => return self.make_number(),
70            Some(char) if char.is_ascii_alphabetic() || char == '_' => return Ok(self.make_name()),
71            Some(char) => {
72                self.next();
73                return Err(Error::new_boxed(
74                    format!("illegal character `{char}`"),
75                    start_loc.until(self.location),
76                    self.input,
77                ));
78            }
79        };
80        self.next();
81        Ok(kind.spanned(start_loc.until(self.location)))
82    }
83
84    fn source(&self) -> &'src str {
85        self.input
86    }
87}
88
89impl<'src> Lexer<'src> {
90    pub fn new(text: &'src str, path: &'src str) -> Self {
91        let mut lexer = Self {
92            input: text,
93            reader: text.chars(),
94            location: Location::new(path),
95            curr_char: None,
96            next_char: None,
97        };
98        // advance the lexer twice so that curr_char and next_char are populated
99        lexer.next();
100        lexer.next();
101        lexer
102    }
103
104    fn next(&mut self) {
105        if let Some(current_char) = self.curr_char {
106            self.location.advance(
107                current_char == '\n',
108                // byte count is specified because advance does not know about the current char
109                current_char.len_utf8(),
110            );
111        }
112        // swap the current and next char so that the old next is the new current
113        mem::swap(&mut self.curr_char, &mut self.next_char);
114        self.next_char = self.reader.next()
115    }
116
117    fn skip_line_comment(&mut self) {
118        self.next();
119        self.next();
120        while !matches!(self.curr_char, Some('\n') | None) {
121            self.next()
122        }
123        self.next();
124    }
125
126    fn skip_block_comment(&mut self) {
127        self.next();
128        self.next();
129        loop {
130            match (self.curr_char, self.next_char) {
131                // end of block comment
132                (Some('*'), Some('/')) => {
133                    self.next();
134                    self.next();
135                    break;
136                }
137                // any char in comment
138                (Some(_), _) => self.next(),
139                // end of file
140                _ => break,
141            }
142        }
143    }
144
145    fn make_char_construct(
146        &mut self,
147        kind_single: TokenKind<'src>,
148        kind_with_eq: Option<TokenKind<'src>>,
149        kind_double: Option<TokenKind<'src>>,
150        kind_double_with_eq: Option<TokenKind<'src>>,
151    ) -> Token<'src> {
152        let start_loc = self.location;
153        let char = self
154            .curr_char
155            .expect("this should only be called when self.curr_char is Some(_)");
156        self.next();
157        match (
158            kind_with_eq,
159            &kind_double,
160            &kind_double_with_eq,
161            self.curr_char,
162        ) {
163            (Some(kind), .., Some('=')) => {
164                self.next();
165                kind.spanned(start_loc.until(self.location))
166            }
167            (_, Some(_), _, Some(current_char)) | (_, _, Some(_), Some(current_char))
168                if current_char == char =>
169            {
170                self.next();
171                match (kind_double, kind_double_with_eq, self.curr_char) {
172                    (_, Some(kind), Some('=')) => {
173                        self.next();
174                        kind.spanned(start_loc.until(self.location))
175                    }
176                    (Some(kind), ..) => kind.spanned(start_loc.until(self.location)),
177                    // can panic when all this is true:
178                    // - `kind_double` is `None`
179                    // - `kind_double_with_eq` is `Some(_)`
180                    // - `self.curr_char` is not `Some('=')`
181                    // however, this function is never called in that context
182                    _ => unreachable!(),
183                }
184            }
185            _ => kind_single.spanned(start_loc.until(self.location)),
186        }
187    }
188
189    fn make_char(&mut self) -> Result<'src, Token<'src>> {
190        let start_loc = self.location;
191        self.next();
192
193        let char = match self.curr_char {
194            None => {
195                self.next();
196                return Err(Error::new_boxed(
197                    "unterminated char literal".to_string(),
198                    start_loc.until(self.location),
199                    self.input,
200                ));
201            }
202            Some('\\') => {
203                let char = match self.next_char {
204                    Some('\\') => b'\\',
205                    Some('\'') => b'\'',
206                    Some('b') => b'\x08',
207                    Some('n') => b'\n',
208                    Some('r') => b'\r',
209                    Some('t') => b'\t',
210                    Some('x') => {
211                        self.next();
212                        self.next();
213                        let start_hex = self.location.byte_idx;
214                        for i in 0..2 {
215                            if !self.curr_char.map_or(false, |c| c.is_ascii_hexdigit()) {
216                                return Err(Error::new_boxed(
217                                    format!("expected 2 hexadecimal digits, found {i}"),
218                                    start_loc.until(self.location),
219                                    self.input,
220                                ));
221                            }
222                            self.next();
223                        }
224                        return match self.curr_char {
225                            Some('\'') => {
226                                let char = u8::from_str_radix(
227                                    &self.input[start_hex..self.location.byte_idx],
228                                    16,
229                                )
230                                .expect("This string slice should be valid hexadecimal");
231                                self.next();
232                                Ok(Token::new(
233                                    TokenKind::Char(char),
234                                    start_loc.until(self.location),
235                                ))
236                            }
237                            _ => {
238                                self.next();
239                                Err(Error::new_boxed(
240                                    "unterminated char literal".to_string(),
241                                    start_loc.until(self.location),
242                                    self.input,
243                                ))
244                            }
245                        };
246                    }
247                    _ => {
248                        self.next();
249                        return Err(Error::new_boxed(
250                            format!(
251                                "expected escape character, found {}",
252                                self.curr_char.map_or("EOF".to_string(), |c| c.to_string())
253                            ),
254                            start_loc.until(self.location),
255                            self.input,
256                        ));
257                    }
258                };
259                self.next();
260                char
261            }
262            Some(char) if char.is_ascii() => char as u8,
263            Some(char) => {
264                self.next();
265                return Err(Error::new_boxed(
266                    format!("character `{char}` is not in ASCII range"),
267                    start_loc.until(self.location),
268                    self.input,
269                ));
270            }
271        };
272        self.next();
273        match self.curr_char {
274            Some('\'') => {
275                self.next();
276                Ok(Token::new(
277                    TokenKind::Char(char),
278                    start_loc.until(self.location),
279                ))
280            }
281            _ => {
282                self.next();
283                Err(Error::new_boxed(
284                    "unterminated char literal".to_string(),
285                    start_loc.until(self.location),
286                    self.input,
287                ))
288            }
289        }
290    }
291
292    fn make_number(&mut self) -> Result<'src, Token<'src>> {
293        let start_loc = self.location;
294
295        if self.curr_char == Some('0') && self.next_char == Some('x') {
296            self.next();
297            self.next();
298            let start_hex = self.location.byte_idx;
299
300            if !self.curr_char.map_or(false, |c| c.is_ascii_hexdigit()) {
301                self.next();
302                return Err(Error::new_boxed(
303                    "expected at least one hexadecimal digit".to_string(),
304                    start_loc.until(self.location),
305                    self.input,
306                ));
307            }
308
309            while self
310                .curr_char
311                .map_or(false, |c| c.is_ascii_hexdigit() || c == '_')
312            {
313                self.next();
314            }
315
316            let num = match i64::from_str_radix(
317                &self.input[start_hex..self.location.byte_idx].replace('_', ""),
318                16,
319            ) {
320                Ok(num) => num,
321                Err(_) => {
322                    return Err(Error::new_boxed(
323                        "integer too large for 64 bits".to_string(),
324                        start_loc.until(self.location),
325                        self.input,
326                    ))
327                }
328            };
329
330            return Ok(TokenKind::Int(num).spanned(start_loc.until(self.location)));
331        }
332
333        while self
334            .curr_char
335            .map_or(false, |c| c.is_ascii_digit() || c == '_')
336        {
337            self.next();
338        }
339
340        match self.curr_char {
341            Some('.') => {
342                self.next();
343
344                if !self.curr_char.map_or(false, |c| c.is_ascii_digit()) {
345                    let err_start = self.location;
346                    self.next();
347                    return Err(Error::new_boxed(
348                        format!(
349                            "expected digit, found `{}`",
350                            self.curr_char.map_or("EOF".to_string(), |c| c.to_string())
351                        ),
352                        err_start.until(self.location),
353                        self.input,
354                    ));
355                }
356
357                while self
358                    .curr_char
359                    .map_or(false, |c| c.is_ascii_digit() || c == '_')
360                {
361                    self.next();
362                }
363
364                let float = self.input[start_loc.byte_idx..self.location.byte_idx]
365                    .replace('_', "")
366                    .parse()
367                    .expect("The grammar guarantees correctly formed float literals");
368                Ok(Token::new(
369                    TokenKind::Float(float),
370                    start_loc.until(self.location),
371                ))
372            }
373            Some('f') => {
374                let float = self.input[start_loc.byte_idx..self.location.byte_idx]
375                    .replace('_', "")
376                    .parse()
377                    .expect("The grammar guarantees correctly formed float literals");
378                self.next();
379                Ok(Token::new(
380                    TokenKind::Float(float),
381                    start_loc.until(self.location),
382                ))
383            }
384            _ => {
385                let int = match self.input[start_loc.byte_idx..self.location.byte_idx]
386                    .replace('_', "")
387                    .parse()
388                {
389                    Ok(value) => value,
390                    Err(_) => {
391                        return Err(Error::new_boxed(
392                            "integer too large for 64 bits".to_string(),
393                            start_loc.until(self.location),
394                            self.input,
395                        ))
396                    }
397                };
398                Ok(Token::new(
399                    TokenKind::Int(int),
400                    start_loc.until(self.location),
401                ))
402            }
403        }
404    }
405
406    fn make_name(&mut self) -> Token<'src> {
407        let start_loc = self.location;
408        while self.curr_char.map_or(false, |c| {
409            c.is_ascii_alphabetic() || c.is_ascii_digit() || c == '_'
410        }) {
411            self.next()
412        }
413        let kind = match &self.input[start_loc.byte_idx..self.location.byte_idx] {
414            "true" => TokenKind::True,
415            "false" => TokenKind::False,
416            "fn" => TokenKind::Fn,
417            "let" => TokenKind::Let,
418            "mut" => TokenKind::Mut,
419            "return" => TokenKind::Return,
420            "loop" => TokenKind::Loop,
421            "while" => TokenKind::While,
422            "for" => TokenKind::For,
423            "break" => TokenKind::Break,
424            "continue" => TokenKind::Continue,
425            "if" => TokenKind::If,
426            "else" => TokenKind::Else,
427            "as" => TokenKind::As,
428            ident => TokenKind::Ident(ident),
429        };
430        kind.spanned(start_loc.until(self.location))
431    }
432}
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437
438    #[test]
439    fn single_tokens() {
440        let tests = [
441            // Chars
442            ("'a'", Ok(TokenKind::Char(b'a').spanned(span!(0..3)))),
443            ("'*'", Ok(TokenKind::Char(b'*').spanned(span!(0..3)))),
444            ("'_'", Ok(TokenKind::Char(b'_').spanned(span!(0..3)))),
445            (r"'\'", Err("unterminated char literal")),
446            (r"'\\'", Ok(TokenKind::Char(b'\\').spanned(span!(0..4)))),
447            (r"'\a'", Err("expected escape character, found a")),
448            (r"'\x1b'", Ok(TokenKind::Char(b'\x1b').spanned(span!(0..6)))),
449            (r"'\x1b1'", Err("unterminated char literal")),
450            // Keywords
451            ("true", Ok(TokenKind::True.spanned(span!(0..4)))),
452            ("false", Ok(TokenKind::False.spanned(span!(0..5)))),
453            ("fn", Ok(TokenKind::Fn.spanned(span!(0..2)))),
454            ("let", Ok(TokenKind::Let.spanned(span!(0..3)))),
455            ("mut", Ok(TokenKind::Mut.spanned(span!(0..3)))),
456            ("return", Ok(TokenKind::Return.spanned(span!(0..6)))),
457            ("loop", Ok(TokenKind::Loop.spanned(span!(0..4)))),
458            ("while", Ok(TokenKind::While.spanned(span!(0..5)))),
459            ("for", Ok(TokenKind::For.spanned(span!(0..3)))),
460            ("break", Ok(TokenKind::Break.spanned(span!(0..5)))),
461            ("continue", Ok(TokenKind::Continue.spanned(span!(0..8)))),
462            ("if", Ok(TokenKind::If.spanned(span!(0..2)))),
463            ("else", Ok(TokenKind::Else.spanned(span!(0..4)))),
464            ("as", Ok(TokenKind::As.spanned(span!(0..2)))),
465            // Identifiers
466            ("foo", Ok(TokenKind::Ident("foo").spanned(span!(0..3)))),
467            ("_foo", Ok(TokenKind::Ident("_foo").spanned(span!(0..4)))),
468            ("f_0o", Ok(TokenKind::Ident("f_0o").spanned(span!(0..4)))),
469            // Numbers
470            ("1", Ok(TokenKind::Int(1).spanned(span!(0..1)))),
471            ("0x1b", Ok(TokenKind::Int(0x1b).spanned(span!(0..4)))),
472            ("42", Ok(TokenKind::Int(42).spanned(span!(0..2)))),
473            ("42f", Ok(TokenKind::Float(42.0).spanned(span!(0..3)))),
474            ("3.1", Ok(TokenKind::Float(3.1).spanned(span!(0..3)))),
475            (
476                "42.12345678",
477                Ok(TokenKind::Float(42.12345678).spanned(span!(0..11))),
478            ),
479            ("42.69", Ok(TokenKind::Float(42.69).spanned(span!(0..5)))),
480            // Parenthesis
481            ("(", Ok(TokenKind::LParen.spanned(span!(0..1)))),
482            (")", Ok(TokenKind::RParen.spanned(span!(0..1)))),
483            ("{", Ok(TokenKind::LBrace.spanned(span!(0..1)))),
484            ("}", Ok(TokenKind::RBrace.spanned(span!(0..1)))),
485            // Punctuation and delimiters
486            ("->", Ok(TokenKind::Arrow.spanned(span!(0..2)))),
487            (",", Ok(TokenKind::Comma.spanned(span!(0..1)))),
488            (":", Ok(TokenKind::Colon.spanned(span!(0..1)))),
489            (";", Ok(TokenKind::Semicolon.spanned(span!(0..1)))),
490            // Operators
491            ("!", Ok(TokenKind::Not.spanned(span!(0..1)))),
492            ("-", Ok(TokenKind::Minus.spanned(span!(0..1)))),
493            ("+", Ok(TokenKind::Plus.spanned(span!(0..1)))),
494            ("*", Ok(TokenKind::Star.spanned(span!(0..1)))),
495            ("/", Ok(TokenKind::Slash.spanned(span!(0..1)))),
496            ("%", Ok(TokenKind::Percent.spanned(span!(0..1)))),
497            ("**", Ok(TokenKind::Pow.spanned(span!(0..2)))),
498            ("==", Ok(TokenKind::Eq.spanned(span!(0..2)))),
499            ("!=", Ok(TokenKind::Neq.spanned(span!(0..2)))),
500            ("<", Ok(TokenKind::Lt.spanned(span!(0..1)))),
501            (">", Ok(TokenKind::Gt.spanned(span!(0..1)))),
502            ("<=", Ok(TokenKind::Lte.spanned(span!(0..2)))),
503            (">=", Ok(TokenKind::Gte.spanned(span!(0..2)))),
504            ("<<", Ok(TokenKind::Shl.spanned(span!(0..2)))),
505            (">>", Ok(TokenKind::Shr.spanned(span!(0..2)))),
506            ("|", Ok(TokenKind::BitOr.spanned(span!(0..1)))),
507            ("&", Ok(TokenKind::BitAnd.spanned(span!(0..1)))),
508            ("^", Ok(TokenKind::BitXor.spanned(span!(0..1)))),
509            ("&&", Ok(TokenKind::And.spanned(span!(0..2)))),
510            ("||", Ok(TokenKind::Or.spanned(span!(0..2)))),
511            // Assignments
512            ("=", Ok(TokenKind::Assign.spanned(span!(0..1)))),
513            ("+=", Ok(TokenKind::PlusAssign.spanned(span!(0..2)))),
514            ("-=", Ok(TokenKind::MinusAssign.spanned(span!(0..2)))),
515            ("*=", Ok(TokenKind::MulAssign.spanned(span!(0..2)))),
516            ("/=", Ok(TokenKind::DivAssign.spanned(span!(0..2)))),
517            ("%=", Ok(TokenKind::RemAssign.spanned(span!(0..2)))),
518            ("**=", Ok(TokenKind::PowAssign.spanned(span!(0..3)))),
519            ("<<=", Ok(TokenKind::ShlAssign.spanned(span!(0..3)))),
520            (">>=", Ok(TokenKind::ShrAssign.spanned(span!(0..3)))),
521            ("|=", Ok(TokenKind::BitOrAssign.spanned(span!(0..2)))),
522            ("&=", Ok(TokenKind::BitAndAssign.spanned(span!(0..2)))),
523            ("^=", Ok(TokenKind::BitXorAssign.spanned(span!(0..2)))),
524        ];
525        println!();
526        for (input, expected) in tests {
527            let mut lexer = Lexer::new(input, "");
528            let res = lexer.next_token();
529            match (res, expected) {
530                (Ok(_), Err(expected)) => panic!("Expected error: {:?}, got none", expected),
531                (Err(err), Ok(_)) => panic!("Unexpected error: {:?}", err),
532                (Err(got), Err(expected)) => assert_eq!(expected, got.message),
533                (Ok(got), Ok(expected)) => {
534                    match got.kind {
535                        TokenKind::Char(ch) => {
536                            println!("found char: {} ({ch})", got.kind)
537                        }
538                        _ => println!("{:?}", got),
539                    }
540                    assert_eq!(expected, got)
541                }
542            }
543        }
544    }
545
546    impl<'src> Iterator for Lexer<'src> {
547        type Item = Result<'src, Token<'src>>;
548
549        fn next(&mut self) -> Option<Self::Item> {
550            match self.next_token() {
551                Ok(Token {
552                    kind: TokenKind::Eof,
553                    span: _,
554                }) => None,
555                item => Some(item),
556            }
557        }
558    }
559
560    #[test]
561    fn call_expr() {
562        let lexer = Lexer::new("exit(1 + 3);", "");
563        assert_eq!(
564            lexer.collect::<Result<Vec<_>>>(),
565            Ok(vec![
566                TokenKind::Ident("exit").spanned(span!(0..4)),
567                TokenKind::LParen.spanned(span!(4..5)),
568                TokenKind::Int(1).spanned(span!(5..6)),
569                TokenKind::Plus.spanned(span!(7..8)),
570                TokenKind::Int(3).spanned(span!(9..10)),
571                TokenKind::RParen.spanned(span!(10..11)),
572                TokenKind::Semicolon.spanned(span!(11..12)),
573            ])
574        );
575    }
576}