Skip to main content

rustpython_ruff_python_ast/
node.rs

1use ruff_text_size::Ranged;
2
3use crate::visitor::source_order::SourceOrderVisitor;
4use crate::{
5    self as ast, Alias, AnyNodeRef, AnyParameterRef, ArgOrKeyword, MatchCase, PatternKeyword,
6};
7
8impl ast::ElifElseClause {
9    pub(crate) fn visit_source_order<'a, V>(&'a self, visitor: &mut V)
10    where
11        V: SourceOrderVisitor<'a> + ?Sized,
12    {
13        let ast::ElifElseClause {
14            range: _,
15            node_index: _,
16            test,
17            body,
18        } = self;
19        if let Some(test) = test {
20            visitor.visit_expr(test);
21        }
22        visitor.visit_body(body);
23    }
24}
25
26impl ast::ExprDict {
27    pub(crate) fn visit_source_order<'a, V>(&'a self, visitor: &mut V)
28    where
29        V: SourceOrderVisitor<'a> + ?Sized,
30    {
31        let ast::ExprDict {
32            items,
33            range: _,
34            node_index: _,
35        } = self;
36
37        for ast::DictItem { key, value } in items {
38            if let Some(key) = key {
39                visitor.visit_expr(key);
40            }
41            visitor.visit_expr(value);
42        }
43    }
44}
45
46impl ast::ExprBoolOp {
47    pub(crate) fn visit_source_order<'a, V>(&'a self, visitor: &mut V)
48    where
49        V: SourceOrderVisitor<'a> + ?Sized,
50    {
51        let ast::ExprBoolOp {
52            op,
53            values,
54            range: _,
55            node_index: _,
56        } = self;
57        match values.as_slice() {
58            [left, rest @ ..] => {
59                visitor.visit_expr(left);
60                visitor.visit_bool_op(op);
61                for expr in rest {
62                    visitor.visit_expr(expr);
63                }
64            }
65            [] => {
66                visitor.visit_bool_op(op);
67            }
68        }
69    }
70}
71
72impl ast::ExprCompare {
73    pub(crate) fn visit_source_order<'a, V>(&'a self, visitor: &mut V)
74    where
75        V: SourceOrderVisitor<'a> + ?Sized,
76    {
77        let ast::ExprCompare {
78            left,
79            ops,
80            comparators,
81            range: _,
82            node_index: _,
83        } = self;
84
85        visitor.visit_expr(left);
86
87        for (op, comparator) in ops.iter().zip(comparators) {
88            visitor.visit_cmp_op(op);
89            visitor.visit_expr(comparator);
90        }
91    }
92}
93
94impl ast::InterpolatedStringFormatSpec {
95    pub(crate) fn visit_source_order<'a, V>(&'a self, visitor: &mut V)
96    where
97        V: SourceOrderVisitor<'a> + ?Sized,
98    {
99        for element in &self.elements {
100            visitor.visit_interpolated_string_element(element);
101        }
102    }
103}
104
105impl ast::InterpolatedElement {
106    pub(crate) fn visit_source_order<'a, V>(&'a self, visitor: &mut V)
107    where
108        V: SourceOrderVisitor<'a> + ?Sized,
109    {
110        let ast::InterpolatedElement {
111            expression,
112            format_spec,
113            ..
114        } = self;
115        visitor.visit_expr(expression);
116
117        if let Some(format_spec) = format_spec {
118            for spec_part in &format_spec.elements {
119                visitor.visit_interpolated_string_element(spec_part);
120            }
121        }
122    }
123}
124
125impl ast::InterpolatedStringLiteralElement {
126    pub(crate) fn visit_source_order<'a, V>(&'a self, _visitor: &mut V)
127    where
128        V: SourceOrderVisitor<'a> + ?Sized,
129    {
130        let ast::InterpolatedStringLiteralElement {
131            range: _,
132            node_index: _,
133            value: _,
134        } = self;
135    }
136}
137
138impl ast::ExprFString {
139    pub(crate) fn visit_source_order<'a, V>(&'a self, visitor: &mut V)
140    where
141        V: SourceOrderVisitor<'a> + ?Sized,
142    {
143        let ast::ExprFString {
144            value,
145            range: _,
146            node_index: _,
147        } = self;
148
149        for f_string_part in value {
150            match f_string_part {
151                ast::FStringPart::Literal(string_literal) => {
152                    visitor.visit_string_literal(string_literal);
153                }
154                ast::FStringPart::FString(f_string) => {
155                    visitor.visit_f_string(f_string);
156                }
157            }
158        }
159    }
160}
161
162impl ast::ExprTString {
163    pub(crate) fn visit_source_order<'a, V>(&'a self, visitor: &mut V)
164    where
165        V: SourceOrderVisitor<'a> + ?Sized,
166    {
167        let ast::ExprTString {
168            value,
169            range: _,
170            node_index: _,
171        } = self;
172
173        for t_string in value {
174            visitor.visit_t_string(t_string);
175        }
176    }
177}
178
179impl ast::ExprStringLiteral {
180    pub(crate) fn visit_source_order<'a, V>(&'a self, visitor: &mut V)
181    where
182        V: SourceOrderVisitor<'a> + ?Sized,
183    {
184        let ast::ExprStringLiteral {
185            value,
186            range: _,
187            node_index: _,
188        } = self;
189
190        for string_literal in value {
191            visitor.visit_string_literal(string_literal);
192        }
193    }
194}
195
196impl ast::ExprBytesLiteral {
197    pub(crate) fn visit_source_order<'a, V>(&'a self, visitor: &mut V)
198    where
199        V: SourceOrderVisitor<'a> + ?Sized,
200    {
201        let ast::ExprBytesLiteral {
202            value,
203            range: _,
204            node_index: _,
205        } = self;
206
207        for bytes_literal in value {
208            visitor.visit_bytes_literal(bytes_literal);
209        }
210    }
211}
212
213impl ast::ExceptHandlerExceptHandler {
214    pub(crate) fn visit_source_order<'a, V>(&'a self, visitor: &mut V)
215    where
216        V: SourceOrderVisitor<'a> + ?Sized,
217    {
218        let ast::ExceptHandlerExceptHandler {
219            range: _,
220            node_index: _,
221            type_,
222            name,
223            body,
224        } = self;
225        if let Some(expr) = type_ {
226            visitor.visit_expr(expr);
227        }
228
229        if let Some(name) = name {
230            visitor.visit_identifier(name);
231        }
232
233        visitor.visit_body(body);
234    }
235}
236
237impl ast::PatternMatchMapping {
238    pub(crate) fn visit_source_order<'a, V>(&'a self, visitor: &mut V)
239    where
240        V: SourceOrderVisitor<'a> + ?Sized,
241    {
242        let ast::PatternMatchMapping {
243            keys,
244            patterns,
245            rest,
246            range: _,
247            node_index: _,
248        } = self;
249
250        let mut rest = rest.as_ref();
251
252        for (key, pattern) in keys.iter().zip(patterns) {
253            if let Some(rest_identifier) = rest {
254                if rest_identifier.start() < key.start() {
255                    visitor.visit_identifier(rest_identifier);
256                    rest = None;
257                }
258            }
259            visitor.visit_expr(key);
260            visitor.visit_pattern(pattern);
261        }
262
263        if let Some(rest) = rest {
264            visitor.visit_identifier(rest);
265        }
266    }
267}
268
269impl ast::PatternArguments {
270    pub(crate) fn visit_source_order<'a, V>(&'a self, visitor: &mut V)
271    where
272        V: SourceOrderVisitor<'a> + ?Sized,
273    {
274        for pattern_or_keyword in self.patterns_source_order() {
275            match pattern_or_keyword {
276                crate::PatternOrKeyword::Pattern(pattern) => visitor.visit_pattern(pattern),
277                crate::PatternOrKeyword::Keyword(keyword) => {
278                    visitor.visit_pattern_keyword(keyword);
279                }
280            }
281        }
282    }
283}
284
285impl ast::PatternKeyword {
286    pub(crate) fn visit_source_order<'a, V>(&'a self, visitor: &mut V)
287    where
288        V: SourceOrderVisitor<'a> + ?Sized,
289    {
290        let PatternKeyword {
291            range: _,
292            node_index: _,
293            attr,
294            pattern,
295        } = self;
296
297        visitor.visit_identifier(attr);
298        visitor.visit_pattern(pattern);
299    }
300}
301
302impl ast::Comprehension {
303    pub(crate) fn visit_source_order<'a, V>(&'a self, visitor: &mut V)
304    where
305        V: SourceOrderVisitor<'a> + ?Sized,
306    {
307        let ast::Comprehension {
308            range: _,
309            node_index: _,
310            target,
311            iter,
312            ifs,
313            is_async: _,
314        } = self;
315        visitor.visit_expr(target);
316        visitor.visit_expr(iter);
317
318        for expr in ifs {
319            visitor.visit_expr(expr);
320        }
321    }
322}
323
324impl ast::Arguments {
325    pub(crate) fn visit_source_order<'a, V>(&'a self, visitor: &mut V)
326    where
327        V: SourceOrderVisitor<'a> + ?Sized,
328    {
329        for arg_or_keyword in self.arguments_source_order() {
330            match arg_or_keyword {
331                ArgOrKeyword::Arg(arg) => visitor.visit_expr(arg),
332                ArgOrKeyword::Keyword(keyword) => visitor.visit_keyword(keyword),
333            }
334        }
335    }
336}
337
338impl ast::Parameters {
339    pub(crate) fn visit_source_order<'a, V>(&'a self, visitor: &mut V)
340    where
341        V: SourceOrderVisitor<'a> + ?Sized,
342    {
343        for parameter in self.iter_source_order() {
344            match parameter {
345                AnyParameterRef::NonVariadic(parameter_with_default) => {
346                    visitor.visit_parameter_with_default(parameter_with_default);
347                }
348                AnyParameterRef::Variadic(parameter) => visitor.visit_parameter(parameter),
349            }
350        }
351    }
352}
353
354impl ast::Parameter {
355    pub(crate) fn visit_source_order<'a, V>(&'a self, visitor: &mut V)
356    where
357        V: SourceOrderVisitor<'a> + ?Sized,
358    {
359        let ast::Parameter {
360            range: _,
361            node_index: _,
362            name,
363            annotation,
364        } = self;
365
366        visitor.visit_identifier(name);
367        if let Some(expr) = annotation {
368            visitor.visit_annotation(expr);
369        }
370    }
371}
372
373impl ast::ParameterWithDefault {
374    pub(crate) fn visit_source_order<'a, V>(&'a self, visitor: &mut V)
375    where
376        V: SourceOrderVisitor<'a> + ?Sized,
377    {
378        let ast::ParameterWithDefault {
379            range: _,
380            node_index: _,
381            parameter,
382            default,
383        } = self;
384        visitor.visit_parameter(parameter);
385        if let Some(expr) = default {
386            visitor.visit_expr(expr);
387        }
388    }
389}
390
391impl ast::Keyword {
392    pub(crate) fn visit_source_order<'a, V>(&'a self, visitor: &mut V)
393    where
394        V: SourceOrderVisitor<'a> + ?Sized,
395    {
396        let ast::Keyword {
397            range: _,
398            node_index: _,
399            arg,
400            value,
401        } = self;
402
403        if let Some(arg) = arg {
404            visitor.visit_identifier(arg);
405        }
406        visitor.visit_expr(value);
407    }
408}
409
410impl Alias {
411    pub(crate) fn visit_source_order<'a, V>(&'a self, visitor: &mut V)
412    where
413        V: SourceOrderVisitor<'a> + ?Sized,
414    {
415        let ast::Alias {
416            range: _,
417            node_index: _,
418            name,
419            asname,
420        } = self;
421
422        visitor.visit_identifier(name);
423        if let Some(asname) = asname {
424            visitor.visit_identifier(asname);
425        }
426    }
427}
428
429impl ast::WithItem {
430    pub(crate) fn visit_source_order<'a, V>(&'a self, visitor: &mut V)
431    where
432        V: SourceOrderVisitor<'a> + ?Sized,
433    {
434        let ast::WithItem {
435            range: _,
436            node_index: _,
437            context_expr,
438            optional_vars,
439        } = self;
440
441        visitor.visit_expr(context_expr);
442
443        if let Some(expr) = optional_vars {
444            visitor.visit_expr(expr);
445        }
446    }
447}
448
449impl ast::MatchCase {
450    pub(crate) fn visit_source_order<'a, V>(&'a self, visitor: &mut V)
451    where
452        V: SourceOrderVisitor<'a> + ?Sized,
453    {
454        let ast::MatchCase {
455            range: _,
456            node_index: _,
457            pattern,
458            guard,
459            body,
460        } = self;
461
462        visitor.visit_pattern(pattern);
463        if let Some(expr) = guard {
464            visitor.visit_expr(expr);
465        }
466        visitor.visit_body(body);
467    }
468}
469
470impl ast::Decorator {
471    pub(crate) fn visit_source_order<'a, V>(&'a self, visitor: &mut V)
472    where
473        V: SourceOrderVisitor<'a> + ?Sized,
474    {
475        let ast::Decorator {
476            range: _,
477            node_index: _,
478            expression,
479        } = self;
480
481        visitor.visit_expr(expression);
482    }
483}
484
485impl ast::TypeParams {
486    pub(crate) fn visit_source_order<'a, V>(&'a self, visitor: &mut V)
487    where
488        V: SourceOrderVisitor<'a> + ?Sized,
489    {
490        let ast::TypeParams {
491            range: _,
492            node_index: _,
493            type_params,
494        } = self;
495
496        for type_param in type_params {
497            visitor.visit_type_param(type_param);
498        }
499    }
500}
501
502impl ast::FString {
503    pub(crate) fn visit_source_order<'a, V>(&'a self, visitor: &mut V)
504    where
505        V: SourceOrderVisitor<'a> + ?Sized,
506    {
507        let ast::FString {
508            elements,
509            range: _,
510            node_index: _,
511            flags: _,
512        } = self;
513
514        for fstring_element in elements {
515            visitor.visit_interpolated_string_element(fstring_element);
516        }
517    }
518}
519
520impl ast::TString {
521    pub(crate) fn visit_source_order<'a, V>(&'a self, visitor: &mut V)
522    where
523        V: SourceOrderVisitor<'a> + ?Sized,
524    {
525        let ast::TString {
526            elements,
527            range: _,
528            node_index: _,
529            flags: _,
530        } = self;
531
532        for tstring_element in elements {
533            visitor.visit_interpolated_string_element(tstring_element);
534        }
535    }
536}
537
538impl ast::StringLiteral {
539    #[inline]
540    pub(crate) fn visit_source_order<'a, V>(&'a self, _visitor: &mut V)
541    where
542        V: SourceOrderVisitor<'a> + ?Sized,
543    {
544        let ast::StringLiteral {
545            range: _,
546            node_index: _,
547            value: _,
548            flags: _,
549        } = self;
550    }
551}
552
553impl ast::BytesLiteral {
554    #[inline]
555    pub(crate) fn visit_source_order<'a, V>(&'a self, _visitor: &mut V)
556    where
557        V: SourceOrderVisitor<'a> + ?Sized,
558    {
559        let ast::BytesLiteral {
560            range: _,
561            node_index: _,
562            value: _,
563            flags: _,
564        } = self;
565    }
566}
567
568impl ast::Identifier {
569    #[inline]
570    pub(crate) fn visit_source_order<'a, V>(&'a self, _visitor: &mut V)
571    where
572        V: SourceOrderVisitor<'a> + ?Sized,
573    {
574        let ast::Identifier {
575            range: _,
576            node_index: _,
577            id: _,
578        } = self;
579    }
580}
581
582impl<'a> AnyNodeRef<'a> {
583    /// Compares two any node refs by their pointers (referential equality).
584    pub fn ptr_eq(self, other: AnyNodeRef) -> bool {
585        self.as_ptr().eq(&other.as_ptr()) && self.kind() == other.kind()
586    }
587
588    /// In our AST, only some alternative branches are represented as a node. This has historical
589    /// reasons, e.g. we added a node for elif/else in if statements which was not originally
590    /// present in the parser.
591    pub const fn is_alternative_branch_with_node(self) -> bool {
592        matches!(
593            self,
594            AnyNodeRef::ExceptHandlerExceptHandler(_) | AnyNodeRef::ElifElseClause(_)
595        )
596    }
597
598    /// The last child of the last branch, if the node has multiple branches.
599    pub fn last_child_in_body(&self) -> Option<AnyNodeRef<'a>> {
600        let body =
601            match self {
602                AnyNodeRef::StmtFunctionDef(ast::StmtFunctionDef { body, .. })
603                | AnyNodeRef::StmtClassDef(ast::StmtClassDef { body, .. })
604                | AnyNodeRef::StmtWith(ast::StmtWith { body, .. })
605                | AnyNodeRef::MatchCase(MatchCase { body, .. })
606                | AnyNodeRef::ExceptHandlerExceptHandler(ast::ExceptHandlerExceptHandler {
607                    body,
608                    ..
609                })
610                | AnyNodeRef::ElifElseClause(ast::ElifElseClause { body, .. }) => body,
611                AnyNodeRef::StmtIf(ast::StmtIf {
612                    body,
613                    elif_else_clauses,
614                    ..
615                }) => elif_else_clauses.last().map_or(body, |clause| &clause.body),
616
617                AnyNodeRef::StmtFor(ast::StmtFor { body, orelse, .. })
618                | AnyNodeRef::StmtWhile(ast::StmtWhile { body, orelse, .. }) => {
619                    if orelse.is_empty() { body } else { orelse }
620                }
621
622                AnyNodeRef::StmtMatch(ast::StmtMatch { cases, .. }) => {
623                    return cases.last().map(AnyNodeRef::from);
624                }
625
626                AnyNodeRef::StmtTry(ast::StmtTry {
627                    body,
628                    handlers,
629                    orelse,
630                    finalbody,
631                    ..
632                }) => {
633                    if finalbody.is_empty() {
634                        if orelse.is_empty() {
635                            if handlers.is_empty() {
636                                body
637                            } else {
638                                return handlers.last().map(AnyNodeRef::from);
639                            }
640                        } else {
641                            orelse
642                        }
643                    } else {
644                        finalbody
645                    }
646                }
647
648                // Not a node that contains an indented child node.
649                _ => return None,
650            };
651
652        body.last().map(AnyNodeRef::from)
653    }
654
655    /// Check if the given statement is the first statement after the colon of a branch, be it in if
656    /// statements, for statements, after each part of a try-except-else-finally or function/class
657    /// definitions.
658    ///
659    ///
660    /// ```python
661    /// if True:    <- has body
662    ///     a       <- first statement
663    ///     b
664    /// elif b:     <- has body
665    ///     c       <- first statement
666    ///     d
667    /// else:       <- has body
668    ///     e       <- first statement
669    ///     f
670    ///
671    /// class:      <- has body
672    ///     a: int  <- first statement
673    ///     b: int
674    ///
675    /// ```
676    ///
677    /// For nodes with multiple bodies, we check all bodies that don't have their own node. For
678    /// try-except-else-finally, each except branch has it's own node, so for the `StmtTry`, we check
679    /// the `try:`, `else:` and `finally:`, bodies, while `ExceptHandlerExceptHandler` has it's own
680    /// check. For for-else and while-else, we check both branches for the whole statement.
681    ///
682    /// ```python
683    /// try:        <- has body (a)
684    ///     6/8     <- first statement (a)
685    ///     1/0
686    /// except:     <- has body (b)
687    ///     a       <- first statement (b)
688    ///     b
689    /// else:
690    ///     c       <- first statement (a)
691    ///     d
692    /// finally:
693    ///     e       <- first statement (a)
694    ///     f
695    /// ```
696    pub fn is_first_statement_in_body(&self, body: AnyNodeRef) -> bool {
697        match body {
698            AnyNodeRef::StmtFor(ast::StmtFor { body, orelse, .. })
699            | AnyNodeRef::StmtWhile(ast::StmtWhile { body, orelse, .. }) => {
700                are_same_optional(*self, body.first()) || are_same_optional(*self, orelse.first())
701            }
702
703            AnyNodeRef::StmtTry(ast::StmtTry {
704                body,
705                orelse,
706                finalbody,
707                ..
708            }) => {
709                are_same_optional(*self, body.first())
710                    || are_same_optional(*self, orelse.first())
711                    || are_same_optional(*self, finalbody.first())
712            }
713
714            AnyNodeRef::StmtIf(ast::StmtIf { body, .. })
715            | AnyNodeRef::ElifElseClause(ast::ElifElseClause { body, .. })
716            | AnyNodeRef::StmtWith(ast::StmtWith { body, .. })
717            | AnyNodeRef::ExceptHandlerExceptHandler(ast::ExceptHandlerExceptHandler {
718                body,
719                ..
720            })
721            | AnyNodeRef::MatchCase(MatchCase { body, .. })
722            | AnyNodeRef::StmtFunctionDef(ast::StmtFunctionDef { body, .. })
723            | AnyNodeRef::StmtClassDef(ast::StmtClassDef { body, .. }) => {
724                are_same_optional(*self, body.first())
725            }
726
727            AnyNodeRef::StmtMatch(ast::StmtMatch { cases, .. }) => {
728                are_same_optional(*self, cases.first())
729            }
730
731            _ => false,
732        }
733    }
734
735    /// Returns `true` if `statement` is the first statement in an alternate `body` (e.g. the else of an if statement)
736    pub fn is_first_statement_in_alternate_body(&self, body: AnyNodeRef) -> bool {
737        match body {
738            AnyNodeRef::StmtFor(ast::StmtFor { orelse, .. })
739            | AnyNodeRef::StmtWhile(ast::StmtWhile { orelse, .. }) => {
740                are_same_optional(*self, orelse.first())
741            }
742
743            AnyNodeRef::StmtTry(ast::StmtTry {
744                handlers,
745                orelse,
746                finalbody,
747                ..
748            }) => {
749                are_same_optional(*self, handlers.first())
750                    || are_same_optional(*self, orelse.first())
751                    || are_same_optional(*self, finalbody.first())
752            }
753
754            AnyNodeRef::StmtIf(ast::StmtIf {
755                elif_else_clauses, ..
756            }) => are_same_optional(*self, elif_else_clauses.first()),
757            _ => false,
758        }
759    }
760}
761
762/// Returns `true` if `right` is `Some` and `left` and `right` are referentially equal.
763fn are_same_optional<'a, T>(left: AnyNodeRef, right: Option<T>) -> bool
764where
765    T: Into<AnyNodeRef<'a>>,
766{
767    right.is_some_and(|right| left.ptr_eq(right.into()))
768}