1use crate::token::{Span, Token, TokenKind};
6use std::str::Chars;
7
8pub struct Lexer<'a> {
10 source: &'a str,
12 chars: Chars<'a>,
14 pos: usize,
16 line: u32,
18 column: u32,
20 token_start: usize,
22 token_start_line: u32,
24 token_start_column: u32,
26}
27
28impl<'a> Lexer<'a> {
29 pub fn new(source: &'a str) -> Self {
31 Self {
32 source,
33 chars: source.chars(),
34 pos: 0,
35 line: 1,
36 column: 1,
37 token_start: 0,
38 token_start_line: 1,
39 token_start_column: 1,
40 }
41 }
42
43 pub fn tokenize(mut self) -> Vec<Token> {
45 let mut tokens = Vec::new();
46 loop {
47 let token = self.next_token();
48 let is_eof = token.is_eof();
49 tokens.push(token);
50 if is_eof {
51 break;
52 }
53 }
54 tokens
55 }
56
57 pub fn next_token(&mut self) -> Token {
59 self.skip_whitespace();
60 self.mark_token_start();
61
62 let Some(c) = self.peek() else {
63 return self.make_token(TokenKind::Eof);
64 };
65
66 if c == '/' && self.peek_next() == Some('/') {
68 return self.lex_comment();
69 }
70
71 if c == '/' && self.peek_next() == Some('*') {
73 return self.lex_multiline_comment();
74 }
75
76 if c == '"' {
78 return self.lex_string();
79 }
80
81 if c.is_ascii_digit() {
83 return self.lex_number();
84 }
85
86 if c.is_alphabetic() || c == '_' {
88 return self.lex_identifier();
89 }
90
91 self.lex_operator_or_punctuation()
93 }
94
95 fn skip_whitespace(&mut self) {
97 while let Some(c) = self.peek() {
98 if c.is_whitespace() {
99 self.advance();
100 } else {
101 break;
102 }
103 }
104 }
105
106 fn mark_token_start(&mut self) {
108 self.token_start = self.pos;
109 self.token_start_line = self.line;
110 self.token_start_column = self.column;
111 }
112
113 fn peek(&self) -> Option<char> {
115 self.chars.clone().next()
116 }
117
118 fn peek_next(&self) -> Option<char> {
120 let mut chars = self.chars.clone();
121 chars.next();
122 chars.next()
123 }
124
125 fn advance(&mut self) -> Option<char> {
127 let c = self.chars.next()?;
128 self.pos += c.len_utf8();
129 if c == '\n' {
130 self.line += 1;
131 self.column = 1;
132 } else {
133 self.column += 1;
134 }
135 Some(c)
136 }
137
138 fn make_token(&self, kind: TokenKind) -> Token {
140 Token::new(
141 kind,
142 Span::new(
143 self.token_start,
144 self.pos,
145 self.token_start_line,
146 self.token_start_column,
147 ),
148 )
149 }
150
151 fn token_text(&self) -> &'a str {
153 &self.source[self.token_start..self.pos]
154 }
155
156 fn lex_comment(&mut self) -> Token {
158 self.advance();
160 self.advance();
161
162 let is_doc = self.peek() == Some('/');
164 if is_doc {
165 self.advance();
166 }
167
168 if self.peek() == Some(' ') {
170 self.advance();
171 }
172
173 let content_start = self.pos;
174
175 while let Some(c) = self.peek() {
177 if c == '\n' {
178 break;
179 }
180 self.advance();
181 }
182
183 let content = self.source[content_start..self.pos].to_string();
184
185 if is_doc {
186 self.make_token(TokenKind::DocComment(content))
187 } else {
188 self.make_token(TokenKind::Comment(content))
189 }
190 }
191
192 fn lex_multiline_comment(&mut self) -> Token {
194 self.advance();
196 self.advance();
197
198 let content_start = self.pos;
199 let mut depth = 1;
200
201 while depth > 0 {
202 match self.peek() {
203 None => {
204 return self.make_token(TokenKind::Error(
205 "unterminated multi-line comment".to_string(),
206 ));
207 }
208 Some('*') if self.peek_next() == Some('/') => {
209 self.advance();
210 self.advance();
211 depth -= 1;
212 }
213 Some('/') if self.peek_next() == Some('*') => {
214 self.advance();
215 self.advance();
216 depth += 1;
217 }
218 Some(_) => {
219 self.advance();
220 }
221 }
222 }
223
224 let content = self.source[content_start..self.pos - 2].to_string();
225 self.make_token(TokenKind::Comment(content))
226 }
227
228 fn lex_string(&mut self) -> Token {
230 self.advance();
232
233 let mut content = String::new();
234 loop {
235 match self.peek() {
236 None | Some('\n') => {
237 return self
238 .make_token(TokenKind::Error("unterminated string literal".to_string()));
239 }
240 Some('"') => {
241 self.advance();
242 break;
243 }
244 Some('\\') => {
245 self.advance();
246 match self.peek() {
247 Some('n') => {
248 content.push('\n');
249 self.advance();
250 }
251 Some('t') => {
252 content.push('\t');
253 self.advance();
254 }
255 Some('r') => {
256 content.push('\r');
257 self.advance();
258 }
259 Some('\\') => {
260 content.push('\\');
261 self.advance();
262 }
263 Some('"') => {
264 content.push('"');
265 self.advance();
266 }
267 Some(c) => {
268 return self.make_token(TokenKind::Error(format!(
269 "invalid escape sequence: \\{}",
270 c
271 )));
272 }
273 None => {
274 return self.make_token(TokenKind::Error(
275 "unterminated string literal".to_string(),
276 ));
277 }
278 }
279 }
280 Some(c) => {
281 content.push(c);
282 self.advance();
283 }
284 }
285 }
286
287 self.make_token(TokenKind::StringLit(content))
288 }
289
290 fn lex_number(&mut self) -> Token {
292 while let Some(c) = self.peek() {
293 if c.is_ascii_digit() {
294 self.advance();
295 } else {
296 break;
297 }
298 }
299
300 let text = self.token_text();
301 match text.parse::<i64>() {
302 Ok(n) => self.make_token(TokenKind::Integer(n)),
303 Err(_) => self.make_token(TokenKind::Error(format!("invalid integer: {}", text))),
304 }
305 }
306
307 fn lex_identifier(&mut self) -> Token {
309 while let Some(c) = self.peek() {
310 if c.is_alphanumeric() || c == '_' {
311 self.advance();
312 } else {
313 break;
314 }
315 }
316
317 let text = self.token_text();
318
319 if let Some(keyword) = TokenKind::keyword(text) {
321 self.make_token(keyword)
322 } else {
323 self.make_token(TokenKind::Ident(text.to_string()))
324 }
325 }
326
327 fn lex_operator_or_punctuation(&mut self) -> Token {
329 let c = self.advance().unwrap();
330
331 match c {
332 '(' => self.make_token(TokenKind::LParen),
333 ')' => self.make_token(TokenKind::RParen),
334 '{' => self.make_token(TokenKind::LBrace),
335 '}' => self.make_token(TokenKind::RBrace),
336 '[' => self.make_token(TokenKind::LBracket),
337 ']' => self.make_token(TokenKind::RBracket),
338 ',' => self.make_token(TokenKind::Comma),
339 ':' => self.make_token(TokenKind::Colon),
340 ';' => self.make_token(TokenKind::Semicolon),
341 '\'' => self.make_token(TokenKind::Prime),
342 '|' => self.make_token(TokenKind::Pipe),
343 '&' => self.make_token(TokenKind::Ampersand),
344 '%' => self.make_token(TokenKind::Percent),
345 '*' => self.make_token(TokenKind::Star),
346 '.' => {
347 if self.peek() == Some('.') {
348 self.advance();
349 self.make_token(TokenKind::DotDot)
350 } else {
351 self.make_token(TokenKind::Dot)
352 }
353 }
354 '+' => {
355 if self.peek() == Some('+') {
356 self.advance();
357 self.make_token(TokenKind::PlusPlus)
358 } else {
359 self.make_token(TokenKind::Plus)
360 }
361 }
362 '-' => {
363 if self.peek() == Some('>') {
364 self.advance();
365 self.make_token(TokenKind::Arrow)
366 } else {
367 self.make_token(TokenKind::Minus)
368 }
369 }
370 '/' => self.make_token(TokenKind::Slash),
371 '=' => {
372 if self.peek() == Some('=') {
373 self.advance();
374 self.make_token(TokenKind::Eq)
375 } else if self.peek() == Some('>') {
376 self.advance();
377 self.make_token(TokenKind::FatArrow)
378 } else {
379 self.make_token(TokenKind::Assign)
380 }
381 }
382 '!' => {
383 if self.peek() == Some('=') {
384 self.advance();
385 self.make_token(TokenKind::Ne)
386 } else {
387 self.make_token(TokenKind::Error(format!("unexpected character: {}", c)))
388 }
389 }
390 '<' => {
391 if self.peek() == Some('=') {
392 self.advance();
393 self.make_token(TokenKind::Le)
394 } else {
395 self.make_token(TokenKind::Lt)
396 }
397 }
398 '>' => {
399 if self.peek() == Some('=') {
400 self.advance();
401 self.make_token(TokenKind::Ge)
402 } else {
403 self.make_token(TokenKind::Gt)
404 }
405 }
406 _ => self.make_token(TokenKind::Error(format!("unexpected character: {}", c))),
407 }
408 }
409}
410
411impl<'a> Iterator for Lexer<'a> {
412 type Item = Token;
413
414 fn next(&mut self) -> Option<Self::Item> {
415 let token = self.next_token();
416 if token.is_eof() {
417 None
418 } else {
419 Some(token)
420 }
421 }
422}
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427
428 fn lex(source: &str) -> Vec<TokenKind> {
429 Lexer::new(source)
430 .tokenize()
431 .into_iter()
432 .map(|t| t.kind)
433 .collect()
434 }
435
436 #[test]
437 fn test_empty() {
438 assert_eq!(lex(""), vec![TokenKind::Eof]);
439 }
440
441 #[test]
442 fn test_whitespace() {
443 assert_eq!(lex(" \n\t "), vec![TokenKind::Eof]);
444 }
445
446 #[test]
447 fn test_keywords() {
448 assert_eq!(
449 lex("module action init"),
450 vec![
451 TokenKind::Module,
452 TokenKind::Action,
453 TokenKind::Init,
454 TokenKind::Eof
455 ]
456 );
457 }
458
459 #[test]
460 fn test_identifiers() {
461 assert_eq!(
462 lex("foo bar_baz _private"),
463 vec![
464 TokenKind::Ident("foo".to_string()),
465 TokenKind::Ident("bar_baz".to_string()),
466 TokenKind::Ident("_private".to_string()),
467 TokenKind::Eof
468 ]
469 );
470 }
471
472 #[test]
473 fn test_numbers() {
474 assert_eq!(
475 lex("0 42 123456"),
476 vec![
477 TokenKind::Integer(0),
478 TokenKind::Integer(42),
479 TokenKind::Integer(123456),
480 TokenKind::Eof
481 ]
482 );
483 }
484
485 #[test]
486 fn test_strings() {
487 assert_eq!(
488 lex(r#""hello" "world""#),
489 vec![
490 TokenKind::StringLit("hello".to_string()),
491 TokenKind::StringLit("world".to_string()),
492 TokenKind::Eof
493 ]
494 );
495 }
496
497 #[test]
498 fn test_string_escapes() {
499 assert_eq!(
500 lex(r#""line1\nline2" "tab\there" "quote\"end""#),
501 vec![
502 TokenKind::StringLit("line1\nline2".to_string()),
503 TokenKind::StringLit("tab\there".to_string()),
504 TokenKind::StringLit("quote\"end".to_string()),
505 TokenKind::Eof
506 ]
507 );
508 }
509
510 #[test]
511 fn test_operators() {
512 assert_eq!(
513 lex("== != < <= > >= + - * / %"),
514 vec![
515 TokenKind::Eq,
516 TokenKind::Ne,
517 TokenKind::Lt,
518 TokenKind::Le,
519 TokenKind::Gt,
520 TokenKind::Ge,
521 TokenKind::Plus,
522 TokenKind::Minus,
523 TokenKind::Star,
524 TokenKind::Slash,
525 TokenKind::Percent,
526 TokenKind::Eof
527 ]
528 );
529 }
530
531 #[test]
532 fn test_punctuation() {
533 assert_eq!(
534 lex("( ) { } [ ] , : ; . .. -> => ' |"),
535 vec![
536 TokenKind::LParen,
537 TokenKind::RParen,
538 TokenKind::LBrace,
539 TokenKind::RBrace,
540 TokenKind::LBracket,
541 TokenKind::RBracket,
542 TokenKind::Comma,
543 TokenKind::Colon,
544 TokenKind::Semicolon,
545 TokenKind::Dot,
546 TokenKind::DotDot,
547 TokenKind::Arrow,
548 TokenKind::FatArrow,
549 TokenKind::Prime,
550 TokenKind::Pipe,
551 TokenKind::Eof
552 ]
553 );
554 }
555
556 #[test]
557 fn test_comments() {
558 let tokens = lex("foo // comment\nbar");
559 assert_eq!(
560 tokens,
561 vec![
562 TokenKind::Ident("foo".to_string()),
563 TokenKind::Comment("comment".to_string()),
564 TokenKind::Ident("bar".to_string()),
565 TokenKind::Eof
566 ]
567 );
568 }
569
570 #[test]
571 fn test_doc_comments() {
572 let tokens = lex("/// This is documentation\nfoo");
573 assert_eq!(
574 tokens,
575 vec![
576 TokenKind::DocComment("This is documentation".to_string()),
577 TokenKind::Ident("foo".to_string()),
578 TokenKind::Eof
579 ]
580 );
581 }
582
583 #[test]
584 fn test_multiline_comment() {
585 let tokens = lex("foo /* comment */ bar");
586 assert_eq!(
587 tokens,
588 vec![
589 TokenKind::Ident("foo".to_string()),
590 TokenKind::Comment(" comment ".to_string()),
591 TokenKind::Ident("bar".to_string()),
592 TokenKind::Eof
593 ]
594 );
595 }
596
597 #[test]
598 fn test_nested_multiline_comment() {
599 let tokens = lex("/* outer /* inner */ outer */ foo");
600 assert_eq!(
601 tokens,
602 vec![
603 TokenKind::Comment(" outer /* inner */ outer ".to_string()),
604 TokenKind::Ident("foo".to_string()),
605 TokenKind::Eof
606 ]
607 );
608 }
609
610 #[test]
611 fn test_logical_keywords() {
612 assert_eq!(
613 lex("and or not implies iff"),
614 vec![
615 TokenKind::And,
616 TokenKind::Or,
617 TokenKind::Not,
618 TokenKind::Implies,
619 TokenKind::Iff,
620 TokenKind::Eof
621 ]
622 );
623 }
624
625 #[test]
626 fn test_quantifiers() {
627 assert_eq!(
628 lex("all any fix in"),
629 vec![
630 TokenKind::All,
631 TokenKind::Any,
632 TokenKind::Choose,
633 TokenKind::In,
634 TokenKind::Eof
635 ]
636 );
637 }
638
639 #[test]
640 fn test_temporal() {
641 assert_eq!(
642 lex("always eventually leads_to enabled"),
643 vec![
644 TokenKind::Always,
645 TokenKind::Eventually,
646 TokenKind::LeadsTo,
647 TokenKind::Enabled,
648 TokenKind::Eof
649 ]
650 );
651 }
652
653 #[test]
654 fn test_types() {
655 assert_eq!(
656 lex("Nat Int Bool String Set Seq Dict Option"),
657 vec![
658 TokenKind::Nat,
659 TokenKind::Int,
660 TokenKind::Bool,
661 TokenKind::String_,
662 TokenKind::Set,
663 TokenKind::Seq,
664 TokenKind::Dict,
665 TokenKind::Option_,
666 TokenKind::Eof
667 ]
668 );
669 }
670
671 #[test]
672 fn test_sample_spec() {
673 let source = r#"
674module Counter
675
676var count: Nat
677
678init {
679 count == 0
680}
681
682action Increment() {
683 count' == count + 1
684}
685"#;
686 let tokens: Vec<_> = Lexer::new(source)
687 .tokenize()
688 .into_iter()
689 .filter(|t| !t.kind.is_trivia())
690 .map(|t| t.kind)
691 .collect();
692
693 assert_eq!(
694 tokens,
695 vec![
696 TokenKind::Module,
697 TokenKind::Ident("Counter".to_string()),
698 TokenKind::Var,
699 TokenKind::Ident("count".to_string()),
700 TokenKind::Colon,
701 TokenKind::Nat,
702 TokenKind::Init,
703 TokenKind::LBrace,
704 TokenKind::Ident("count".to_string()),
705 TokenKind::Eq,
706 TokenKind::Integer(0),
707 TokenKind::RBrace,
708 TokenKind::Action,
709 TokenKind::Ident("Increment".to_string()),
710 TokenKind::LParen,
711 TokenKind::RParen,
712 TokenKind::LBrace,
713 TokenKind::Ident("count".to_string()),
714 TokenKind::Prime,
715 TokenKind::Eq,
716 TokenKind::Ident("count".to_string()),
717 TokenKind::Plus,
718 TokenKind::Integer(1),
719 TokenKind::RBrace,
720 TokenKind::Eof
721 ]
722 );
723 }
724
725 #[test]
726 fn test_span_tracking() {
727 let tokens = Lexer::new("foo bar").tokenize();
728 assert_eq!(tokens[0].span.line, 1);
729 assert_eq!(tokens[0].span.column, 1);
730 assert_eq!(tokens[1].span.line, 1);
731 assert_eq!(tokens[1].span.column, 5);
732 }
733
734 #[test]
735 fn test_span_multiline() {
736 let tokens = Lexer::new("foo\nbar").tokenize();
737 assert_eq!(tokens[0].span.line, 1);
738 assert_eq!(tokens[1].span.line, 2);
739 assert_eq!(tokens[1].span.column, 1);
740 }
741
742 #[test]
743 fn test_error_recovery() {
744 let tokens = lex("foo @ bar");
745 assert!(matches!(tokens[1], TokenKind::Error(_)));
746 assert_eq!(tokens[2], TokenKind::Ident("bar".to_string()));
747 }
748}