rustpython_parser/
function.rs

1// Contains functions that perform validation and parsing of arguments and parameters.
2// Checks apply both to functions and to lambdas.
3use crate::text_size::TextRange;
4use crate::{
5    ast,
6    lexer::{LexicalError, LexicalErrorType},
7    text_size::TextSize,
8};
9use rustc_hash::FxHashSet;
10use rustpython_ast::Ranged;
11
12pub(crate) struct ArgumentList {
13    pub args: Vec<ast::Expr>,
14    pub keywords: Vec<ast::Keyword>,
15}
16
17// Perform validation of function/lambda arguments in a function definition.
18pub(crate) fn validate_arguments(arguments: &ast::Arguments) -> Result<(), LexicalError> {
19    let mut all_arg_names = FxHashSet::with_hasher(Default::default());
20
21    let posonlyargs = arguments.posonlyargs.iter();
22    let args = arguments.args.iter();
23    let kwonlyargs = arguments.kwonlyargs.iter();
24
25    let vararg: Option<&ast::Arg> = arguments.vararg.as_deref();
26    let kwarg: Option<&ast::Arg> = arguments.kwarg.as_deref();
27
28    for arg in posonlyargs
29        .chain(args)
30        .chain(kwonlyargs)
31        .map(|arg| &arg.def)
32        .chain(vararg)
33        .chain(kwarg)
34    {
35        let range = arg.range;
36        let arg_name = arg.arg.as_str();
37        if !all_arg_names.insert(arg_name) {
38            return Err(LexicalError {
39                error: LexicalErrorType::DuplicateArgumentError(arg_name.to_string()),
40                location: range.start(),
41            });
42        }
43    }
44
45    Ok(())
46}
47
48pub(crate) fn validate_pos_params(
49    args: &(Vec<ast::ArgWithDefault>, Vec<ast::ArgWithDefault>),
50) -> Result<(), LexicalError> {
51    let (posonlyargs, args) = args;
52    #[allow(clippy::skip_while_next)]
53    let first_invalid = posonlyargs
54        .iter()
55        .chain(args.iter()) // for all args
56        .skip_while(|arg| arg.default.is_none()) // starting with args without default
57        .skip_while(|arg| arg.default.is_some()) // and then args with default
58        .next(); // there must not be any more args without default
59    if let Some(invalid) = first_invalid {
60        return Err(LexicalError {
61            error: LexicalErrorType::DefaultArgumentError,
62            location: invalid.def.range.start(),
63        });
64    }
65    Ok(())
66}
67
68type FunctionArgument = (
69    Option<(TextSize, TextSize, Option<ast::Identifier>)>,
70    ast::Expr,
71);
72
73// Parse arguments as supplied during a function/lambda *call*.
74pub(crate) fn parse_args(func_args: Vec<FunctionArgument>) -> Result<ArgumentList, LexicalError> {
75    let mut args = vec![];
76    let mut keywords = vec![];
77
78    let mut keyword_names =
79        FxHashSet::with_capacity_and_hasher(func_args.len(), Default::default());
80    let mut double_starred = false;
81    for (name, value) in func_args {
82        match name {
83            Some((start, end, name)) => {
84                // Check for duplicate keyword arguments in the call.
85                if let Some(keyword_name) = &name {
86                    if keyword_names.contains(keyword_name) {
87                        return Err(LexicalError {
88                            error: LexicalErrorType::DuplicateKeywordArgumentError(
89                                keyword_name.to_string(),
90                            ),
91                            location: start,
92                        });
93                    }
94
95                    keyword_names.insert(keyword_name.clone());
96                } else {
97                    double_starred = true;
98                }
99
100                keywords.push(ast::Keyword {
101                    arg: name.map(ast::Identifier::new),
102                    value,
103                    range: TextRange::new(start, end),
104                });
105            }
106            None => {
107                // Positional arguments mustn't follow keyword arguments.
108                if !keywords.is_empty() && !is_starred(&value) {
109                    return Err(LexicalError {
110                        error: LexicalErrorType::PositionalArgumentError,
111                        location: value.start(),
112                    });
113                // Allow starred arguments after keyword arguments but
114                // not after double-starred arguments.
115                } else if double_starred {
116                    return Err(LexicalError {
117                        error: LexicalErrorType::UnpackedArgumentError,
118                        location: value.start(),
119                    });
120                }
121
122                args.push(value);
123            }
124        }
125    }
126    Ok(ArgumentList { args, keywords })
127}
128
129// Check if an expression is a starred expression.
130const fn is_starred(exp: &ast::Expr) -> bool {
131    exp.is_starred_expr()
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137    use crate::{ast, parser::ParseErrorType, Parse};
138
139    #[cfg(feature = "all-nodes-with-ranges")]
140    macro_rules! function_and_lambda {
141        ($($name:ident: $code:expr,)*) => {
142            $(
143                #[test]
144                fn $name() {
145                    let parse_ast = ast::Suite::parse($code, "<test>");
146                    insta::assert_debug_snapshot!(parse_ast);
147                }
148            )*
149        }
150    }
151
152    #[cfg(feature = "all-nodes-with-ranges")]
153    function_and_lambda! {
154        test_function_no_args: "def f(): pass",
155        test_function_pos_args: "def f(a, b, c): pass",
156        test_function_pos_args_with_defaults: "def f(a, b=20, c=30): pass",
157        test_function_kw_only_args: "def f(*, a, b, c): pass",
158        test_function_kw_only_args_with_defaults: "def f(*, a, b=20, c=30): pass",
159        test_function_pos_and_kw_only_args: "def f(a, b, c, *, d, e, f): pass",
160        test_function_pos_and_kw_only_args_with_defaults: "def f(a, b, c, *, d, e=20, f=30): pass",
161        test_function_pos_and_kw_only_args_with_defaults_and_varargs: "def f(a, b, c, *args, d, e=20, f=30): pass",
162        test_function_pos_and_kw_only_args_with_defaults_and_varargs_and_kwargs: "def f(a, b, c, *args, d, e=20, f=30, **kwargs): pass",
163        test_lambda_no_args: "lambda: 1",
164        test_lambda_pos_args: "lambda a, b, c: 1",
165        test_lambda_pos_args_with_defaults: "lambda a, b=20, c=30: 1",
166        test_lambda_kw_only_args: "lambda *, a, b, c: 1",
167        test_lambda_kw_only_args_with_defaults: "lambda *, a, b=20, c=30: 1",
168        test_lambda_pos_and_kw_only_args: "lambda a, b, c, *, d, e: 0",
169    }
170
171    fn function_parse_error(src: &str) -> LexicalErrorType {
172        let parse_ast = ast::Suite::parse(src, "<test>");
173        parse_ast
174            .map_err(|e| match e.error {
175                ParseErrorType::Lexical(e) => e,
176                _ => panic!("Expected LexicalError"),
177            })
178            .expect_err("Expected error")
179    }
180
181    macro_rules! function_and_lambda_error {
182        ($($name:ident: $code:expr, $error:expr,)*) => {
183            $(
184                #[test]
185                fn $name() {
186                    let error = function_parse_error($code);
187                    assert_eq!(error, $error);
188                }
189            )*
190        }
191    }
192
193    function_and_lambda_error! {
194        // Check definitions
195        test_duplicates_f1: "def f(a, a): pass", LexicalErrorType::DuplicateArgumentError("a".to_string()),
196        test_duplicates_f2: "def f(a, *, a): pass", LexicalErrorType::DuplicateArgumentError("a".to_string()),
197        test_duplicates_f3: "def f(a, a=20): pass", LexicalErrorType::DuplicateArgumentError("a".to_string()),
198        test_duplicates_f4: "def f(a, *a): pass", LexicalErrorType::DuplicateArgumentError("a".to_string()),
199        test_duplicates_f5: "def f(a, *, **a): pass", LexicalErrorType::DuplicateArgumentError("a".to_string()),
200        test_duplicates_l1: "lambda a, a: 1", LexicalErrorType::DuplicateArgumentError("a".to_string()),
201        test_duplicates_l2: "lambda a, *, a: 1", LexicalErrorType::DuplicateArgumentError("a".to_string()),
202        test_duplicates_l3: "lambda a, a=20: 1", LexicalErrorType::DuplicateArgumentError("a".to_string()),
203        test_duplicates_l4: "lambda a, *a: 1", LexicalErrorType::DuplicateArgumentError("a".to_string()),
204        test_duplicates_l5: "lambda a, *, **a: 1", LexicalErrorType::DuplicateArgumentError("a".to_string()),
205        test_default_arg_error_f: "def f(a, b=20, c): pass", LexicalErrorType::DefaultArgumentError,
206        test_default_arg_error_l: "lambda a, b=20, c: 1", LexicalErrorType::DefaultArgumentError,
207
208        // Check some calls.
209        test_positional_arg_error_f: "f(b=20, c)", LexicalErrorType::PositionalArgumentError,
210        test_unpacked_arg_error_f: "f(**b, *c)", LexicalErrorType::UnpackedArgumentError,
211        test_duplicate_kw_f1: "f(a=20, a=30)", LexicalErrorType::DuplicateKeywordArgumentError("a".to_string()),
212    }
213}