1use nom::{
6 branch::alt,
7 bytes::complete::{tag, tag_no_case, take_while, take_while1},
8 character::complete::{char, multispace0, multispace1, one_of},
9 combinator::{map, opt, recognize},
10 multi::many0,
11 sequence::{delimited, pair, preceded, tuple},
12 IResult,
13};
14use serde::{Deserialize, Serialize};
15use std::fmt;
16
17#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
19pub struct Token {
20 pub kind: TokenKind,
21 pub lexeme: String,
22 pub position: Position,
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
27pub struct Position {
28 pub line: usize,
29 pub column: usize,
30 pub offset: usize,
31}
32
33#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
35pub enum TokenKind {
36 Match,
38 OptionalMatch,
39 Where,
40 Return,
41 Create,
42 Merge,
43 Delete,
44 DetachDelete,
45 Set,
46 Remove,
47 With,
48 OrderBy,
49 Limit,
50 Skip,
51 Distinct,
52 As,
53 Asc,
54 Desc,
55 Case,
56 When,
57 Then,
58 Else,
59 End,
60 And,
61 Or,
62 Xor,
63 Not,
64 In,
65 Is,
66 Null,
67 True,
68 False,
69 OnCreate,
70 OnMatch,
71
72 Identifier(String),
74 Integer(i64),
75 Float(f64),
76 String(String),
77
78 Plus,
80 Minus,
81 Star,
82 Slash,
83 Percent,
84 Caret,
85 Equal,
86 NotEqual,
87 LessThan,
88 LessThanOrEqual,
89 GreaterThan,
90 GreaterThanOrEqual,
91 Arrow, LeftArrow, Dash, LeftParen,
97 RightParen,
98 LeftBracket,
99 RightBracket,
100 LeftBrace,
101 RightBrace,
102 Comma,
103 Dot,
104 Colon,
105 Semicolon,
106 Pipe,
107
108 DotDot, Eof,
111}
112
113impl fmt::Display for TokenKind {
114 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
115 match self {
116 TokenKind::Identifier(s) => write!(f, "identifier '{}'", s),
117 TokenKind::Integer(n) => write!(f, "integer {}", n),
118 TokenKind::Float(n) => write!(f, "float {}", n),
119 TokenKind::String(s) => write!(f, "string \"{}\"", s),
120 _ => write!(f, "{:?}", self),
121 }
122 }
123}
124
125pub fn tokenize(input: &str) -> Result<Vec<Token>, LexerError> {
127 let mut tokens = Vec::new();
128 let mut remaining = input;
129 let mut position = Position {
130 line: 1,
131 column: 1,
132 offset: 0,
133 };
134
135 while !remaining.is_empty() {
136 if let Ok((rest, _)) = multispace1::<_, nom::error::Error<_>>(remaining) {
138 let consumed = remaining.len() - rest.len();
139 update_position(&mut position, &remaining[..consumed]);
140 remaining = rest;
141 continue;
142 }
143
144 match parse_token(remaining) {
146 Ok((rest, (kind, lexeme))) => {
147 tokens.push(Token {
148 kind,
149 lexeme: lexeme.to_string(),
150 position,
151 });
152 update_position(&mut position, lexeme);
153 remaining = rest;
154 }
155 Err(_) => {
156 return Err(LexerError::UnexpectedCharacter {
157 character: remaining.chars().next().unwrap(),
158 position,
159 });
160 }
161 }
162 }
163
164 tokens.push(Token {
165 kind: TokenKind::Eof,
166 lexeme: String::new(),
167 position,
168 });
169
170 Ok(tokens)
171}
172
173fn update_position(pos: &mut Position, text: &str) {
174 for ch in text.chars() {
175 pos.offset += ch.len_utf8();
176 if ch == '\n' {
177 pos.line += 1;
178 pos.column = 1;
179 } else {
180 pos.column += 1;
181 }
182 }
183}
184
185fn parse_token(input: &str) -> IResult<&str, (TokenKind, &str)> {
186 alt((
187 parse_keyword,
188 parse_number,
189 parse_string,
190 parse_identifier,
191 parse_operator,
192 parse_delimiter,
193 ))(input)
194}
195
196fn parse_keyword(input: &str) -> IResult<&str, (TokenKind, &str)> {
197 let (input, _) = multispace0(input)?;
198
199 alt((
201 alt((
202 map(tag_no_case("OPTIONAL MATCH"), |s: &str| {
203 (TokenKind::OptionalMatch, s)
204 }),
205 map(tag_no_case("DETACH DELETE"), |s: &str| {
206 (TokenKind::DetachDelete, s)
207 }),
208 map(tag_no_case("ORDER BY"), |s: &str| (TokenKind::OrderBy, s)),
209 map(tag_no_case("ON CREATE"), |s: &str| (TokenKind::OnCreate, s)),
210 map(tag_no_case("ON MATCH"), |s: &str| (TokenKind::OnMatch, s)),
211 map(tag_no_case("MATCH"), |s: &str| (TokenKind::Match, s)),
212 map(tag_no_case("WHERE"), |s: &str| (TokenKind::Where, s)),
213 map(tag_no_case("RETURN"), |s: &str| (TokenKind::Return, s)),
214 map(tag_no_case("CREATE"), |s: &str| (TokenKind::Create, s)),
215 map(tag_no_case("MERGE"), |s: &str| (TokenKind::Merge, s)),
216 map(tag_no_case("DELETE"), |s: &str| (TokenKind::Delete, s)),
217 map(tag_no_case("SET"), |s: &str| (TokenKind::Set, s)),
218 map(tag_no_case("REMOVE"), |s: &str| (TokenKind::Remove, s)),
219 map(tag_no_case("WITH"), |s: &str| (TokenKind::With, s)),
220 map(tag_no_case("LIMIT"), |s: &str| (TokenKind::Limit, s)),
221 map(tag_no_case("SKIP"), |s: &str| (TokenKind::Skip, s)),
222 map(tag_no_case("DISTINCT"), |s: &str| (TokenKind::Distinct, s)),
223 )),
224 alt((
225 map(tag_no_case("ASC"), |s: &str| (TokenKind::Asc, s)),
226 map(tag_no_case("DESC"), |s: &str| (TokenKind::Desc, s)),
227 map(tag_no_case("CASE"), |s: &str| (TokenKind::Case, s)),
228 map(tag_no_case("WHEN"), |s: &str| (TokenKind::When, s)),
229 map(tag_no_case("THEN"), |s: &str| (TokenKind::Then, s)),
230 map(tag_no_case("ELSE"), |s: &str| (TokenKind::Else, s)),
231 map(tag_no_case("END"), |s: &str| (TokenKind::End, s)),
232 map(tag_no_case("AND"), |s: &str| (TokenKind::And, s)),
233 map(tag_no_case("OR"), |s: &str| (TokenKind::Or, s)),
234 map(tag_no_case("XOR"), |s: &str| (TokenKind::Xor, s)),
235 map(tag_no_case("NOT"), |s: &str| (TokenKind::Not, s)),
236 map(tag_no_case("IN"), |s: &str| (TokenKind::In, s)),
237 map(tag_no_case("IS"), |s: &str| (TokenKind::Is, s)),
238 map(tag_no_case("NULL"), |s: &str| (TokenKind::Null, s)),
239 map(tag_no_case("TRUE"), |s: &str| (TokenKind::True, s)),
240 map(tag_no_case("FALSE"), |s: &str| (TokenKind::False, s)),
241 map(tag_no_case("AS"), |s: &str| (TokenKind::As, s)),
242 )),
243 ))(input)
244}
245
246fn parse_number(input: &str) -> IResult<&str, (TokenKind, &str)> {
247 let (input, _) = multispace0(input)?;
248
249 if let Ok((rest, num_str)) = recognize::<_, _, nom::error::Error<_>, _>(tuple((
251 opt(char('-')),
252 take_while1(|c: char| c.is_ascii_digit()),
253 char('.'),
254 take_while1(|c: char| c.is_ascii_digit()),
255 opt(tuple((
256 one_of("eE"),
257 opt(one_of("+-")),
258 take_while1(|c: char| c.is_ascii_digit()),
259 ))),
260 )))(input)
261 {
262 if let Ok(n) = num_str.parse::<f64>() {
263 return Ok((rest, (TokenKind::Float(n), num_str)));
264 }
265 }
266
267 let (rest, num_str) = recognize(tuple((
269 opt(char('-')),
270 take_while1(|c: char| c.is_ascii_digit()),
271 )))(input)?;
272
273 let n = num_str.parse::<i64>().map_err(|_| {
274 nom::Err::Error(nom::error::Error::new(input, nom::error::ErrorKind::Digit))
275 })?;
276
277 Ok((rest, (TokenKind::Integer(n), num_str)))
278}
279
280fn parse_string(input: &str) -> IResult<&str, (TokenKind, &str)> {
281 let (input, _) = multispace0(input)?;
282
283 let (rest, s) = alt((
284 delimited(
285 char('\''),
286 recognize(many0(alt((
287 tag("\\'"),
288 tag("\\\\"),
289 take_while1(|c| c != '\'' && c != '\\'),
290 )))),
291 char('\''),
292 ),
293 delimited(
294 char('"'),
295 recognize(many0(alt((
296 tag("\\\""),
297 tag("\\\\"),
298 take_while1(|c| c != '"' && c != '\\'),
299 )))),
300 char('"'),
301 ),
302 ))(input)?;
303
304 let unescaped = s
306 .replace("\\'", "'")
307 .replace("\\\"", "\"")
308 .replace("\\\\", "\\");
309
310 Ok((rest, (TokenKind::String(unescaped), s)))
311}
312
313fn parse_identifier(input: &str) -> IResult<&str, (TokenKind, &str)> {
314 let (input, _) = multispace0(input)?;
315
316 let backtick_result: IResult<&str, &str> =
318 delimited(char('`'), take_while1(|c| c != '`'), char('`'))(input);
319 if let Ok((rest, id)) = backtick_result {
320 return Ok((rest, (TokenKind::Identifier(id.to_string()), id)));
321 }
322
323 let (rest, id) = recognize(pair(
325 alt((
326 take_while1(|c: char| c.is_ascii_alphabetic() || c == '_'),
327 tag("$"),
328 )),
329 take_while(|c: char| c.is_ascii_alphanumeric() || c == '_'),
330 ))(input)?;
331
332 Ok((rest, (TokenKind::Identifier(id.to_string()), id)))
333}
334
335fn parse_operator(input: &str) -> IResult<&str, (TokenKind, &str)> {
336 let (input, _) = multispace0(input)?;
337
338 alt((
339 map(tag("<="), |s| (TokenKind::LessThanOrEqual, s)),
340 map(tag(">="), |s| (TokenKind::GreaterThanOrEqual, s)),
341 map(tag("<>"), |s| (TokenKind::NotEqual, s)),
342 map(tag("!="), |s| (TokenKind::NotEqual, s)),
343 map(tag("->"), |s| (TokenKind::Arrow, s)),
344 map(tag("<-"), |s| (TokenKind::LeftArrow, s)),
345 map(tag(".."), |s| (TokenKind::DotDot, s)),
346 map(char('+'), |_| (TokenKind::Plus, "+")),
347 map(char('-'), |_| (TokenKind::Dash, "-")),
348 map(char('*'), |_| (TokenKind::Star, "*")),
349 map(char('/'), |_| (TokenKind::Slash, "/")),
350 map(char('%'), |_| (TokenKind::Percent, "%")),
351 map(char('^'), |_| (TokenKind::Caret, "^")),
352 map(char('='), |_| (TokenKind::Equal, "=")),
353 map(char('<'), |_| (TokenKind::LessThan, "<")),
354 map(char('>'), |_| (TokenKind::GreaterThan, ">")),
355 ))(input)
356}
357
358fn parse_delimiter(input: &str) -> IResult<&str, (TokenKind, &str)> {
359 let (input, _) = multispace0(input)?;
360
361 alt((
362 map(char('('), |_| (TokenKind::LeftParen, "(")),
363 map(char(')'), |_| (TokenKind::RightParen, ")")),
364 map(char('['), |_| (TokenKind::LeftBracket, "[")),
365 map(char(']'), |_| (TokenKind::RightBracket, "]")),
366 map(char('{'), |_| (TokenKind::LeftBrace, "{")),
367 map(char('}'), |_| (TokenKind::RightBrace, "}")),
368 map(char(','), |_| (TokenKind::Comma, ",")),
369 map(char('.'), |_| (TokenKind::Dot, ".")),
370 map(char(':'), |_| (TokenKind::Colon, ":")),
371 map(char(';'), |_| (TokenKind::Semicolon, ";")),
372 map(char('|'), |_| (TokenKind::Pipe, "|")),
373 ))(input)
374}
375
376#[derive(Debug, thiserror::Error)]
377pub enum LexerError {
378 #[error("Unexpected character '{character}' at line {}, column {}", position.line, position.column)]
379 UnexpectedCharacter { character: char, position: Position },
380}
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385
386 #[test]
387 fn test_tokenize_simple_match() {
388 let input = "MATCH (n:Person) RETURN n";
389 let tokens = tokenize(input).unwrap();
390
391 assert_eq!(tokens[0].kind, TokenKind::Match);
392 assert_eq!(tokens[1].kind, TokenKind::LeftParen);
393 assert_eq!(tokens[2].kind, TokenKind::Identifier("n".to_string()));
394 assert_eq!(tokens[3].kind, TokenKind::Colon);
395 assert_eq!(tokens[4].kind, TokenKind::Identifier("Person".to_string()));
396 assert_eq!(tokens[5].kind, TokenKind::RightParen);
397 assert_eq!(tokens[6].kind, TokenKind::Return);
398 assert_eq!(tokens[7].kind, TokenKind::Identifier("n".to_string()));
399 }
400
401 #[test]
402 fn test_tokenize_numbers() {
403 let tokens = tokenize("123 45.67 -89 3.14e-2").unwrap();
404 assert_eq!(tokens[0].kind, TokenKind::Integer(123));
405 assert_eq!(tokens[1].kind, TokenKind::Float(45.67));
406 assert_eq!(tokens[2].kind, TokenKind::Integer(-89));
407 assert!(matches!(tokens[3].kind, TokenKind::Float(_)));
408 }
409
410 #[test]
411 fn test_tokenize_strings() {
412 let tokens = tokenize(r#"'Alice' "Bob's friend""#).unwrap();
413 assert_eq!(tokens[0].kind, TokenKind::String("Alice".to_string()));
414 assert_eq!(
415 tokens[1].kind,
416 TokenKind::String("Bob's friend".to_string())
417 );
418 }
419
420 #[test]
421 fn test_tokenize_operators() {
422 let tokens = tokenize("-> <- = <> >= <=").unwrap();
423 assert_eq!(tokens[0].kind, TokenKind::Arrow);
424 assert_eq!(tokens[1].kind, TokenKind::LeftArrow);
425 assert_eq!(tokens[2].kind, TokenKind::Equal);
426 assert_eq!(tokens[3].kind, TokenKind::NotEqual);
427 assert_eq!(tokens[4].kind, TokenKind::GreaterThanOrEqual);
428 assert_eq!(tokens[5].kind, TokenKind::LessThanOrEqual);
429 }
430}