Skip to main content

tupa_lexer/
lib.rs

1use nom::{
2    branch::alt,
3    bytes::complete::{escaped_transform, is_not, tag, take_while, take_while1},
4    character::complete::{char, digit1},
5    combinator::{map, opt, recognize, value},
6    sequence::{delimited, pair, tuple},
7    IResult,
8};
9use serde::{Deserialize, Serialize};
10use thiserror::Error;
11
12#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
13pub enum Token {
14    Fn,
15    Enum,
16    Trait,
17    Pipeline,
18    Step,
19    Let,
20    Return,
21    If,
22    Else,
23    Match,
24    While,
25    For,
26    Break,
27    Continue,
28    In,
29    Await,
30    True,
31    False,
32    Null,
33    Ident(String),
34    Int(String),
35    Float(String),
36    Str(String),
37    LParen,
38    RParen,
39    LBrace,
40    RBrace,
41    LBracket,
42    RBracket,
43    Semicolon,
44    Comma,
45    Colon,
46    Equal,
47    Arrow,
48    ThinArrow,
49    EqualEqual,
50    BangEqual,
51    Less,
52    LessEqual,
53    Greater,
54    GreaterEqual,
55    AndAnd,
56    OrOr,
57    Plus,
58    PlusEqual,
59    Minus,
60    MinusEqual,
61    Star,
62    StarEqual,
63    Slash,
64    SlashEqual,
65    DoubleStar,
66    DotDot,
67    Dot,
68    Bang,
69    Pipe,
70    At,
71    Percent,
72    PercentEqual,
73}
74
75#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
76pub struct Span {
77    pub start: usize,
78    pub end: usize,
79}
80
81#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
82pub struct TokenSpan {
83    pub token: Token,
84    pub span: Span,
85}
86
87#[derive(Debug, Error, Serialize, Deserialize)]
88pub enum LexerError {
89    #[error("unexpected character '{0}'")]
90    Unexpected(char, usize),
91}
92
93pub fn lex(input: &str) -> Result<Vec<Token>, LexerError> {
94    let tokens = lex_with_spans(input)?;
95    Ok(tokens.into_iter().map(|t| t.token).collect())
96}
97
98pub fn lex_with_spans(input: &str) -> Result<Vec<TokenSpan>, LexerError> {
99    let mut rest = input;
100    let mut tokens = Vec::new();
101
102    loop {
103        rest = skip_ws_and_comments(rest);
104        if rest.is_empty() {
105            break;
106        }
107
108        if rest.starts_with('.')
109            && !rest.starts_with("..")
110            && rest.chars().nth(1).is_some_and(|c| c.is_ascii_digit())
111        {
112            let pos = input.len().saturating_sub(rest.len());
113            return Err(LexerError::Unexpected('.', pos));
114        }
115
116        let start = input.len().saturating_sub(rest.len());
117        match token(rest) {
118            Ok((next, tok)) => {
119                if matches!(tok, Token::Int(_)) && next.starts_with('.') && !next.starts_with("..")
120                {
121                    let pos = input.len().saturating_sub(next.len());
122                    return Err(LexerError::Unexpected('.', pos));
123                }
124                let end = input.len().saturating_sub(next.len());
125                tokens.push(TokenSpan {
126                    token: tok,
127                    span: Span { start, end },
128                });
129                rest = next;
130            }
131            Err(_) => {
132                let pos = input.len().saturating_sub(rest.len());
133                let ch = rest.chars().next().unwrap();
134                return Err(LexerError::Unexpected(ch, pos));
135            }
136        }
137    }
138
139    Ok(tokens)
140}
141
142fn skip_ws_and_comments(input: &str) -> &str {
143    let mut rest = input;
144    loop {
145        let trimmed = rest.trim_start();
146        if trimmed.starts_with("//") {
147            if let Some(idx) = trimmed.find('\n') {
148                rest = &trimmed[idx + 1..];
149                continue;
150            } else {
151                return "";
152            }
153        }
154        if trimmed.starts_with("/*") {
155            if let Some(idx) = trimmed.find("*/") {
156                rest = &trimmed[idx + 2..];
157                continue;
158            } else {
159                // Unterminated block comment - treat rest as comment
160                return "";
161            }
162        }
163        return trimmed;
164    }
165}
166
167fn token(input: &str) -> IResult<&str, Token> {
168    alt((literal, symbol, ident_or_keyword))(input)
169}
170
171fn ident_or_keyword(input: &str) -> IResult<&str, Token> {
172    map(
173        recognize(pair(
174            alt((
175                take_while1(|c: char| c.is_alphabetic() || c == '_'),
176                tag("_"),
177            )),
178            take_while(|c: char| c.is_alphanumeric() || c == '_'),
179        )),
180        |s: &str| match s {
181            "fn" => Token::Fn,
182            "enum" => Token::Enum,
183            "trait" => Token::Trait,
184            "pipeline" => Token::Pipeline,
185            "step" => Token::Step,
186            "let" => Token::Let,
187            "return" => Token::Return,
188            "if" => Token::If,
189            "else" => Token::Else,
190            "match" => Token::Match,
191            "while" => Token::While,
192            "for" => Token::For,
193            "break" => Token::Break,
194            "continue" => Token::Continue,
195            "in" => Token::In,
196            "await" => Token::Await,
197            "true" => Token::True,
198            "false" => Token::False,
199            "null" => Token::Null,
200            _ => Token::Ident(s.to_string()),
201        },
202    )(input)
203}
204
205fn literal(input: &str) -> IResult<&str, Token> {
206    alt((
207        map(float_lit, Token::Float),
208        map(int_lit, Token::Int),
209        map(string_lit, Token::Str),
210    ))(input)
211}
212
213fn symbol(input: &str) -> IResult<&str, Token> {
214    alt((symbol_two_char, symbol_one_char))(input)
215}
216
217fn symbol_two_char(input: &str) -> IResult<&str, Token> {
218    alt((
219        value(Token::ThinArrow, tag("->")),
220        value(Token::Arrow, tag("=>")),
221        value(Token::EqualEqual, tag("==")),
222        value(Token::BangEqual, tag("!=")),
223        value(Token::LessEqual, tag("<=")),
224        value(Token::GreaterEqual, tag(">=")),
225        value(Token::AndAnd, tag("&&")),
226        value(Token::OrOr, tag("||")),
227        value(Token::PlusEqual, tag("+=")),
228        value(Token::MinusEqual, tag("-=")),
229        value(Token::StarEqual, tag("*=")),
230        value(Token::SlashEqual, tag("/=")),
231        value(Token::DoubleStar, tag("**")),
232        value(Token::DotDot, tag("..")),
233        value(Token::PercentEqual, tag("%=")),
234    ))(input)
235}
236
237fn symbol_one_char(input: &str) -> IResult<&str, Token> {
238    alt((
239        value(Token::LParen, char('(')),
240        value(Token::RParen, char(')')),
241        value(Token::LBrace, char('{')),
242        value(Token::RBrace, char('}')),
243        value(Token::LBracket, char('[')),
244        value(Token::RBracket, char(']')),
245        value(Token::Semicolon, char(';')),
246        value(Token::Comma, char(',')),
247        value(Token::Colon, char(':')),
248        value(Token::Equal, char('=')),
249        value(Token::Less, char('<')),
250        value(Token::Greater, char('>')),
251        value(Token::Plus, char('+')),
252        value(Token::Minus, char('-')),
253        value(Token::Star, char('*')),
254        value(Token::Slash, char('/')),
255        value(Token::Dot, char('.')),
256        value(Token::Bang, char('!')),
257        value(Token::Pipe, char('|')),
258        value(Token::At, char('@')),
259        value(Token::Percent, char('%')),
260    ))(input)
261}
262
263fn int_lit(input: &str) -> IResult<&str, String> {
264    map(recognize(digit1), |s: &str| s.to_string())(input)
265}
266
267fn float_lit(input: &str) -> IResult<&str, String> {
268    map(
269        recognize(tuple((
270            digit1,
271            char('.'),
272            digit1,
273            opt(tuple((
274                alt((char('e'), char('E'))),
275                opt(alt((char('+'), char('-')))),
276                digit1,
277            ))),
278        ))),
279        |s: &str| s.to_string(),
280    )(input)
281}
282
283fn string_lit(input: &str) -> IResult<&str, String> {
284    delimited(
285        char('"'),
286        map(
287            opt(escaped_transform(
288                is_not("\\\""),
289                '\\',
290                alt((
291                    value("\\", tag("\\")),
292                    value("\"", tag("\"")),
293                    value("\n", tag("n")),
294                    value("\r", tag("r")),
295                    value("\t", tag("t")),
296                )),
297            )),
298            |s| s.unwrap_or_default(),
299        ),
300        char('"'),
301    )(input)
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    #[test]
309    fn lex_keywords_and_idents() {
310        let tokens =
311            lex("fn let if else match while for in return await true false null foo bar").unwrap();
312        assert_eq!(
313            tokens,
314            vec![
315                Token::Fn,
316                Token::Let,
317                Token::If,
318                Token::Else,
319                Token::Match,
320                Token::While,
321                Token::For,
322                Token::In,
323                Token::Return,
324                Token::Await,
325                Token::True,
326                Token::False,
327                Token::Null,
328                Token::Ident("foo".to_string()),
329                Token::Ident("bar".to_string()),
330            ]
331        );
332    }
333}