selene_lib/lints/
mismatched_arg_count.rs

1use super::*;
2use crate::{
3    ast_util::{
4        is_vararg, range,
5        scopes::{Reference, ScopeManager, Variable},
6    },
7    text::plural,
8};
9use std::{
10    collections::HashMap,
11    convert::Infallible,
12    fmt::{self, Display},
13};
14
15use full_moon::{
16    ast::{self, Ast},
17    visitors::Visitor,
18};
19use id_arena::Id;
20
21pub struct MismatchedArgCountLint;
22
23impl Lint for MismatchedArgCountLint {
24    type Config = ();
25    type Error = Infallible;
26
27    const SEVERITY: Severity = Severity::Error;
28    const LINT_TYPE: LintType = LintType::Correctness;
29
30    fn new(_: Self::Config) -> Result<Self, Self::Error> {
31        Ok(MismatchedArgCountLint)
32    }
33
34    fn pass(&self, ast: &Ast, _: &Context, ast_context: &AstContext) -> Vec<Diagnostic> {
35        // Firstly visit the AST so we can map the variables to their required parameter counts
36        let mut definitions = HashMap::new();
37        let mut definitions_visitor = MapFunctionDefinitionVisitor {
38            scope_manager: &ast_context.scope_manager,
39            definitions: &mut definitions,
40        };
41        definitions_visitor.visit_ast(ast);
42
43        let mut visitor = MismatchedArgCountVisitor {
44            mismatched_arg_counts: Vec::new(),
45            scope_manager: &ast_context.scope_manager,
46            definitions,
47        };
48
49        visitor.visit_ast(ast);
50
51        visitor
52            .mismatched_arg_counts
53            .iter()
54            .map(|mismatched_arg| {
55                Diagnostic::new_complete(
56                    "mismatched_arg_count",
57                    mismatched_arg
58                        .parameter_count
59                        .to_message(mismatched_arg.num_provided),
60                    Label::new_with_message(
61                        mismatched_arg.call_range,
62                        mismatched_arg.parameter_count.to_string(),
63                    ),
64                    Vec::new(),
65                    mismatched_arg
66                        .function_definition_ranges
67                        .iter()
68                        .map(|range| {
69                            Label::new_with_message(
70                                *range,
71                                "note: function defined here".to_owned(),
72                            )
73                        })
74                        .collect(),
75                )
76            })
77            .collect()
78    }
79}
80
81struct MismatchedArgCount {
82    parameter_count: ParameterCount,
83    num_provided: PassedArgumentCount,
84    call_range: (usize, usize),
85    function_definition_ranges: Vec<(usize, usize)>,
86}
87
88#[derive(Clone, Copy, Debug)]
89enum ParameterCount {
90    /// A fixed number of parameters are required: `function(a, b, c)`
91    Fixed(usize),
92    /// Some amount of fixed parameters are required, and the rest are variable: `function(a, b, ...)`
93    Minimum(usize),
94    /// A variable number of parameters can be provided: `function(...)`
95    Variable,
96}
97
98impl ParameterCount {
99    /// Calculates the number of required parameters that must be passed to a function
100    fn from_function_body(function_body: &ast::FunctionBody) -> Self {
101        let mut necessary_params = 0;
102
103        for parameter in function_body.parameters() {
104            #[cfg_attr(
105                feature = "force_exhaustive_checks",
106                deny(non_exhaustive_omitted_patterns)
107            )]
108            match parameter {
109                ast::Parameter::Name(_) => necessary_params += 1,
110                ast::Parameter::Ellipsis(_) => {
111                    if necessary_params == 0 {
112                        return Self::Variable;
113                    } else {
114                        return Self::Minimum(necessary_params);
115                    }
116                }
117                _ => {}
118            }
119        }
120
121        Self::Fixed(necessary_params)
122    }
123
124    /// Checks the provided number of arguments to see if it satisfies the number of arguments required
125    /// We will only lint an upper bound. If we have a function(a, b, c) and we call foo(a, b), this will
126    /// pass the lint, since the `nil` could be implicitly provided.
127    fn correct_num_args_provided(self, provided: PassedArgumentCount) -> bool {
128        match self {
129            ParameterCount::Fixed(required) => match provided {
130                PassedArgumentCount::Fixed(provided) => provided <= required,
131                // If we have function(a, b, c), but we provide foo(a, call()), we cannot infer anything
132                // but if we provide foo(a, b, c, call()), we know we have too many
133                PassedArgumentCount::Variable(atleast_provided) => atleast_provided <= required,
134            },
135            // function(a, b, ...) - if we call it through foo(a), b and the varargs could be implicitly nil.
136            // there is no upper bound since foo(a, b, c, d) is valid - therefore any amount of arguments provided is valid
137            ParameterCount::Minimum(_) => true,
138            // Any amount of arguments could be provided
139            ParameterCount::Variable => true,
140        }
141    }
142
143    fn to_message(self, provided: PassedArgumentCount) -> String {
144        match self {
145            ParameterCount::Fixed(required) => {
146                format!(
147                    "this function takes {} {} but {} were supplied",
148                    required,
149                    plural(required, "argument", "arguments"),
150                    provided
151                )
152            }
153            ParameterCount::Minimum(required) => format!(
154                "this function takes at least {} {} but {} were supplied",
155                required,
156                plural(required, "argument", "arguments"),
157                provided
158            ),
159            ParameterCount::Variable => "a variable amount of arguments".to_owned(),
160        }
161    }
162
163    fn overlap_with_other_parameter_count(self, other: ParameterCount) -> ParameterCount {
164        match (self, other) {
165            // If something takes `...`, then it'll always be correct no matter what.
166            (ParameterCount::Variable, _) | (_, ParameterCount::Variable) => {
167                ParameterCount::Variable
168            }
169
170            // Minimum always wins, since it allows for infinite parameters, and fixed will always match.
171            // f(a, b, ...) vs. f(a) is Minimum(1), so that `f(1, 2, 3, 4)` passes.
172            // f(a, b, c) vs. f(a, ...) is Minimum(1) for the same reason.
173            (ParameterCount::Fixed(fixed), ParameterCount::Minimum(minimum))
174            | (ParameterCount::Minimum(minimum), ParameterCount::Fixed(fixed)) => {
175                ParameterCount::Minimum(minimum.min(fixed))
176            }
177
178            // Given `f(a, b)` and `f(c, d)`, just preserve the Fixed(2).
179            // The complication comes with `f(a)` and `f(b, c)`, where we change to Minimum(1).
180            (ParameterCount::Fixed(this_fixed), ParameterCount::Fixed(other_fixed)) => {
181                if this_fixed == other_fixed {
182                    ParameterCount::Fixed(this_fixed)
183                } else {
184                    ParameterCount::Fixed(this_fixed.max(other_fixed))
185                }
186            }
187
188            // `f(a, b, ...)` vs. `f(a, ...)`. Same lints apply, just preserve the smaller minimum.
189            (ParameterCount::Minimum(this_minimum), ParameterCount::Minimum(other_minimum)) => {
190                ParameterCount::Minimum(this_minimum.min(other_minimum))
191            }
192        }
193    }
194}
195
196impl Display for ParameterCount {
197    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198        match self {
199            ParameterCount::Fixed(required) => write!(
200                f,
201                "expected {} {}",
202                required,
203                plural(*required, "argument", "arguments")
204            ),
205            ParameterCount::Minimum(required) => {
206                write!(
207                    f,
208                    "expected at least {} {}",
209                    required,
210                    plural(*required, "argument", "arguments")
211                )
212            }
213            ParameterCount::Variable => write!(f, "expected any number of arguments"),
214        }
215    }
216}
217
218#[derive(Clone, Copy, Debug)]
219enum PassedArgumentCount {
220    /// Passed a fixed amount of arguments, such as foo(a, b, c) or foo(a, call(), c) or foo(a, ..., c)
221    Fixed(usize),
222    /// Passed a variable of arguments - but we know the lower bound: e.g. foo(a, b, call()) or foo(a, b, ...)
223    Variable(usize),
224}
225
226impl PassedArgumentCount {
227    fn from_function_args(function_args: &ast::FunctionArgs) -> Self {
228        match function_args {
229            ast::FunctionArgs::Parentheses { arguments, .. } => {
230                // We need to be wary of function calls or ... being the last argument passed
231                // e.g. foo(a, b, call()) or foo(a, b, ...) - we don't know how many arguments were passed.
232                // However, if the call is NOT the last argument, as per Lua semantics, it is only classed as one argument,
233                // e.g. foo(a, call(), b) or foo(a, ..., c)
234
235                let mut passed_argument_count = 0;
236
237                for argument in arguments.pairs() {
238                    passed_argument_count += 1;
239
240                    if let ast::punctuated::Pair::End(expression) = argument {
241                        if matches!(expression, ast::Expression::FunctionCall(_))
242                            || is_vararg(expression)
243                        {
244                            return PassedArgumentCount::Variable(passed_argument_count);
245                        }
246                    }
247                }
248
249                Self::Fixed(passed_argument_count)
250            }
251            ast::FunctionArgs::String(_) => Self::Fixed(1),
252            ast::FunctionArgs::TableConstructor(_) => Self::Fixed(1),
253            _ => Self::Fixed(0),
254        }
255    }
256}
257
258impl Display for PassedArgumentCount {
259    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
260        match self {
261            PassedArgumentCount::Fixed(amount) => write!(f, "{amount} arguments"),
262            PassedArgumentCount::Variable(amount) => write!(f, "at least {amount} arguments"),
263        }
264    }
265}
266
267/// A visitor used to map a variable to the necessary number of parameters required
268struct MapFunctionDefinitionVisitor<'a> {
269    scope_manager: &'a ScopeManager,
270    definitions: &'a mut HashMap<Id<Variable>, ParameterCount>,
271}
272
273impl MapFunctionDefinitionVisitor<'_> {
274    fn find_variable(&self, identifier: (usize, usize)) -> Option<Id<Variable>> {
275        self.scope_manager
276            .variables
277            .iter()
278            .find(|variable| variable.1.identifiers.contains(&identifier))
279            .map(|variable| variable.0)
280    }
281
282    fn find_reference(&self, identifier: (usize, usize)) -> Option<&Reference> {
283        self.scope_manager
284            .references
285            .iter()
286            .find(|reference| reference.1.identifier == identifier)
287            .map(|reference| reference.1)
288    }
289
290    /// Checks the provided variable to see if it is blacklisted, or it has already been stored.
291    /// If so, we can no longer verify which function definition is correct for a specific function call
292    /// so we bail out and blacklist it. This does not apply to locally assignment/reassigned variables (i.e. shadowing),
293    /// as that is handled properly.
294    /// If it is safe to use, the function body is stored.
295    fn verify_assignment(&mut self, variable: Id<Variable>, function_body: &ast::FunctionBody) {
296        let parameter_count = ParameterCount::from_function_body(function_body);
297
298        self.definitions
299            .entry(variable)
300            .and_modify(|older_count| {
301                *older_count = parameter_count.overlap_with_other_parameter_count(*older_count)
302            })
303            .or_insert(parameter_count);
304    }
305}
306
307impl Visitor for MapFunctionDefinitionVisitor<'_> {
308    fn visit_local_function(&mut self, function: &ast::LocalFunction) {
309        let identifier = range(function.name());
310
311        if let Some(id) = self.find_variable(identifier) {
312            self.definitions
313                .insert(id, ParameterCount::from_function_body(function.body()));
314        }
315    }
316
317    fn visit_function_declaration(&mut self, function: &ast::FunctionDeclaration) {
318        let identifier = range(function.name());
319
320        if let Some(reference) = self.find_reference(identifier) {
321            if let Some(variable) = reference.resolved {
322                self.verify_assignment(variable, function.body())
323            }
324        }
325    }
326
327    fn visit_local_assignment(&mut self, local_assignment: &ast::LocalAssignment) {
328        let assignment_expressions = local_assignment
329            .names()
330            .iter()
331            .zip(local_assignment.expressions());
332
333        for (name_token, expression) in assignment_expressions {
334            if let ast::Expression::Function(function_box) = expression {
335                let function_body = &function_box.1;
336                let identifier = range(name_token);
337
338                if let Some(id) = self.find_variable(identifier) {
339                    self.definitions
340                        .insert(id, ParameterCount::from_function_body(function_body));
341                }
342            }
343        }
344    }
345
346    fn visit_assignment(&mut self, assignment: &ast::Assignment) {
347        let assignment_expressions = assignment.variables().iter().zip(assignment.expressions());
348
349        for (var, expression) in assignment_expressions {
350            if let ast::Expression::Function(function_box) = expression {
351                let function_body = &function_box.1;
352                let identifier = range(var);
353
354                if let Some(reference) = self.find_reference(identifier) {
355                    if let Some(variable) = reference.resolved {
356                        self.verify_assignment(variable, function_body)
357                    }
358                }
359            }
360        }
361    }
362}
363
364struct MismatchedArgCountVisitor<'a> {
365    mismatched_arg_counts: Vec<MismatchedArgCount>,
366    scope_manager: &'a ScopeManager,
367    definitions: HashMap<Id<Variable>, ParameterCount>,
368}
369
370impl MismatchedArgCountVisitor<'_> {
371    // Split off since the formatter doesn't work inside if_chain.
372    fn get_function_definiton_ranges(&self, defined_variable: Id<Variable>) -> Vec<(usize, usize)> {
373        let variable = self.scope_manager.variables.get(defined_variable).unwrap();
374
375        variable
376            .definitions
377            .iter()
378            .copied()
379            .chain(variable.references.iter().filter_map(|reference_id| {
380                let reference = self.scope_manager.references.get(*reference_id)?;
381                if reference.write.is_some() {
382                    Some(reference.identifier)
383                } else {
384                    None
385                }
386            }))
387            .collect()
388    }
389}
390
391impl Visitor for MismatchedArgCountVisitor<'_> {
392    fn visit_function_call(&mut self, call: &ast::FunctionCall) {
393        if_chain::if_chain! {
394            // Check that we're using a named function call, with an anonymous call suffix
395            if let ast::Prefix::Name(name) = call.prefix();
396            if let Some(ast::Suffix::Call(ast::Call::AnonymousCall(args))) = call.suffixes().next();
397
398            // Find the corresponding function definition
399            let identifier = range(name);
400            if let Some((_, reference)) = self.scope_manager.references.iter().find(|reference| reference.1.identifier == identifier);
401            if let Some(defined_variable) = reference.resolved;
402            if let Some(parameter_count) = self.definitions.get(&defined_variable);
403
404            // Count the number of arguments provided
405            let num_args_provided = PassedArgumentCount::from_function_args(args);
406            if !parameter_count.correct_num_args_provided(num_args_provided);
407
408            then {
409                self.mismatched_arg_counts.push(MismatchedArgCount {
410                    num_provided: num_args_provided,
411                    parameter_count: *parameter_count,
412                    call_range: range(call),
413                    function_definition_ranges: self.get_function_definiton_ranges(defined_variable),
414                });
415            }
416        }
417    }
418}
419
420#[cfg(test)]
421mod tests {
422    use super::{super::test_util::test_lint, *};
423
424    #[test]
425    fn test_mismatched_arg_count() {
426        test_lint(
427            MismatchedArgCountLint::new(()).unwrap(),
428            "mismatched_arg_count",
429            "mismatched_arg_count",
430        );
431    }
432
433    #[test]
434    fn test_vararg_function_def() {
435        test_lint(
436            MismatchedArgCountLint::new(()).unwrap(),
437            "mismatched_arg_count",
438            "variable_function_def",
439        );
440    }
441
442    #[test]
443    fn test_call_side_effects() {
444        test_lint(
445            MismatchedArgCountLint::new(()).unwrap(),
446            "mismatched_arg_count",
447            "call_side_effects",
448        );
449    }
450
451    #[test]
452    fn test_args_alt_definition() {
453        test_lint(
454            MismatchedArgCountLint::new(()).unwrap(),
455            "mismatched_arg_count",
456            "alternative_function_definition",
457        );
458    }
459
460    #[test]
461    fn test_args_shadowing_variables() {
462        test_lint(
463            MismatchedArgCountLint::new(()).unwrap(),
464            "mismatched_arg_count",
465            "shadowing_variables",
466        );
467    }
468
469    #[test]
470    fn test_args_reassigned_variables() {
471        test_lint(
472            MismatchedArgCountLint::new(()).unwrap(),
473            "mismatched_arg_count",
474            "reassigned_variables",
475        );
476    }
477
478    #[test]
479    fn test_args_reassigned_variables_2() {
480        test_lint(
481            MismatchedArgCountLint::new(()).unwrap(),
482            "mismatched_arg_count",
483            "reassigned_variables_2",
484        );
485    }
486
487    #[test]
488    fn test_definition_location() {
489        test_lint(
490            MismatchedArgCountLint::new(()).unwrap(),
491            "mismatched_arg_count",
492            "definition_location",
493        );
494    }
495
496    #[test]
497    fn test_multiple_definition_locations() {
498        test_lint(
499            MismatchedArgCountLint::new(()).unwrap(),
500            "mismatched_arg_count",
501            "multiple_definition_locations",
502        );
503    }
504}