rustpython_parser/
parser.rs

1//! Contains the interface to the Python parser.
2//!
3//! Functions in this module can be used to parse Python code into an [Abstract Syntax Tree]
4//! (AST) that is then transformed into bytecode.
5//!
6//! There are three ways to parse Python code corresponding to the different [`Mode`]s
7//! defined in the [`mode`] module.
8//!
9//! All functions return a [`Result`](std::result::Result) containing the parsed AST or
10//! a [`ParseError`] if parsing failed.
11//!
12//! [Abstract Syntax Tree]: https://en.wikipedia.org/wiki/Abstract_syntax_tree
13//! [`Mode`]: crate::mode
14
15use crate::{
16    ast::{self, OptionalRange, Ranged},
17    lexer::{self, LexResult, LexicalError, LexicalErrorType},
18    python,
19    text_size::TextSize,
20    token::Tok,
21    Mode,
22};
23use itertools::Itertools;
24use std::iter;
25
26use crate::{lexer::Lexer, soft_keywords::SoftKeywordTransformer, text_size::TextRange};
27pub(super) use lalrpop_util::ParseError as LalrpopError;
28
29/// Parse Python code string to implementor's type.
30///
31/// # Example
32///
33/// For example, parsing a simple function definition and a call to that function:
34///
35/// ```
36/// use rustpython_parser::{self as parser, ast, Parse};
37/// let source = r#"
38/// def foo():
39///    return 42
40///
41/// print(foo())
42/// "#;
43/// let program = ast::Suite::parse(source, "<embedded>");
44/// assert!(program.is_ok());
45/// ```
46///
47/// Parsing a single expression denoting the addition of two numbers, but this time specifying a different,
48/// somewhat silly, location:
49///
50/// ```
51/// use rustpython_parser::{self as parser, ast, Parse, text_size::TextSize};
52///
53/// let expr = ast::Expr::parse_starts_at("1 + 2", "<embedded>", TextSize::from(400));
54/// assert!(expr.is_ok());
55pub trait Parse
56where
57    Self: Sized,
58{
59    fn parse(source: &str, source_path: &str) -> Result<Self, ParseError> {
60        Self::parse_starts_at(source, source_path, TextSize::default())
61    }
62    fn parse_without_path(source: &str) -> Result<Self, ParseError> {
63        Self::parse(source, "<unknown>")
64    }
65    fn parse_starts_at(
66        source: &str,
67        source_path: &str,
68        offset: TextSize,
69    ) -> Result<Self, ParseError> {
70        let lxr = Self::lex_starts_at(source, offset);
71        #[cfg(feature = "full-lexer")]
72        let lxr =
73            lxr.filter_ok(|(tok, _)| !matches!(tok, Tok::Comment { .. } | Tok::NonLogicalNewline));
74        Self::parse_tokens(lxr, source_path)
75    }
76    fn lex_starts_at(
77        source: &str,
78        offset: TextSize,
79    ) -> SoftKeywordTransformer<Lexer<std::str::Chars>>;
80    fn parse_tokens(
81        lxr: impl IntoIterator<Item = LexResult>,
82        source_path: &str,
83    ) -> Result<Self, ParseError>;
84}
85
86impl Parse for ast::ModModule {
87    fn lex_starts_at(
88        source: &str,
89        offset: TextSize,
90    ) -> SoftKeywordTransformer<Lexer<std::str::Chars>> {
91        lexer::lex_starts_at(source, Mode::Module, offset)
92    }
93    fn parse_tokens(
94        lxr: impl IntoIterator<Item = LexResult>,
95        source_path: &str,
96    ) -> Result<Self, ParseError> {
97        match parse_filtered_tokens(lxr, Mode::Module, source_path)? {
98            ast::Mod::Module(m) => Ok(m),
99            _ => unreachable!("Mode::Module doesn't return other variant"),
100        }
101    }
102}
103
104impl Parse for ast::ModExpression {
105    fn lex_starts_at(
106        source: &str,
107        offset: TextSize,
108    ) -> SoftKeywordTransformer<Lexer<std::str::Chars>> {
109        lexer::lex_starts_at(source, Mode::Expression, offset)
110    }
111    fn parse_tokens(
112        lxr: impl IntoIterator<Item = LexResult>,
113        source_path: &str,
114    ) -> Result<Self, ParseError> {
115        match parse_filtered_tokens(lxr, Mode::Expression, source_path)? {
116            ast::Mod::Expression(m) => Ok(m),
117            _ => unreachable!("Mode::Module doesn't return other variant"),
118        }
119    }
120}
121
122impl Parse for ast::ModInteractive {
123    fn lex_starts_at(
124        source: &str,
125        offset: TextSize,
126    ) -> SoftKeywordTransformer<Lexer<std::str::Chars>> {
127        lexer::lex_starts_at(source, Mode::Interactive, offset)
128    }
129    fn parse_tokens(
130        lxr: impl IntoIterator<Item = LexResult>,
131        source_path: &str,
132    ) -> Result<Self, ParseError> {
133        match parse_filtered_tokens(lxr, Mode::Interactive, source_path)? {
134            ast::Mod::Interactive(m) => Ok(m),
135            _ => unreachable!("Mode::Module doesn't return other variant"),
136        }
137    }
138}
139
140impl Parse for ast::Suite {
141    fn lex_starts_at(
142        source: &str,
143        offset: TextSize,
144    ) -> SoftKeywordTransformer<Lexer<std::str::Chars>> {
145        ast::ModModule::lex_starts_at(source, offset)
146    }
147    fn parse_tokens(
148        lxr: impl IntoIterator<Item = LexResult>,
149        source_path: &str,
150    ) -> Result<Self, ParseError> {
151        Ok(ast::ModModule::parse_tokens(lxr, source_path)?.body)
152    }
153}
154
155impl Parse for ast::Stmt {
156    fn lex_starts_at(
157        source: &str,
158        offset: TextSize,
159    ) -> SoftKeywordTransformer<Lexer<std::str::Chars>> {
160        ast::ModModule::lex_starts_at(source, offset)
161    }
162    fn parse_tokens(
163        lxr: impl IntoIterator<Item = LexResult>,
164        source_path: &str,
165    ) -> Result<Self, ParseError> {
166        let mut statements = ast::ModModule::parse_tokens(lxr, source_path)?.body;
167        let statement = match statements.len() {
168            0 => {
169                return Err(ParseError {
170                    error: ParseErrorType::Eof,
171                    offset: TextSize::default(),
172                    source_path: source_path.to_owned(),
173                })
174            }
175            1 => statements.pop().unwrap(),
176            _ => {
177                return Err(ParseError {
178                    error: ParseErrorType::InvalidToken,
179                    offset: statements[1].range().start(),
180                    source_path: source_path.to_owned(),
181                })
182            }
183        };
184        Ok(statement)
185    }
186}
187
188impl Parse for ast::Expr {
189    fn lex_starts_at(
190        source: &str,
191        offset: TextSize,
192    ) -> SoftKeywordTransformer<Lexer<std::str::Chars>> {
193        ast::ModExpression::lex_starts_at(source, offset)
194    }
195    fn parse_tokens(
196        lxr: impl IntoIterator<Item = LexResult>,
197        source_path: &str,
198    ) -> Result<Self, ParseError> {
199        Ok(*ast::ModExpression::parse_tokens(lxr, source_path)?.body)
200    }
201}
202
203impl Parse for ast::Identifier {
204    fn lex_starts_at(
205        source: &str,
206        offset: TextSize,
207    ) -> SoftKeywordTransformer<Lexer<std::str::Chars>> {
208        ast::Expr::lex_starts_at(source, offset)
209    }
210    fn parse_tokens(
211        lxr: impl IntoIterator<Item = LexResult>,
212        source_path: &str,
213    ) -> Result<Self, ParseError> {
214        let expr = ast::Expr::parse_tokens(lxr, source_path)?;
215        match expr {
216            ast::Expr::Name(name) => Ok(name.id),
217            expr => Err(ParseError {
218                error: ParseErrorType::InvalidToken,
219                offset: expr.range().start(),
220                source_path: source_path.to_owned(),
221            }),
222        }
223    }
224}
225
226impl Parse for ast::Constant {
227    fn lex_starts_at(
228        source: &str,
229        offset: TextSize,
230    ) -> SoftKeywordTransformer<Lexer<std::str::Chars>> {
231        ast::Expr::lex_starts_at(source, offset)
232    }
233    fn parse_tokens(
234        lxr: impl IntoIterator<Item = LexResult>,
235        source_path: &str,
236    ) -> Result<Self, ParseError> {
237        let expr = ast::Expr::parse_tokens(lxr, source_path)?;
238        match expr {
239            ast::Expr::Constant(c) => Ok(c.value),
240            expr => Err(ParseError {
241                error: ParseErrorType::InvalidToken,
242                offset: expr.range().start(),
243                source_path: source_path.to_owned(),
244            }),
245        }
246    }
247}
248
249/// Parse a full Python program usually consisting of multiple lines.
250///  
251/// This is a convenience function that can be used to parse a full Python program without having to
252/// specify the [`Mode`] or the location. It is probably what you want to use most of the time.
253///
254/// # Example
255///
256/// For example, parsing a simple function definition and a call to that function:
257///
258/// ```
259/// use rustpython_parser as parser;
260/// let source = r#"
261/// def foo():
262///    return 42
263///
264/// print(foo())
265/// "#;
266/// let program = parser::parse_program(source, "<embedded>");
267/// assert!(program.is_ok());
268/// ```
269#[deprecated = "Use ast::Suite::parse from rustpython_parser::Parse trait."]
270pub fn parse_program(source: &str, source_path: &str) -> Result<ast::Suite, ParseError> {
271    parse(source, Mode::Module, source_path).map(|top| match top {
272        ast::Mod::Module(ast::ModModule { body, .. }) => body,
273        _ => unreachable!(),
274    })
275}
276
277/// Parses a single Python expression.
278///
279/// This convenience function can be used to parse a single expression without having to
280/// specify the Mode or the location.
281///
282/// # Example
283///
284/// For example, parsing a single expression denoting the addition of two numbers:
285///
286///  ```
287/// use rustpython_parser as parser;
288/// let expr = parser::parse_expression("1 + 2", "<embedded>");
289///
290/// assert!(expr.is_ok());
291///
292/// ```
293#[deprecated = "Use ast::Expr::parse from rustpython_parser::Parse trait."]
294pub fn parse_expression(source: &str, path: &str) -> Result<ast::Expr, ParseError> {
295    ast::Expr::parse(source, path)
296}
297
298/// Parses a Python expression from a given location.
299///
300/// This function allows to specify the location of the expression in the source code, other than
301/// that, it behaves exactly like [`parse_expression`].
302///
303/// # Example
304///
305/// Parsing a single expression denoting the addition of two numbers, but this time specifying a different,
306/// somewhat silly, location:
307///
308/// ```
309/// use rustpython_parser::{text_size::TextSize, parse_expression_starts_at};
310///
311/// let expr = parse_expression_starts_at("1 + 2", "<embedded>", TextSize::from(400));
312/// assert!(expr.is_ok());
313/// ```
314#[deprecated = "Use ast::Expr::parse_starts_at from rustpython_parser::Parse trait."]
315pub fn parse_expression_starts_at(
316    source: &str,
317    path: &str,
318    offset: TextSize,
319) -> Result<ast::Expr, ParseError> {
320    ast::Expr::parse_starts_at(source, path, offset)
321}
322
323/// Parse the given Python source code using the specified [`Mode`].
324///
325/// This function is the most general function to parse Python code. Based on the [`Mode`] supplied,
326/// it can be used to parse a single expression, a full Python program or an interactive expression.
327///
328/// # Example
329///
330/// If we want to parse a simple expression, we can use the [`Mode::Expression`] mode during
331/// parsing:
332///
333/// ```
334/// use rustpython_parser::{Mode, parse};
335///
336/// let expr = parse("1 + 2", Mode::Expression, "<embedded>");
337/// assert!(expr.is_ok());
338/// ```
339///
340/// Alternatively, we can parse a full Python program consisting of multiple lines:
341///
342/// ```
343/// use rustpython_parser::{Mode, parse};
344///
345/// let source = r#"
346/// class Greeter:
347///
348///   def greet(self):
349///    print("Hello, world!")
350/// "#;
351/// let program = parse(source, Mode::Module, "<embedded>");
352/// assert!(program.is_ok());
353/// ```
354pub fn parse(source: &str, mode: Mode, source_path: &str) -> Result<ast::Mod, ParseError> {
355    parse_starts_at(source, mode, source_path, TextSize::default())
356}
357
358/// Parse the given Python source code using the specified [`Mode`] and [`Location`].
359///
360/// This function allows to specify the location of the the source code, other than
361/// that, it behaves exactly like [`parse`].
362///
363/// # Example
364///
365/// ```
366/// use rustpython_parser::{text_size::TextSize, Mode, parse_starts_at};
367///
368/// let source = r#"
369/// def fib(i):
370///    a, b = 0, 1
371///    for _ in range(i):
372///       a, b = b, a + b
373///    return a
374///
375/// print(fib(42))
376/// "#;
377/// let program = parse_starts_at(source, Mode::Module, "<embedded>", TextSize::from(0));
378/// assert!(program.is_ok());
379/// ```
380pub fn parse_starts_at(
381    source: &str,
382    mode: Mode,
383    source_path: &str,
384    offset: TextSize,
385) -> Result<ast::Mod, ParseError> {
386    let lxr = lexer::lex_starts_at(source, mode, offset);
387    parse_tokens(lxr, mode, source_path)
388}
389
390/// Parse an iterator of [`LexResult`]s using the specified [`Mode`].
391///
392/// This could allow you to perform some preprocessing on the tokens before parsing them.
393///
394/// # Example
395///
396/// As an example, instead of parsing a string, we can parse a list of tokens after we generate
397/// them using the [`lexer::lex`] function:
398///
399/// ```
400/// use rustpython_parser::{lexer::lex, Mode, parse_tokens};
401///
402/// let expr = parse_tokens(lex("1 + 2", Mode::Expression), Mode::Expression, "<embedded>");
403/// assert!(expr.is_ok());
404/// ```
405pub fn parse_tokens(
406    lxr: impl IntoIterator<Item = LexResult>,
407    mode: Mode,
408    source_path: &str,
409) -> Result<ast::Mod, ParseError> {
410    let lxr = lxr.into_iter();
411    #[cfg(feature = "full-lexer")]
412    let lxr =
413        lxr.filter_ok(|(tok, _)| !matches!(tok, Tok::Comment { .. } | Tok::NonLogicalNewline));
414    parse_filtered_tokens(lxr, mode, source_path)
415}
416
417fn parse_filtered_tokens(
418    lxr: impl IntoIterator<Item = LexResult>,
419    mode: Mode,
420    source_path: &str,
421) -> Result<ast::Mod, ParseError> {
422    let marker_token = (Tok::start_marker(mode), Default::default());
423    let lexer = iter::once(Ok(marker_token)).chain(lxr);
424    python::TopParser::new()
425        .parse(
426            lexer
427                .into_iter()
428                .map_ok(|(t, range)| (range.start(), t, range.end())),
429        )
430        .map_err(|e| parse_error_from_lalrpop(e, source_path))
431}
432
433/// Represents represent errors that occur during parsing and are
434/// returned by the `parse_*` functions.
435pub type ParseError = rustpython_parser_core::BaseError<ParseErrorType>;
436
437/// Represents the different types of errors that can occur during parsing.
438#[derive(Debug, PartialEq)]
439pub enum ParseErrorType {
440    /// Parser encountered an unexpected end of input
441    Eof,
442    /// Parser encountered an extra token
443    ExtraToken(Tok),
444    /// Parser encountered an invalid token
445    InvalidToken,
446    /// Parser encountered an unexpected token
447    UnrecognizedToken(Tok, Option<String>),
448    // Maps to `User` type from `lalrpop-util`
449    /// Parser encountered an error during lexing.
450    Lexical(LexicalErrorType),
451}
452
453impl std::error::Error for ParseErrorType {}
454
455// Convert `lalrpop_util::ParseError` to our internal type
456fn parse_error_from_lalrpop(
457    err: LalrpopError<TextSize, Tok, LexicalError>,
458    source_path: &str,
459) -> ParseError {
460    let source_path = source_path.to_owned();
461
462    match err {
463        // TODO: Are there cases where this isn't an EOF?
464        LalrpopError::InvalidToken { location } => ParseError {
465            error: ParseErrorType::Eof,
466            offset: location,
467            source_path,
468        },
469        LalrpopError::ExtraToken { token } => ParseError {
470            error: ParseErrorType::ExtraToken(token.1),
471            offset: token.0,
472            source_path,
473        },
474        LalrpopError::User { error } => ParseError {
475            error: ParseErrorType::Lexical(error.error),
476            offset: error.location,
477            source_path,
478        },
479        LalrpopError::UnrecognizedToken { token, expected } => {
480            // Hacky, but it's how CPython does it. See PyParser_AddToken,
481            // in particular "Only one possible expected token" comment.
482            let expected = (expected.len() == 1).then(|| expected[0].clone());
483            ParseError {
484                error: ParseErrorType::UnrecognizedToken(token.1, expected),
485                offset: token.0,
486                source_path,
487            }
488        }
489        LalrpopError::UnrecognizedEof { location, expected } => {
490            // This could be an initial indentation error that we should ignore
491            let indent_error = expected == ["Indent"];
492            if indent_error {
493                ParseError {
494                    error: ParseErrorType::Lexical(LexicalErrorType::IndentationError),
495                    offset: location,
496                    source_path,
497                }
498            } else {
499                ParseError {
500                    error: ParseErrorType::Eof,
501                    offset: location,
502                    source_path,
503                }
504            }
505        }
506    }
507}
508
509impl std::fmt::Display for ParseErrorType {
510    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
511        match *self {
512            ParseErrorType::Eof => write!(f, "Got unexpected EOF"),
513            ParseErrorType::ExtraToken(ref tok) => write!(f, "Got extraneous token: {tok:?}"),
514            ParseErrorType::InvalidToken => write!(f, "Got invalid token"),
515            ParseErrorType::UnrecognizedToken(ref tok, ref expected) => {
516                if *tok == Tok::Indent {
517                    write!(f, "unexpected indent")
518                } else if expected.as_deref() == Some("Indent") {
519                    write!(f, "expected an indented block")
520                } else {
521                    write!(f, "invalid syntax. Got unexpected token {tok}")
522                }
523            }
524            ParseErrorType::Lexical(ref error) => write!(f, "{error}"),
525        }
526    }
527}
528
529impl ParseErrorType {
530    /// Returns true if the error is an indentation error.
531    pub fn is_indentation_error(&self) -> bool {
532        match self {
533            ParseErrorType::Lexical(LexicalErrorType::IndentationError) => true,
534            ParseErrorType::UnrecognizedToken(token, expected) => {
535                *token == Tok::Indent || expected.clone() == Some("Indent".to_owned())
536            }
537            _ => false,
538        }
539    }
540
541    /// Returns true if the error is a tab error.
542    pub fn is_tab_error(&self) -> bool {
543        matches!(
544            self,
545            ParseErrorType::Lexical(LexicalErrorType::TabError)
546                | ParseErrorType::Lexical(LexicalErrorType::TabsAfterSpaces)
547        )
548    }
549}
550
551#[inline(always)]
552pub(super) fn optional_range(start: TextSize, end: TextSize) -> OptionalRange<TextRange> {
553    OptionalRange::<TextRange>::new(start, end)
554}
555
556include!("gen/parse.rs");
557
558#[cfg(test)]
559mod tests {
560    use super::*;
561    use crate::{ast, Parse};
562
563    #[test]
564    fn test_parse_empty() {
565        let parse_ast = ast::Suite::parse("", "<test>").unwrap();
566        insta::assert_debug_snapshot!(parse_ast);
567    }
568
569    #[test]
570    fn test_parse_string() {
571        let source = "'Hello world'";
572        let parse_ast = ast::Suite::parse(source, "<test>").unwrap();
573        insta::assert_debug_snapshot!(parse_ast);
574    }
575
576    #[test]
577    fn test_parse_f_string() {
578        let source = "f'Hello world'";
579        let parse_ast = ast::Suite::parse(source, "<test>").unwrap();
580        insta::assert_debug_snapshot!(parse_ast);
581    }
582
583    #[test]
584    fn test_parse_print_hello() {
585        let source = "print('Hello world')";
586        let parse_ast = ast::Suite::parse(source, "<test>").unwrap();
587        insta::assert_debug_snapshot!(parse_ast);
588    }
589
590    #[test]
591    fn test_parse_print_2() {
592        let source = "print('Hello world', 2)";
593        let parse_ast = ast::Suite::parse(source, "<test>").unwrap();
594        insta::assert_debug_snapshot!(parse_ast);
595    }
596
597    #[test]
598    fn test_parse_kwargs() {
599        let source = "my_func('positional', keyword=2)";
600        let parse_ast = ast::Suite::parse(source, "<test>").unwrap();
601        insta::assert_debug_snapshot!(parse_ast);
602    }
603
604    #[test]
605    fn test_parse_if_elif_else() {
606        let source = "if 1: 10\nelif 2: 20\nelse: 30";
607        let parse_ast = ast::Suite::parse(source, "<test>").unwrap();
608        insta::assert_debug_snapshot!(parse_ast);
609    }
610
611    #[test]
612    #[cfg(feature = "all-nodes-with-ranges")]
613    fn test_parse_lambda() {
614        let source = "lambda x, y: x * y"; // lambda(x, y): x * y";
615        let parse_ast = ast::Suite::parse(source, "<test>").unwrap();
616        insta::assert_debug_snapshot!(parse_ast);
617    }
618
619    #[test]
620    fn test_parse_tuples() {
621        let source = "a, b = 4, 5";
622
623        insta::assert_debug_snapshot!(ast::Suite::parse(source, "<test>").unwrap());
624    }
625
626    #[test]
627    #[cfg(feature = "all-nodes-with-ranges")]
628    fn test_parse_class() {
629        let source = "\
630class Foo(A, B):
631 def __init__(self):
632  pass
633 def method_with_default(self, arg='default'):
634  pass
635";
636        insta::assert_debug_snapshot!(ast::Suite::parse(source, "<test>").unwrap());
637    }
638
639    #[test]
640    #[cfg(feature = "all-nodes-with-ranges")]
641    fn test_parse_class_generic_types() {
642        let source = "\
643# TypeVar
644class Foo[T](): ...
645
646# TypeVar with bound
647class Foo[T: str](): ...
648
649# TypeVar with tuple bound
650class Foo[T: (str, bytes)](): ...
651
652# Multiple TypeVar
653class Foo[T, U](): ...
654
655# Trailing comma
656class Foo[T, U,](): ...
657
658# TypeVarTuple
659class Foo[*Ts](): ...
660
661# ParamSpec
662class Foo[**P](): ...
663
664# Mixed types
665class Foo[X, Y: str, *U, **P]():
666  pass
667";
668        insta::assert_debug_snapshot!(ast::Suite::parse(source, "<test>").unwrap());
669    }
670    #[test]
671    #[cfg(feature = "all-nodes-with-ranges")]
672    fn test_parse_function_definition() {
673        let source = "\
674def func(a):
675    ...
676
677def func[T](a: T) -> T:
678    ...
679
680def func[T: str](a: T) -> T:
681    ...
682
683def func[T: (str, bytes)](a: T) -> T:
684    ...
685
686def func[*Ts](*a: *Ts):
687    ...
688
689def func[**P](*args: P.args, **kwargs: P.kwargs):
690    ...
691
692def func[T, U: str, *Ts, **P]():
693    pass
694  ";
695        insta::assert_debug_snapshot!(ast::Suite::parse(source, "<test>").unwrap());
696    }
697
698    #[test]
699    #[cfg(feature = "all-nodes-with-ranges")]
700    fn test_parse_dict_comprehension() {
701        let source = "{x1: x2 for y in z}";
702        let parse_ast = ast::Expr::parse(source, "<test>").unwrap();
703        insta::assert_debug_snapshot!(parse_ast);
704    }
705
706    #[test]
707    #[cfg(feature = "all-nodes-with-ranges")]
708    fn test_parse_list_comprehension() {
709        let source = "[x for y in z]";
710        let parse_ast = ast::Expr::parse(source, "<test>").unwrap();
711        insta::assert_debug_snapshot!(parse_ast);
712    }
713
714    #[test]
715    #[cfg(feature = "all-nodes-with-ranges")]
716    fn test_parse_double_list_comprehension() {
717        let source = "[x for y, y2 in z for a in b if a < 5 if a > 10]";
718        let parse_ast = ast::Expr::parse(source, "<test>").unwrap();
719        insta::assert_debug_snapshot!(parse_ast);
720    }
721
722    #[test]
723    #[cfg(feature = "all-nodes-with-ranges")]
724    fn test_parse_generator_comprehension() {
725        let source = "(x for y in z)";
726        let parse_ast = ast::Expr::parse(source, "<test>").unwrap();
727        insta::assert_debug_snapshot!(parse_ast);
728    }
729
730    #[test]
731    #[cfg(feature = "all-nodes-with-ranges")]
732    fn test_parse_named_expression_generator_comprehension() {
733        let source = "(x := y + 1 for y in z)";
734        let parse_ast = ast::Expr::parse(source, "<test>").unwrap();
735        insta::assert_debug_snapshot!(parse_ast);
736    }
737
738    #[test]
739    #[cfg(feature = "all-nodes-with-ranges")]
740    fn test_parse_if_else_generator_comprehension() {
741        let source = "(x if y else y for y in z)";
742        let parse_ast = ast::Expr::parse(source, "<test>").unwrap();
743        insta::assert_debug_snapshot!(parse_ast);
744    }
745
746    #[test]
747    fn test_parse_bool_op_or() {
748        let source = "x or y";
749        let parse_ast = ast::Expr::parse(source, "<test>").unwrap();
750        insta::assert_debug_snapshot!(parse_ast);
751    }
752
753    #[test]
754    fn test_parse_bool_op_and() {
755        let source = "x and y";
756        let parse_ast = ast::Expr::parse(source, "<test>").unwrap();
757        insta::assert_debug_snapshot!(parse_ast);
758    }
759
760    #[test]
761    fn test_slice() {
762        let source = "x[1:2:3]";
763        let parse_ast = ast::Expr::parse(source, "<test>").unwrap();
764        insta::assert_debug_snapshot!(parse_ast);
765    }
766
767    #[test]
768    #[cfg(feature = "all-nodes-with-ranges")]
769    fn test_with_statement() {
770        let source = "\
771with 0: pass
772with 0 as x: pass
773with 0, 1: pass
774with 0 as x, 1 as y: pass
775with 0 if 1 else 2: pass
776with 0 if 1 else 2 as x: pass
777with (): pass
778with () as x: pass
779with (0): pass
780with (0) as x: pass
781with (0,): pass
782with (0,) as x: pass
783with (0, 1): pass
784with (0, 1) as x: pass
785with (*a,): pass
786with (*a,) as x: pass
787with (0, *a): pass
788with (0, *a) as x: pass
789with (a := 0): pass
790with (a := 0) as x: pass
791with (a := 0, b := 1): pass
792with (a := 0, b := 1) as x: pass
793with (0 as a): pass
794with (0 as a,): pass
795with (0 as a, 1 as b): pass
796with (0 as a, 1 as b,): pass
797";
798        insta::assert_debug_snapshot!(ast::Suite::parse(source, "<test>").unwrap());
799    }
800
801    #[test]
802    fn test_with_statement_invalid() {
803        for source in [
804            "with 0,: pass",
805            "with 0 as x,: pass",
806            "with 0 as *x: pass",
807            "with *a: pass",
808            "with *a as x: pass",
809            "with (*a): pass",
810            "with (*a) as x: pass",
811            "with *a, 0 as x: pass",
812            "with (*a, 0 as x): pass",
813            "with 0 as x, *a: pass",
814            "with (0 as x, *a): pass",
815            "with (0 as x) as y: pass",
816            "with (0 as x), 1: pass",
817            "with ((0 as x)): pass",
818            "with a := 0 as x: pass",
819            "with (a := 0 as x): pass",
820        ] {
821            assert!(ast::Suite::parse(source, "<test>").is_err());
822        }
823    }
824
825    #[test]
826    fn test_star_index() {
827        let source = "\
828array_slice = array[0, *indexes, -1]
829array[0, *indexes, -1] = array_slice
830array[*indexes_to_select, *indexes_to_select]
831array[3:5, *indexes_to_select]
832";
833        let parse_ast = ast::Suite::parse(source, "<test>").unwrap();
834        insta::assert_debug_snapshot!(parse_ast);
835    }
836
837    #[test]
838    #[cfg(feature = "all-nodes-with-ranges")]
839    fn test_generator_expression_argument() {
840        let source = r#"' '.join(
841    sql
842    for sql in (
843        "LIMIT %d" % limit if limit else None,
844        ("OFFSET %d" % offset) if offset else None,
845    )
846)"#;
847        let parse_ast = ast::Expr::parse(source, "<test>").unwrap();
848        insta::assert_debug_snapshot!(parse_ast);
849    }
850
851    #[test]
852    fn test_try() {
853        let parse_ast = ast::Suite::parse(
854            r#"try:
855    raise ValueError(1)
856except TypeError as e:
857    print(f'caught {type(e)}')
858except OSError as e:
859    print(f'caught {type(e)}')"#,
860            "<test>",
861        )
862        .unwrap();
863        insta::assert_debug_snapshot!(parse_ast);
864    }
865
866    #[test]
867    fn test_try_star() {
868        let parse_ast = ast::Suite::parse(
869            r#"try:
870    raise ExceptionGroup("eg",
871        [ValueError(1), TypeError(2), OSError(3), OSError(4)])
872except* TypeError as e:
873    print(f'caught {type(e)} with nested {e.exceptions}')
874except* OSError as e:
875    print(f'caught {type(e)} with nested {e.exceptions}')"#,
876            "<test>",
877        )
878        .unwrap();
879        insta::assert_debug_snapshot!(parse_ast);
880    }
881
882    #[test]
883    fn test_dict_unpacking() {
884        let parse_ast = ast::Expr::parse(r#"{"a": "b", **c, "d": "e"}"#, "<test>").unwrap();
885        insta::assert_debug_snapshot!(parse_ast);
886    }
887
888    #[test]
889    fn test_modes() {
890        let source = "a[0][1][2][3][4]";
891
892        assert!(parse(source, Mode::Expression, "<embedded>").is_ok());
893        assert!(parse(source, Mode::Module, "<embedded>").is_ok());
894        assert!(parse(source, Mode::Interactive, "<embedded>").is_ok());
895    }
896
897    #[test]
898    #[cfg(feature = "all-nodes-with-ranges")]
899    fn test_parse_type_declaration() {
900        let source = r#"
901type X = int
902type X = int | str
903type X = int | "ForwardRefY"
904type X[T] = T | list[X[T]]  # recursive
905type X[T] = int
906type X[T] = list[T] | set[T]
907type X[T, *Ts, **P] = (T, Ts, P)
908type X[T: int, *Ts, **P] = (T, Ts, P)
909type X[T: (int, str), *Ts, **P] = (T, Ts, P)
910
911# soft keyword as alias name
912type type = int  
913type match = int
914type case = int
915
916# soft keyword as value
917type foo = type
918type foo = match
919type foo = case
920
921# multine definitions
922type \
923	X = int
924type X \
925	= int
926type X = \
927	int
928type X = (
929    int
930)
931type \
932    X[T] = T
933type X \
934    [T] = T
935type X[T] \
936    = T
937"#;
938        insta::assert_debug_snapshot!(ast::Suite::parse(source, "<test>").unwrap());
939    }
940
941    #[test]
942    #[cfg(feature = "all-nodes-with-ranges")]
943    fn test_type_as_identifier() {
944        let source = r#"\
945type *a + b, c   # ((type * a) + b), c
946type *(a + b), c   # (type * (a + b)), c
947type (*a + b, c)   # type ((*(a + b)), c)
948type -a * b + c   # (type - (a * b)) + c
949type -(a * b) + c   # (type - (a * b)) + c
950type (-a) * b + c   # (type (-(a * b))) + c
951type ().a   # (type()).a
952type (()).a   # (type(())).a
953type ((),).a   # (type(())).a
954type [a].b   # (type[a]).b
955type [a,].b   # (type[(a,)]).b  (not (type[a]).b)
956type [(a,)].b   # (type[(a,)]).b
957type()[a:
958    b]  # (type())[a: b]
959if type := 1: pass
960type = lambda query: query == event
961print(type(12))
962type(type)
963a = (
964	type in C
965)
966a = (
967	type(b)
968)
969type (
970	X = int
971)
972type = 1
973type = x = 1
974x = type = 1
975"#;
976        insta::assert_debug_snapshot!(ast::Suite::parse(source, "<test>").unwrap());
977    }
978
979    #[test]
980    #[cfg(feature = "all-nodes-with-ranges")]
981    fn test_match_as_identifier() {
982        let source = r#"\
983match *a + b, c   # ((match * a) + b), c
984match *(a + b), c   # (match * (a + b)), c
985match (*a + b, c)   # match ((*(a + b)), c)
986match -a * b + c   # (match - (a * b)) + c
987match -(a * b) + c   # (match - (a * b)) + c
988match (-a) * b + c   # (match (-(a * b))) + c
989match ().a   # (match()).a
990match (()).a   # (match(())).a
991match ((),).a   # (match(())).a
992match [a].b   # (match[a]).b
993match [a,].b   # (match[(a,)]).b  (not (match[a]).b)
994match [(a,)].b   # (match[(a,)]).b
995match()[a:
996    b]  # (match())[a: b]
997if match := 1: pass
998match match:
999    case 1: pass
1000    case 2:
1001        pass
1002match = lambda query: query == event
1003print(match(12))
1004"#;
1005        insta::assert_debug_snapshot!(ast::Suite::parse(source, "<test>").unwrap());
1006    }
1007
1008    #[test]
1009    #[cfg(feature = "all-nodes-with-ranges")]
1010    fn test_patma() {
1011        let source = r#"# Cases sampled from Lib/test/test_patma.py
1012
1013# case test_patma_098
1014match x:
1015    case -0j:
1016        y = 0
1017# case test_patma_142
1018match x:
1019    case bytes(z):
1020        y = 0
1021# case test_patma_073
1022match x:
1023    case 0 if 0:
1024        y = 0
1025    case 0 if 1:
1026        y = 1
1027# case test_patma_006
1028match 3:
1029    case 0 | 1 | 2 | 3:
1030        x = True
1031# case test_patma_049
1032match x:
1033    case [0, 1] | [1, 0]:
1034        y = 0
1035# case black_check_sequence_then_mapping
1036match x:
1037    case [*_]:
1038        return "seq"
1039    case {}:
1040        return "map"
1041# case test_patma_035
1042match x:
1043    case {0: [1, 2, {}]}:
1044        y = 0
1045    case {0: [1, 2, {}] | True} | {1: [[]]} | {0: [1, 2, {}]} | [] | "X" | {}:
1046        y = 1
1047    case []:
1048        y = 2
1049# case test_patma_107
1050match x:
1051    case 0.25 + 1.75j:
1052        y = 0
1053# case test_patma_097
1054match x:
1055    case -0j:
1056        y = 0
1057# case test_patma_007
1058match 4:
1059    case 0 | 1 | 2 | 3:
1060        x = True
1061# case test_patma_154
1062match x:
1063    case 0 if x:
1064        y = 0
1065# case test_patma_134
1066match x:
1067    case {1: 0}:
1068        y = 0
1069    case {0: 0}:
1070        y = 1
1071    case {**z}:
1072        y = 2
1073# case test_patma_185
1074match Seq():
1075    case [*_]:
1076        y = 0
1077# case test_patma_063
1078match x:
1079    case 1:
1080        y = 0
1081    case 1:
1082        y = 1
1083# case test_patma_248
1084match x:
1085    case {"foo": bar}:
1086        y = bar
1087# case test_patma_019
1088match (0, 1, 2):
1089    case [0, 1, *x, 2]:
1090        y = 0
1091# case test_patma_052
1092match x:
1093    case [0]:
1094        y = 0
1095    case [1, 0] if (x := x[:0]):
1096        y = 1
1097    case [1, 0]:
1098        y = 2
1099# case test_patma_191
1100match w:
1101    case [x, y, *_]:
1102        z = 0
1103# case test_patma_110
1104match x:
1105    case -0.25 - 1.75j:
1106        y = 0
1107# case test_patma_151
1108match (x,):
1109    case [y]:
1110        z = 0
1111# case test_patma_114
1112match x:
1113    case A.B.C.D:
1114        y = 0
1115# case test_patma_232
1116match x:
1117    case None:
1118        y = 0
1119# case test_patma_058
1120match x:
1121    case 0:
1122        y = 0
1123# case test_patma_233
1124match x:
1125    case False:
1126        y = 0
1127# case test_patma_078
1128match x:
1129    case []:
1130        y = 0
1131    case [""]:
1132        y = 1
1133    case "":
1134        y = 2
1135# case test_patma_156
1136match x:
1137    case z:
1138        y = 0
1139# case test_patma_189
1140match w:
1141    case [x, y, *rest]:
1142        z = 0
1143# case test_patma_042
1144match x:
1145    case (0 as z) | (1 as z) | (2 as z) if z == x % 2:
1146        y = 0
1147# case test_patma_034
1148match x:
1149    case {0: [1, 2, {}]}:
1150        y = 0
1151    case {0: [1, 2, {}] | False} | {1: [[]]} | {0: [1, 2, {}]} | [] | "X" | {}:
1152        y = 1
1153    case []:
1154        y = 2
1155# case test_patma_123
1156match (0, 1, 2):
1157    case 0, *x:
1158        y = 0
1159# case test_patma_126
1160match (0, 1, 2):
1161    case *x, 2,:
1162        y = 0
1163# case test_patma_151
1164match x,:
1165    case y,:
1166        z = 0
1167# case test_patma_152
1168match w, x:
1169    case y, z:
1170        v = 0
1171# case test_patma_153
1172match w := x,:
1173    case y as v,:
1174        z = 0
1175"#;
1176        let parse_ast = ast::Suite::parse(source, "<test>").unwrap();
1177        insta::assert_debug_snapshot!(parse_ast);
1178    }
1179
1180    #[test]
1181    #[cfg(feature = "all-nodes-with-ranges")]
1182    fn test_match() {
1183        let parse_ast = ast::Suite::parse(
1184            r#"
1185match {"test": 1}:
1186    case {
1187        **rest,
1188    }:
1189        print(rest)
1190match {"label": "test"}:
1191    case {
1192        "label": str() | None as label,
1193    }:
1194        print(label)
1195match x:
1196    case [0, 1,]:
1197        y = 0
1198match x:
1199    case (0, 1,):
1200        y = 0
1201match x:
1202    case (0,):
1203        y = 0
1204"#,
1205            "<test>",
1206        )
1207        .unwrap();
1208        insta::assert_debug_snapshot!(parse_ast);
1209    }
1210
1211    #[test]
1212    #[cfg(feature = "all-nodes-with-ranges")]
1213    fn test_variadic_generics() {
1214        let parse_ast = ast::Suite::parse(
1215            r#"
1216def args_to_tuple(*args: *Ts) -> Tuple[*Ts]: ...
1217"#,
1218            "<test>",
1219        )
1220        .unwrap();
1221        insta::assert_debug_snapshot!(parse_ast);
1222    }
1223
1224    #[test]
1225    fn test_parse_constant() {
1226        use num_traits::ToPrimitive;
1227
1228        let c = ast::Constant::parse_without_path("'string'").unwrap();
1229        assert_eq!(c.str().unwrap(), "string");
1230
1231        let c = ast::Constant::parse_without_path("10").unwrap();
1232        assert_eq!(c.int().unwrap().to_i32().unwrap(), 10);
1233    }
1234
1235    #[test]
1236    fn test_parse_identifier() {
1237        let i = ast::Identifier::parse_without_path("test").unwrap();
1238        assert_eq!(i.as_str(), "test");
1239    }
1240}