Skip to main content

spg_sql/
lexer.rs

1//! Lexer for the PG-dialect subset that SPG accepts.
2//!
3//! v0.2 token stream is value-only — no source spans yet. Errors do report
4//! the byte offset where the offending construct started. Identifiers are
5//! ASCII case-folded to lower-case (matches PG when un-quoted). Quoted
6//! identifiers (`"..."`) preserve case; `""` is an embedded quote.
7//! String literals (`'...'`) follow PG single-quote convention with `''`
8//! as the embedded quote. The lexer accepts but does not interpret E-strings
9//! or dollar-quoted strings — those land in a later milestone.
10
11use alloc::string::{String, ToString};
12use alloc::vec::Vec;
13use core::fmt;
14
15#[derive(Debug, Clone, PartialEq)]
16pub enum Token {
17    // Keywords
18    Select,
19    From,
20    Where,
21    As,
22    Null,
23    True,
24    False,
25    And,
26    Or,
27    Not,
28    Create,
29    Table,
30    Insert,
31    Into,
32    Values,
33    Index,
34    On,
35    Begin,
36    Commit,
37    Rollback,
38    Order,
39    By,
40    Limit,
41
42    // Identifiers
43    Ident(String),       // ASCII case-folded
44    QuotedIdent(String), // original case, "" → "
45    /// v7.14.0 — MySQL session / user variable reference
46    /// (`@VAR` / `@@VAR`). The wrapped string is the verbatim
47    /// source form (including the `@` / `@@` prefix). Used by
48    /// mysqldump preamble (`SET @OLD_FOREIGN_KEY_CHECKS =
49    /// @@FOREIGN_KEY_CHECKS, …`); SPG accepts the token and
50    /// the SET parser treats the assignment as a no-op apart
51    /// from any second LHS that targets a real session
52    /// parameter (e.g. `FOREIGN_KEY_CHECKS=0`).
53    SessionVar(String),
54
55    // Literals
56    Integer(i64),
57    Float(f64),
58    String(String),
59
60    // Operators
61    Plus,
62    Minus,
63    Star,
64    Slash,
65    Eq,
66    NotEq,
67    Lt,
68    LtEq,
69    Gt,
70    GtEq,
71
72    // Punctuation
73    LParen,
74    RParen,
75    LBracket,
76    RBracket,
77    Comma,
78    Semicolon,
79    Dot,
80    /// pgvector L2 distance operator `<->`. Lexed as one token so the
81    /// parser can give it its own precedence rung.
82    /// v4.14 `->` — JSON object/array element access, returns json.
83    JsonGet,
84    /// v4.14 `->>` — same access, returns text.
85    JsonGetText,
86    /// v6.4.5 `#>` — JSON path walk, returns json. Path is the
87    /// right-hand TEXT with PG `{a,b,0}` syntax.
88    JsonGetPath,
89    /// v6.4.5 `#>>` — same walk, returns text.
90    JsonGetPathText,
91    /// v6.4.5 `@>` — JSON containment. `j @> sub` returns true if
92    /// every key/value in `sub` is present in `j` with structural
93    /// containment for objects + arrays.
94    JsonContains,
95    /// v7.12.2 `@@` — tsvector / tsquery match. Either ordering
96    /// (`vec @@ q` or `q @@ vec`) parses; engine eval normalises
97    /// before matching.
98    TsMatch,
99    L2Distance,
100    /// pgvector inner-product operator `<#>` (returns negative dot product
101    /// so smaller still means more similar — same semantics as pgvector).
102    InnerProduct,
103    /// pgvector cosine distance operator `<=>`.
104    CosineDistance,
105    /// PG-style cast `expr::type` — single token because we want it to bind
106    /// at postfix precedence.
107    DoubleColon,
108    /// v7.12.4 — PL/pgSQL assignment operator `:=`.
109    /// Outside PL/pgSQL bodies this token has no SQL-side meaning.
110    ColonEq,
111    /// v7.12.4 — bare `:` separator. Used inside `tsvector` external-form
112    /// literals (`'cat:1 dog:2'::tsvector`) and as the fallback path for
113    /// the PL/pgSQL assignment lexer.
114    Colon,
115    /// Standard SQL string concatenation `||`.
116    Concat,
117    /// `IS` keyword — postfix `IS NULL` / `IS NOT NULL` predicates.
118    Is,
119    Between,
120    In,
121    Like,
122    Group,
123    Distinct,
124    Union,
125    All,
126    Join,
127    Inner,
128    Left,
129    Cross,
130    Outer,
131    Default,
132    Savepoint,
133    Release,
134    To,
135    Having,
136    Show,
137    Extract,
138    Offset,
139    Asc,
140    Desc,
141    /// `INTERVAL` — followed by a string literal carrying the span text
142    /// (e.g. `INTERVAL '1 day 2 hours'`).
143    Interval,
144    /// v6.1.1 — `$N` parameter placeholder for the extended query
145    /// protocol. The number N is 1-based per PostgreSQL convention.
146    /// `0` and `$0` are not valid; the lexer rejects them.
147    Placeholder(u16),
148
149    /// v6.1.2 — `DROP` keyword. Used by `DROP PUBLICATION <name>`.
150    /// Reserved for future `DROP TABLE` / `DROP INDEX` / `DROP USER`
151    /// surface that currently goes through SHOW-shaped admin SQL.
152    Drop,
153    /// v6.1.2 — `FOR` keyword (publication scope).
154    For,
155    /// v6.1.2 — `TABLES` plural keyword (`FOR ALL TABLES`,
156    /// `FOR ALL TABLES EXCEPT …`). The existing `TABLE` keyword
157    /// stays a separate token so `CREATE TABLE`'s single-table
158    /// form keeps lexing as today.
159    Tables,
160    /// v6.1.3 (reserved at v6.1.2 to keep the AST shape stable) —
161    /// `EXCEPT` keyword for `FOR ALL TABLES EXCEPT t1, t2`.
162    Except,
163    /// v6.1.2 — `PUBLICATION` keyword.
164    Publication,
165    /// v6.1.4 (reserved at v6.1.2) — `SUBSCRIPTION` keyword.
166    Subscription,
167    /// v6.1.4 — `CONNECTION` keyword (for
168    /// `CREATE SUBSCRIPTION … CONNECTION '<conn_str>' …`).
169    Connection,
170
171    Eof,
172}
173
174#[derive(Debug, Clone, PartialEq, Eq)]
175pub enum LexErrorKind {
176    UnknownChar(char),
177    UnterminatedString,
178    UnterminatedQuotedIdent,
179    UnterminatedBlockComment,
180    BadNumber(String),
181}
182
183#[derive(Debug, Clone, PartialEq, Eq)]
184pub struct LexError {
185    pub kind: LexErrorKind,
186    pub pos: usize,
187}
188
189impl fmt::Display for LexError {
190    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191        match &self.kind {
192            LexErrorKind::UnknownChar(c) => write!(f, "unknown char {c:?} at byte {}", self.pos),
193            LexErrorKind::UnterminatedString => {
194                write!(f, "unterminated string literal at byte {}", self.pos)
195            }
196            LexErrorKind::UnterminatedQuotedIdent => {
197                write!(f, "unterminated quoted identifier at byte {}", self.pos)
198            }
199            LexErrorKind::UnterminatedBlockComment => {
200                write!(f, "unterminated /* */ comment at byte {}", self.pos)
201            }
202            LexErrorKind::BadNumber(s) => {
203                write!(f, "invalid number literal {s:?} at byte {}", self.pos)
204            }
205        }
206    }
207}
208
209/// Tokenize `input` into a `Vec<Token>` ending in `Token::Eof`.
210#[allow(clippy::too_many_lines)] // big match — splitting would obscure the dispatch table
211pub fn tokenize(input: &str) -> Result<Vec<Token>, LexError> {
212    let bytes = input.as_bytes();
213    let mut i = 0usize;
214    let mut out = Vec::new();
215
216    while i < bytes.len() {
217        let b = bytes[i];
218        match b {
219            b' ' | b'\t' | b'\n' | b'\r' => {
220                i += 1;
221            }
222            b'-' if peek_eq(bytes, i + 1, b'-') => {
223                i += 2;
224                while i < bytes.len() && bytes[i] != b'\n' {
225                    i += 1;
226                }
227            }
228            b'/' if peek_eq(bytes, i + 1, b'*') => {
229                let start = i;
230                // v7.14.0 — MySQL versioned conditional comment
231                // `/*!NNNNN <body> */`. The body is real SQL that
232                // MySQL/MariaDB executes when the runtime version
233                // matches the 5-digit code; PG strips the whole
234                // thing as a block comment. SPG sides with MySQL
235                // semantics for dump compatibility: skip the
236                // `/*!NNNNN ` prefix and continue lexing the body
237                // as ordinary tokens. The closing `*/` is later
238                // matched + skipped by the symmetric arm below.
239                if peek_eq(bytes, i + 2, b'!') {
240                    let mut j = i + 3;
241                    // skip the optional 5-digit version code +
242                    // following single whitespace
243                    while j < bytes.len() && bytes[j].is_ascii_digit() {
244                        j += 1;
245                    }
246                    if j < bytes.len() && (bytes[j] == b' ' || bytes[j] == b'\t') {
247                        j += 1;
248                    }
249                    i = j;
250                    continue;
251                }
252                i += 2;
253                let mut closed = false;
254                while i + 1 < bytes.len() {
255                    if bytes[i] == b'*' && bytes[i + 1] == b'/' {
256                        i += 2;
257                        closed = true;
258                        break;
259                    }
260                    i += 1;
261                }
262                if !closed {
263                    return Err(LexError {
264                        kind: LexErrorKind::UnterminatedBlockComment,
265                        pos: start,
266                    });
267                }
268            }
269            // v7.14.0 — bare `*/` (closing of the v7.14 MySQL
270            // versioned-comment opener that didn't consume the
271            // closer). We treat it as an inline comment terminator
272            // and skip 2 bytes.
273            b'*' if peek_eq(bytes, i + 1, b'/') => {
274                i += 2;
275            }
276            b'\'' => {
277                let (tok, consumed) = lex_quoted(input, i, b'\'', false)?;
278                out.push(tok);
279                i += consumed;
280            }
281            b'"' => {
282                let (tok, consumed) = lex_quoted(input, i, b'"', true)?;
283                out.push(tok);
284                i += consumed;
285            }
286            // MySQL-flavoured backtick-quoted identifier. Same semantics
287            // as the standard `"..."` form, including embedded "``" as
288            // a literal backtick.
289            b'`' => {
290                let (tok, consumed) = lex_quoted(input, i, b'`', true)?;
291                out.push(tok);
292                i += consumed;
293            }
294            b if b.is_ascii_alphabetic() || b == b'_' => {
295                let start = i;
296                i += 1;
297                while i < bytes.len() {
298                    let c = bytes[i];
299                    if c.is_ascii_alphanumeric() || c == b'_' {
300                        i += 1;
301                    } else {
302                        break;
303                    }
304                }
305                let raw = &input[start..i];
306                // v3.0.5: try the keyword table case-insensitively
307                // without allocating; only the ident fall-through
308                // pays for a lowercase String.
309                out.push(keyword_or_ident_raw(raw));
310            }
311            b if b.is_ascii_digit() => {
312                let (tok, consumed) =
313                    lex_number(&input[i..]).map_err(|kind| LexError { kind, pos: i })?;
314                out.push(tok);
315                i += consumed;
316            }
317            b'.' if peek_pred(bytes, i + 1, u8::is_ascii_digit) => {
318                let (tok, consumed) =
319                    lex_number(&input[i..]).map_err(|kind| LexError { kind, pos: i })?;
320                out.push(tok);
321                i += consumed;
322            }
323            b'+' => single(&mut out, Token::Plus, &mut i),
324            b'-' => {
325                // v4.14: `->>` and `->` for JSON path access. `->>`
326                // must be tried before `->` (longest match).
327                if peek_eq(bytes, i + 1, b'>') && peek_eq(bytes, i + 2, b'>') {
328                    out.push(Token::JsonGetText);
329                    i += 3;
330                } else if peek_eq(bytes, i + 1, b'>') {
331                    out.push(Token::JsonGet);
332                    i += 2;
333                } else {
334                    single(&mut out, Token::Minus, &mut i);
335                }
336            }
337            // v6.4.5: `#>>` and `#>` JSON path walk.
338            b'#' => {
339                if peek_eq(bytes, i + 1, b'>') && peek_eq(bytes, i + 2, b'>') {
340                    out.push(Token::JsonGetPathText);
341                    i += 3;
342                } else if peek_eq(bytes, i + 1, b'>') {
343                    out.push(Token::JsonGetPath);
344                    i += 2;
345                } else {
346                    return Err(LexError {
347                        kind: LexErrorKind::UnknownChar('#'),
348                        pos: i,
349                    });
350                }
351            }
352            // v6.4.5: `@>` JSON containment.
353            // v7.12.2: `@@` tsvector / tsquery match.
354            // v7.14.0: `@@NAME` MySQL session variable ref +
355            //          `@NAME` user variable ref. mysqldump preamble
356            //          uses both heavily (`SET @OLD_FOREIGN_KEY_CHECKS
357            //          = @@FOREIGN_KEY_CHECKS, FOREIGN_KEY_CHECKS=0`).
358            //          We lex both as a single SessionVar token so
359            //          the parser can accept and ignore them.
360            b'@' => {
361                if peek_eq(bytes, i + 1, b'>') {
362                    out.push(Token::JsonContains);
363                    i += 2;
364                } else if peek_eq(bytes, i + 1, b'@')
365                    && !is_session_var_ident_start(bytes.get(i + 2).copied())
366                {
367                    // `@@` not followed by an ident-start byte is
368                    // the tsquery `@@` operator.
369                    out.push(Token::TsMatch);
370                    i += 2;
371                } else {
372                    // `@VAR` / `@@VAR` — MySQL user / session
373                    // variable reference. Consume the ident-shaped
374                    // tail and emit as Token::SessionVar so the
375                    // SET parser can accept-and-ignore.
376                    let prefix_end = if peek_eq(bytes, i + 1, b'@') {
377                        i + 2
378                    } else {
379                        i + 1
380                    };
381                    let mut end = prefix_end;
382                    while end < bytes.len() && is_session_var_ident_continue(bytes[end]) {
383                        end += 1;
384                    }
385                    if end == prefix_end {
386                        return Err(LexError {
387                            kind: LexErrorKind::UnknownChar('@'),
388                            pos: i,
389                        });
390                    }
391                    out.push(Token::SessionVar(input[i..end].to_string()));
392                    i = end;
393                }
394            }
395            b'*' => single(&mut out, Token::Star, &mut i),
396            b'/' => single(&mut out, Token::Slash, &mut i),
397            b'(' => single(&mut out, Token::LParen, &mut i),
398            b')' => single(&mut out, Token::RParen, &mut i),
399            b'[' => single(&mut out, Token::LBracket, &mut i),
400            b']' => single(&mut out, Token::RBracket, &mut i),
401            b',' => single(&mut out, Token::Comma, &mut i),
402            b';' => single(&mut out, Token::Semicolon, &mut i),
403            b'.' => single(&mut out, Token::Dot, &mut i),
404            b'=' => single(&mut out, Token::Eq, &mut i),
405            b'<' => {
406                if peek_eq(bytes, i + 1, b'=') && peek_eq(bytes, i + 2, b'>') {
407                    out.push(Token::CosineDistance);
408                    i += 3;
409                } else if peek_eq(bytes, i + 1, b'#') && peek_eq(bytes, i + 2, b'>') {
410                    out.push(Token::InnerProduct);
411                    i += 3;
412                } else if peek_eq(bytes, i + 1, b'-') && peek_eq(bytes, i + 2, b'>') {
413                    out.push(Token::L2Distance);
414                    i += 3;
415                } else if peek_eq(bytes, i + 1, b'=') {
416                    out.push(Token::LtEq);
417                    i += 2;
418                } else if peek_eq(bytes, i + 1, b'>') {
419                    out.push(Token::NotEq);
420                    i += 2;
421                } else {
422                    out.push(Token::Lt);
423                    i += 1;
424                }
425            }
426            b':' if peek_eq(bytes, i + 1, b':') => {
427                out.push(Token::DoubleColon);
428                i += 2;
429            }
430            b':' if peek_eq(bytes, i + 1, b'=') => {
431                // v7.12.4 — PL/pgSQL assignment operator `:=`.
432                out.push(Token::ColonEq);
433                i += 2;
434            }
435            b':' => {
436                // v7.12.4 — bare `:`. Used inside `tsvector` external-form
437                // literals which the cast parser consumes in-token, and as a
438                // separator the PL/pgSQL assignment lexer can recover from.
439                out.push(Token::Colon);
440                i += 1;
441            }
442            b'|' if peek_eq(bytes, i + 1, b'|') => {
443                out.push(Token::Concat);
444                i += 2;
445            }
446            b'>' => {
447                if peek_eq(bytes, i + 1, b'=') {
448                    out.push(Token::GtEq);
449                    i += 2;
450                } else {
451                    out.push(Token::Gt);
452                    i += 1;
453                }
454            }
455            b'!' if peek_eq(bytes, i + 1, b'=') => {
456                out.push(Token::NotEq);
457                i += 2;
458            }
459            // v7.9.27 — PG dollar-quoted string `$$ … $$` (or
460            // `$tag$ … $tag$`). Used in `DO $$ … $$ LANGUAGE
461            // plpgsql;` blocks that pg_dump emits for idempotent
462            // migrations. SPG has no PL/pgSQL, so the lexer
463            // consumes the entire string as a single Token::String
464            // and the parser treats the surrounding `DO …;` as a
465            // no-op. mailrs follow-up H1.
466            b'$' if i + 1 < bytes.len() && bytes[i + 1] == b'$' => {
467                // Empty tag form: `$$ … $$`.
468                let end = find_dollar_tag_end(bytes, i + 2, b"$$");
469                let body = match end {
470                    Some(e) => &input[i + 2..e],
471                    None => {
472                        return Err(LexError {
473                            kind: LexErrorKind::UnterminatedString,
474                            pos: i,
475                        });
476                    }
477                };
478                out.push(Token::String(body.to_string()));
479                i = end.unwrap() + 2;
480            }
481            b'$' if i + 1 < bytes.len()
482                && (bytes[i + 1].is_ascii_alphabetic() || bytes[i + 1] == b'_') =>
483            {
484                // Tagged form: `$foo$ … $foo$`. Scan the tag
485                // ident, find the closing copy.
486                let mut j = i + 1;
487                while j < bytes.len() && (bytes[j].is_ascii_alphanumeric() || bytes[j] == b'_') {
488                    j += 1;
489                }
490                if j >= bytes.len() || bytes[j] != b'$' {
491                    // Not a dollar-quoted string — fall through
492                    // to the generic-unknown-char path.
493                    let ch = input[i..].chars().next().unwrap_or('?');
494                    return Err(LexError {
495                        kind: LexErrorKind::UnknownChar(ch),
496                        pos: i,
497                    });
498                }
499                let close: alloc::vec::Vec<u8> = bytes[i..=j].to_vec();
500                let end = find_dollar_tag_end(bytes, j + 1, &close);
501                let body = match end {
502                    Some(e) => &input[j + 1..e],
503                    None => {
504                        return Err(LexError {
505                            kind: LexErrorKind::UnterminatedString,
506                            pos: i,
507                        });
508                    }
509                };
510                out.push(Token::String(body.to_string()));
511                i = end.unwrap() + close.len();
512            }
513            // v6.1.1: `$N` parameter placeholder for the extended
514            // query protocol. PG numbers them 1..=N; we reject $0
515            // and a bare `$` not followed by a digit.
516            b'$' if i + 1 < bytes.len() && bytes[i + 1].is_ascii_digit() => {
517                let mut j = i + 1;
518                let mut n: u32 = 0;
519                while j < bytes.len() && bytes[j].is_ascii_digit() {
520                    n = n
521                        .saturating_mul(10)
522                        .saturating_add(u32::from(bytes[j] - b'0'));
523                    j += 1;
524                }
525                if n == 0 || n > u32::from(u16::MAX) {
526                    return Err(LexError {
527                        kind: LexErrorKind::BadNumber(input[i..j].to_string()),
528                        pos: i,
529                    });
530                }
531                #[allow(clippy::cast_possible_truncation)]
532                out.push(Token::Placeholder(n as u16));
533                i = j;
534            }
535            _ => {
536                let ch = input[i..].chars().next().unwrap_or('?');
537                return Err(LexError {
538                    kind: LexErrorKind::UnknownChar(ch),
539                    pos: i,
540                });
541            }
542        }
543    }
544    out.push(Token::Eof);
545    Ok(out)
546}
547
548fn peek_eq(bytes: &[u8], i: usize, target: u8) -> bool {
549    bytes.get(i) == Some(&target)
550}
551
552/// v7.14.0 — recognise the first byte of a MySQL session/user
553/// variable name (after `@` or `@@`). PG-strict idents are ASCII
554/// letter or underscore; MySQL also allows leading digits inside
555/// quoted names but unquoted vars match the same shape.
556fn is_session_var_ident_start(b: Option<u8>) -> bool {
557    matches!(b, Some(c) if c.is_ascii_alphabetic() || c == b'_')
558}
559
560/// Continuation byte for a `@VAR`/`@@VAR` ident (after the first
561/// alphabet/underscore byte). Letters, digits, underscore, dot
562/// (MySQL allows session-scope qualifiers like
563/// `@@global.sql_mode`) and `$` (some MySQL versions accept it).
564fn is_session_var_ident_continue(b: u8) -> bool {
565    b.is_ascii_alphanumeric() || b == b'_' || b == b'.' || b == b'$'
566}
567
568/// v7.9.27 — find the start index of the next occurrence of `tag`
569/// (e.g. `b"$$"` or `b"$foo$"`) in `bytes` starting at `from`.
570fn find_dollar_tag_end(bytes: &[u8], from: usize, tag: &[u8]) -> Option<usize> {
571    if tag.is_empty() || from > bytes.len() {
572        return None;
573    }
574    let mut i = from;
575    while i + tag.len() <= bytes.len() {
576        if &bytes[i..i + tag.len()] == tag {
577            return Some(i);
578        }
579        i += 1;
580    }
581    None
582}
583
584fn peek_pred<F: Fn(&u8) -> bool>(bytes: &[u8], i: usize, pred: F) -> bool {
585    bytes.get(i).is_some_and(pred)
586}
587
588fn single(out: &mut Vec<Token>, tok: Token, i: &mut usize) {
589    out.push(tok);
590    *i += 1;
591}
592
593/// Length-first ASCII-CI keyword lookup. Avoids allocating a
594/// lowercase `String` when the input matches a keyword; only the ident
595/// fall-through path pays for the lowercase copy.
596///
597/// Grouped by length so the outer `match` becomes a small jump table.
598/// Within a length bucket every keyword has either a unique first
599/// byte (cheap dispatch) or a small set of disambiguating
600/// trailing-byte comparisons. All comparisons are ASCII-CI (XOR
601/// 0x20 on each byte before the compare).
602fn keyword_or_ident_raw(raw: &str) -> Token {
603    let b = raw.as_bytes();
604    let tok = match b.len() {
605        2 => kw_len2(b),
606        3 => kw_len3(b),
607        4 => kw_len4(b),
608        5 => kw_len5(b),
609        6 => kw_len6(b),
610        7 => kw_len7(b),
611        8 => kw_len8(b),
612        9 => kw_len9(b),
613        10 => kw_len10(b),
614        11 => kw_len11(b),
615        12 => kw_len12(b),
616        _ => None,
617    };
618    match tok {
619        Some(t) => t,
620        // Ident fall-through: this is the only path that allocates.
621        None => Token::Ident(raw.to_ascii_lowercase()),
622    }
623}
624
625/// ASCII-CI equality on a byte slice against a lowercase literal.
626/// Letters that differ only in case satisfy `(a ^ b) == 0x20`; other
627/// mismatches set bits outside the 0x20 mask. We compare each byte
628/// against its lowercase form via `to_ascii_lowercase` for clarity;
629/// the compiler folds the loop into a tight cmov chain.
630#[inline]
631fn eq_ci(input: &[u8], lower: &[u8]) -> bool {
632    if input.len() != lower.len() {
633        return false;
634    }
635    for i in 0..lower.len() {
636        if input[i].to_ascii_lowercase() != lower[i] {
637            return false;
638        }
639    }
640    true
641}
642
643#[inline]
644fn kw_len2(b: &[u8]) -> Option<Token> {
645    // 7 keywords: as, by, in, is, on, or, to
646    if eq_ci(b, b"as") {
647        return Some(Token::As);
648    }
649    if eq_ci(b, b"by") {
650        return Some(Token::By);
651    }
652    if eq_ci(b, b"in") {
653        return Some(Token::In);
654    }
655    if eq_ci(b, b"is") {
656        return Some(Token::Is);
657    }
658    if eq_ci(b, b"on") {
659        return Some(Token::On);
660    }
661    if eq_ci(b, b"or") {
662        return Some(Token::Or);
663    }
664    if eq_ci(b, b"to") {
665        return Some(Token::To);
666    }
667    None
668}
669
670#[inline]
671fn kw_len3(b: &[u8]) -> Option<Token> {
672    // 5 keywords: all, and, asc, not, for
673    if eq_ci(b, b"for") {
674        return Some(Token::For);
675    }
676    if eq_ci(b, b"all") {
677        return Some(Token::All);
678    }
679    if eq_ci(b, b"and") {
680        return Some(Token::And);
681    }
682    if eq_ci(b, b"asc") {
683        return Some(Token::Asc);
684    }
685    if eq_ci(b, b"not") {
686        return Some(Token::Not);
687    }
688    None
689}
690
691#[inline]
692fn kw_len4(b: &[u8]) -> Option<Token> {
693    // 10 keywords: from, null, true, into, like, join, left, show, desc, drop
694    if eq_ci(b, b"from") {
695        return Some(Token::From);
696    }
697    if eq_ci(b, b"drop") {
698        return Some(Token::Drop);
699    }
700    if eq_ci(b, b"null") {
701        return Some(Token::Null);
702    }
703    if eq_ci(b, b"true") {
704        return Some(Token::True);
705    }
706    if eq_ci(b, b"into") {
707        return Some(Token::Into);
708    }
709    if eq_ci(b, b"like") {
710        return Some(Token::Like);
711    }
712    if eq_ci(b, b"join") {
713        return Some(Token::Join);
714    }
715    if eq_ci(b, b"left") {
716        return Some(Token::Left);
717    }
718    if eq_ci(b, b"show") {
719        return Some(Token::Show);
720    }
721    if eq_ci(b, b"desc") {
722        return Some(Token::Desc);
723    }
724    None
725}
726
727#[inline]
728fn kw_len5(b: &[u8]) -> Option<Token> {
729    // 12 keywords: false, where, table, index, begin, order, limit,
730    // group, union, inner, cross, outer
731    if eq_ci(b, b"false") {
732        return Some(Token::False);
733    }
734    if eq_ci(b, b"where") {
735        return Some(Token::Where);
736    }
737    if eq_ci(b, b"table") {
738        return Some(Token::Table);
739    }
740    if eq_ci(b, b"index") {
741        return Some(Token::Index);
742    }
743    if eq_ci(b, b"begin") {
744        return Some(Token::Begin);
745    }
746    if eq_ci(b, b"order") {
747        return Some(Token::Order);
748    }
749    if eq_ci(b, b"limit") {
750        return Some(Token::Limit);
751    }
752    if eq_ci(b, b"group") {
753        return Some(Token::Group);
754    }
755    if eq_ci(b, b"union") {
756        return Some(Token::Union);
757    }
758    if eq_ci(b, b"inner") {
759        return Some(Token::Inner);
760    }
761    if eq_ci(b, b"cross") {
762        return Some(Token::Cross);
763    }
764    if eq_ci(b, b"outer") {
765        return Some(Token::Outer);
766    }
767    None
768}
769
770#[inline]
771fn kw_len6(b: &[u8]) -> Option<Token> {
772    // 9 keywords: select, create, insert, values, commit, having, offset, tables, except
773    if eq_ci(b, b"select") {
774        return Some(Token::Select);
775    }
776    if eq_ci(b, b"tables") {
777        return Some(Token::Tables);
778    }
779    if eq_ci(b, b"except") {
780        return Some(Token::Except);
781    }
782    if eq_ci(b, b"create") {
783        return Some(Token::Create);
784    }
785    if eq_ci(b, b"insert") {
786        return Some(Token::Insert);
787    }
788    if eq_ci(b, b"values") {
789        return Some(Token::Values);
790    }
791    if eq_ci(b, b"commit") {
792        return Some(Token::Commit);
793    }
794    if eq_ci(b, b"having") {
795        return Some(Token::Having);
796    }
797    if eq_ci(b, b"offset") {
798        return Some(Token::Offset);
799    }
800    None
801}
802
803#[inline]
804fn kw_len7(b: &[u8]) -> Option<Token> {
805    // 4 keywords: between, default, release, extract
806    if eq_ci(b, b"between") {
807        return Some(Token::Between);
808    }
809    if eq_ci(b, b"default") {
810        return Some(Token::Default);
811    }
812    if eq_ci(b, b"release") {
813        return Some(Token::Release);
814    }
815    if eq_ci(b, b"extract") {
816        return Some(Token::Extract);
817    }
818    None
819}
820
821#[inline]
822fn kw_len8(b: &[u8]) -> Option<Token> {
823    // 3 keywords: rollback, distinct, interval
824    if eq_ci(b, b"rollback") {
825        return Some(Token::Rollback);
826    }
827    if eq_ci(b, b"distinct") {
828        return Some(Token::Distinct);
829    }
830    if eq_ci(b, b"interval") {
831        return Some(Token::Interval);
832    }
833    None
834}
835
836#[inline]
837fn kw_len9(b: &[u8]) -> Option<Token> {
838    // 1 keyword: savepoint
839    if eq_ci(b, b"savepoint") {
840        return Some(Token::Savepoint);
841    }
842    None
843}
844
845#[inline]
846fn kw_len10(b: &[u8]) -> Option<Token> {
847    // 1 keyword: connection
848    if eq_ci(b, b"connection") {
849        return Some(Token::Connection);
850    }
851    None
852}
853
854#[inline]
855fn kw_len11(b: &[u8]) -> Option<Token> {
856    // 1 keyword: publication
857    if eq_ci(b, b"publication") {
858        return Some(Token::Publication);
859    }
860    None
861}
862
863#[inline]
864fn kw_len12(b: &[u8]) -> Option<Token> {
865    // 1 keyword: subscription
866    if eq_ci(b, b"subscription") {
867        return Some(Token::Subscription);
868    }
869    None
870}
871
872/// Lex a `'...'` string literal or `"..."` quoted identifier. The opening
873/// quote sits at `input[start]`; `quote` is its byte value. `is_ident` selects
874/// the resulting token shape.
875///
876/// PG-style doubling escapes the quote: `''` inside `'...'` is a literal `'`,
877/// same for `""` inside `"..."`.
878fn lex_quoted(
879    input: &str,
880    start: usize,
881    quote: u8,
882    is_ident: bool,
883) -> Result<(Token, usize), LexError> {
884    let bytes = input.as_bytes();
885    let mut i = start + 1;
886    let mut s = String::new();
887    loop {
888        if i >= bytes.len() {
889            return Err(LexError {
890                kind: if is_ident {
891                    LexErrorKind::UnterminatedQuotedIdent
892                } else {
893                    LexErrorKind::UnterminatedString
894                },
895                pos: start,
896            });
897        }
898        if bytes[i] == quote {
899            if peek_eq(bytes, i + 1, quote) {
900                s.push(quote as char);
901                i += 2;
902            } else {
903                i += 1;
904                break;
905            }
906        } else {
907            let ch = input[i..].chars().next().expect("non-empty UTF-8 boundary");
908            s.push(ch);
909            i += ch.len_utf8();
910        }
911    }
912    let tok = if is_ident {
913        Token::QuotedIdent(s)
914    } else {
915        Token::String(s)
916    };
917    Ok((tok, i - start))
918}
919
920fn lex_number(s: &str) -> Result<(Token, usize), LexErrorKind> {
921    let bytes = s.as_bytes();
922    let mut i = 0usize;
923    let mut is_float = false;
924
925    while i < bytes.len() && bytes[i].is_ascii_digit() {
926        i += 1;
927    }
928    if i < bytes.len() && bytes[i] == b'.' {
929        is_float = true;
930        i += 1;
931        while i < bytes.len() && bytes[i].is_ascii_digit() {
932            i += 1;
933        }
934    }
935    if i < bytes.len() && (bytes[i] == b'e' || bytes[i] == b'E') {
936        is_float = true;
937        i += 1;
938        if i < bytes.len() && (bytes[i] == b'+' || bytes[i] == b'-') {
939            i += 1;
940        }
941        let exp_start = i;
942        while i < bytes.len() && bytes[i].is_ascii_digit() {
943            i += 1;
944        }
945        if exp_start == i {
946            return Err(LexErrorKind::BadNumber(s[..i].to_string()));
947        }
948    }
949
950    let lit = &s[..i];
951    if is_float {
952        lit.parse::<f64>()
953            .map(|v| (Token::Float(v), i))
954            .map_err(|_| LexErrorKind::BadNumber(lit.to_string()))
955    } else {
956        lit.parse::<i64>()
957            .map(|v| (Token::Integer(v), i))
958            .map_err(|_| LexErrorKind::BadNumber(lit.to_string()))
959    }
960}
961
962#[cfg(test)]
963mod tests {
964    use super::*;
965    use alloc::vec;
966
967    fn lex(s: &str) -> Vec<Token> {
968        tokenize(s).expect("lex ok")
969    }
970
971    #[test]
972    fn empty_yields_only_eof() {
973        assert_eq!(lex(""), vec![Token::Eof]);
974    }
975
976    #[test]
977    fn whitespace_only_yields_only_eof() {
978        assert_eq!(lex("   \t\n  "), vec![Token::Eof]);
979    }
980
981    #[test]
982    fn keywords_are_case_insensitive() {
983        assert_eq!(
984            lex("SELECT select Select"),
985            vec![Token::Select, Token::Select, Token::Select, Token::Eof]
986        );
987    }
988
989    #[test]
990    fn identifiers_lowercase_ascii() {
991        assert_eq!(
992            lex("hello WORLD _x x1"),
993            vec![
994                Token::Ident("hello".into()),
995                Token::Ident("world".into()),
996                Token::Ident("_x".into()),
997                Token::Ident("x1".into()),
998                Token::Eof,
999            ]
1000        );
1001    }
1002
1003    #[test]
1004    fn quoted_identifier_keeps_case_and_handles_embedded_quote() {
1005        assert_eq!(
1006            lex(r#""User Name" "a""b""#),
1007            vec![
1008                Token::QuotedIdent("User Name".into()),
1009                Token::QuotedIdent("a\"b".into()),
1010                Token::Eof,
1011            ]
1012        );
1013    }
1014
1015    #[test]
1016    fn integer_and_float_literals() {
1017        assert_eq!(
1018            lex("0 42 1.5 .5 1e10 2.5e-3"),
1019            vec![
1020                Token::Integer(0),
1021                Token::Integer(42),
1022                Token::Float(1.5),
1023                Token::Float(0.5),
1024                Token::Float(1e10),
1025                Token::Float(2.5e-3),
1026                Token::Eof,
1027            ]
1028        );
1029    }
1030
1031    #[test]
1032    fn negative_number_is_minus_then_integer() {
1033        // PG follows this: unary minus is a separate token, parser folds it.
1034        assert_eq!(
1035            lex("-42"),
1036            vec![Token::Minus, Token::Integer(42), Token::Eof]
1037        );
1038    }
1039
1040    #[test]
1041    fn string_literal_doubled_quote_escape() {
1042        assert_eq!(
1043            lex("'hello' 'it''s'"),
1044            vec![
1045                Token::String("hello".into()),
1046                Token::String("it's".into()),
1047                Token::Eof,
1048            ]
1049        );
1050    }
1051
1052    #[test]
1053    fn all_comparison_and_arithmetic_operators() {
1054        assert_eq!(
1055            lex("= <> != < <= > >= + - * /"),
1056            vec![
1057                Token::Eq,
1058                Token::NotEq,
1059                Token::NotEq,
1060                Token::Lt,
1061                Token::LtEq,
1062                Token::Gt,
1063                Token::GtEq,
1064                Token::Plus,
1065                Token::Minus,
1066                Token::Star,
1067                Token::Slash,
1068                Token::Eof,
1069            ]
1070        );
1071    }
1072
1073    #[test]
1074    fn punctuation() {
1075        assert_eq!(
1076            lex("( ) , ; ."),
1077            vec![
1078                Token::LParen,
1079                Token::RParen,
1080                Token::Comma,
1081                Token::Semicolon,
1082                Token::Dot,
1083                Token::Eof,
1084            ]
1085        );
1086    }
1087
1088    #[test]
1089    fn line_comment_skipped() {
1090        assert_eq!(
1091            lex("SELECT -- trailing junk\nFROM"),
1092            vec![Token::Select, Token::From, Token::Eof]
1093        );
1094    }
1095
1096    #[test]
1097    fn block_comment_skipped() {
1098        assert_eq!(
1099            lex("SELECT /* skipped */ 1"),
1100            vec![Token::Select, Token::Integer(1), Token::Eof]
1101        );
1102    }
1103
1104    #[test]
1105    fn unterminated_string_errors() {
1106        let err = tokenize("'oops").unwrap_err();
1107        assert!(matches!(err.kind, LexErrorKind::UnterminatedString));
1108        assert_eq!(err.pos, 0);
1109    }
1110
1111    #[test]
1112    fn unterminated_block_comment_errors() {
1113        let err = tokenize("/* never closed").unwrap_err();
1114        assert!(matches!(err.kind, LexErrorKind::UnterminatedBlockComment));
1115    }
1116
1117    #[test]
1118    fn unknown_char_errors() {
1119        let err = tokenize("@").unwrap_err();
1120        assert!(matches!(err.kind, LexErrorKind::UnknownChar('@')));
1121    }
1122
1123    #[test]
1124    fn dot_in_qualified_column() {
1125        assert_eq!(
1126            lex("t.col"),
1127            vec![
1128                Token::Ident("t".into()),
1129                Token::Dot,
1130                Token::Ident("col".into()),
1131                Token::Eof,
1132            ]
1133        );
1134    }
1135
1136    // --- v0.11 brackets + distance op + vector keyword --------------------
1137
1138    #[test]
1139    fn brackets_are_distinct_tokens() {
1140        assert_eq!(
1141            lex("[ ]"),
1142            vec![Token::LBracket, Token::RBracket, Token::Eof]
1143        );
1144    }
1145
1146    #[test]
1147    fn l2_distance_is_three_char_token() {
1148        assert_eq!(
1149            lex("a <-> b"),
1150            vec![
1151                Token::Ident("a".into()),
1152                Token::L2Distance,
1153                Token::Ident("b".into()),
1154                Token::Eof,
1155            ]
1156        );
1157        // Bare `<-` should NOT match L2Distance.
1158        assert_eq!(
1159            lex("a <- b"),
1160            vec![
1161                Token::Ident("a".into()),
1162                Token::Lt,
1163                Token::Minus,
1164                Token::Ident("b".into()),
1165                Token::Eof,
1166            ]
1167        );
1168    }
1169
1170    #[test]
1171    fn order_by_limit_are_keywords() {
1172        assert_eq!(
1173            lex("ORDER BY LIMIT"),
1174            vec![Token::Order, Token::By, Token::Limit, Token::Eof]
1175        );
1176    }
1177
1178    // --- v1.2: pgvector distance ops + PG cast --------------------------
1179
1180    #[test]
1181    fn inner_product_operator_3char() {
1182        assert_eq!(
1183            lex("a <#> b"),
1184            vec![
1185                Token::Ident("a".into()),
1186                Token::InnerProduct,
1187                Token::Ident("b".into()),
1188                Token::Eof,
1189            ]
1190        );
1191    }
1192
1193    #[test]
1194    fn cosine_distance_operator_3char() {
1195        assert_eq!(
1196            lex("a <=> b"),
1197            vec![
1198                Token::Ident("a".into()),
1199                Token::CosineDistance,
1200                Token::Ident("b".into()),
1201                Token::Eof,
1202            ]
1203        );
1204        // Make sure `<=` and `<>` and `<->` still lex right when `<=>` is
1205        // around (greedy match takes the longest).
1206        assert_eq!(
1207            lex("a <= b"),
1208            vec![
1209                Token::Ident("a".into()),
1210                Token::LtEq,
1211                Token::Ident("b".into()),
1212                Token::Eof,
1213            ]
1214        );
1215    }
1216
1217    #[test]
1218    fn double_colon_cast_token() {
1219        assert_eq!(
1220            lex("x::INT"),
1221            vec![
1222                Token::Ident("x".into()),
1223                Token::DoubleColon,
1224                Token::Ident("int".into()),
1225                Token::Eof,
1226            ]
1227        );
1228    }
1229
1230    #[test]
1231    fn lone_single_colon_lexes_as_colon_token() {
1232        // v7.12.4 — single `:` is now a token (PL/pgSQL surface
1233        // + tsvector external-form literal both need it). The
1234        // pre-v7.12.4 "single colon = unknown char" behaviour
1235        // was incidental.
1236        let toks = tokenize(":x").expect("colon now lexes");
1237        assert_eq!(toks[0], Token::Colon);
1238    }
1239
1240    #[test]
1241    fn colon_eq_lexes_as_assignment() {
1242        // v7.12.4 — PL/pgSQL assignment operator.
1243        let toks = tokenize("x := 1").expect("colon-eq lexes");
1244        // Tokens: Ident("x"), ColonEq, NumberLiteral
1245        assert!(matches!(toks[1], Token::ColonEq));
1246    }
1247}