1use super::*;
2use crate::{
3 ast_util::{
4 is_vararg, range,
5 scopes::{Reference, ScopeManager, Variable},
6 },
7 text::plural,
8};
9use std::{
10 collections::HashMap,
11 convert::Infallible,
12 fmt::{self, Display},
13};
14
15use full_moon::{
16 ast::{self, Ast},
17 visitors::Visitor,
18};
19use id_arena::Id;
20
21pub struct MismatchedArgCountLint;
22
23impl Lint for MismatchedArgCountLint {
24 type Config = ();
25 type Error = Infallible;
26
27 const SEVERITY: Severity = Severity::Error;
28 const LINT_TYPE: LintType = LintType::Correctness;
29
30 fn new(_: Self::Config) -> Result<Self, Self::Error> {
31 Ok(MismatchedArgCountLint)
32 }
33
34 fn pass(&self, ast: &Ast, _: &Context, ast_context: &AstContext) -> Vec<Diagnostic> {
35 let mut definitions = HashMap::new();
37 let mut definitions_visitor = MapFunctionDefinitionVisitor {
38 scope_manager: &ast_context.scope_manager,
39 definitions: &mut definitions,
40 };
41 definitions_visitor.visit_ast(ast);
42
43 let mut visitor = MismatchedArgCountVisitor {
44 mismatched_arg_counts: Vec::new(),
45 scope_manager: &ast_context.scope_manager,
46 definitions,
47 };
48
49 visitor.visit_ast(ast);
50
51 visitor
52 .mismatched_arg_counts
53 .iter()
54 .map(|mismatched_arg| {
55 Diagnostic::new_complete(
56 "mismatched_arg_count",
57 mismatched_arg
58 .parameter_count
59 .to_message(mismatched_arg.num_provided),
60 Label::new_with_message(
61 mismatched_arg.call_range,
62 mismatched_arg.parameter_count.to_string(),
63 ),
64 Vec::new(),
65 mismatched_arg
66 .function_definition_ranges
67 .iter()
68 .map(|range| {
69 Label::new_with_message(
70 *range,
71 "note: function defined here".to_owned(),
72 )
73 })
74 .collect(),
75 )
76 })
77 .collect()
78 }
79}
80
81struct MismatchedArgCount {
82 parameter_count: ParameterCount,
83 num_provided: PassedArgumentCount,
84 call_range: (usize, usize),
85 function_definition_ranges: Vec<(usize, usize)>,
86}
87
88#[derive(Clone, Copy, Debug)]
89enum ParameterCount {
90 Fixed(usize),
92 Minimum(usize),
94 Variable,
96}
97
98impl ParameterCount {
99 fn from_function_body(function_body: &ast::FunctionBody) -> Self {
101 let mut necessary_params = 0;
102
103 for parameter in function_body.parameters() {
104 #[cfg_attr(
105 feature = "force_exhaustive_checks",
106 deny(non_exhaustive_omitted_patterns)
107 )]
108 match parameter {
109 ast::Parameter::Name(_) => necessary_params += 1,
110 ast::Parameter::Ellipsis(_) => {
111 if necessary_params == 0 {
112 return Self::Variable;
113 } else {
114 return Self::Minimum(necessary_params);
115 }
116 }
117 _ => {}
118 }
119 }
120
121 Self::Fixed(necessary_params)
122 }
123
124 fn correct_num_args_provided(self, provided: PassedArgumentCount) -> bool {
128 match self {
129 ParameterCount::Fixed(required) => match provided {
130 PassedArgumentCount::Fixed(provided) => provided <= required,
131 PassedArgumentCount::Variable(atleast_provided) => atleast_provided <= required,
134 },
135 ParameterCount::Minimum(_) => true,
138 ParameterCount::Variable => true,
140 }
141 }
142
143 fn to_message(self, provided: PassedArgumentCount) -> String {
144 match self {
145 ParameterCount::Fixed(required) => {
146 format!(
147 "this function takes {} {} but {} were supplied",
148 required,
149 plural(required, "argument", "arguments"),
150 provided
151 )
152 }
153 ParameterCount::Minimum(required) => format!(
154 "this function takes at least {} {} but {} were supplied",
155 required,
156 plural(required, "argument", "arguments"),
157 provided
158 ),
159 ParameterCount::Variable => "a variable amount of arguments".to_owned(),
160 }
161 }
162
163 fn overlap_with_other_parameter_count(self, other: ParameterCount) -> ParameterCount {
164 match (self, other) {
165 (ParameterCount::Variable, _) | (_, ParameterCount::Variable) => {
167 ParameterCount::Variable
168 }
169
170 (ParameterCount::Fixed(fixed), ParameterCount::Minimum(minimum))
174 | (ParameterCount::Minimum(minimum), ParameterCount::Fixed(fixed)) => {
175 ParameterCount::Minimum(minimum.min(fixed))
176 }
177
178 (ParameterCount::Fixed(this_fixed), ParameterCount::Fixed(other_fixed)) => {
181 if this_fixed == other_fixed {
182 ParameterCount::Fixed(this_fixed)
183 } else {
184 ParameterCount::Fixed(this_fixed.max(other_fixed))
185 }
186 }
187
188 (ParameterCount::Minimum(this_minimum), ParameterCount::Minimum(other_minimum)) => {
190 ParameterCount::Minimum(this_minimum.min(other_minimum))
191 }
192 }
193 }
194}
195
196impl Display for ParameterCount {
197 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198 match self {
199 ParameterCount::Fixed(required) => write!(
200 f,
201 "expected {} {}",
202 required,
203 plural(*required, "argument", "arguments")
204 ),
205 ParameterCount::Minimum(required) => {
206 write!(
207 f,
208 "expected at least {} {}",
209 required,
210 plural(*required, "argument", "arguments")
211 )
212 }
213 ParameterCount::Variable => write!(f, "expected any number of arguments"),
214 }
215 }
216}
217
218#[derive(Clone, Copy, Debug)]
219enum PassedArgumentCount {
220 Fixed(usize),
222 Variable(usize),
224}
225
226impl PassedArgumentCount {
227 fn from_function_args(function_args: &ast::FunctionArgs) -> Self {
228 match function_args {
229 ast::FunctionArgs::Parentheses { arguments, .. } => {
230 let mut passed_argument_count = 0;
236
237 for argument in arguments.pairs() {
238 passed_argument_count += 1;
239
240 if let ast::punctuated::Pair::End(expression) = argument {
241 if matches!(expression, ast::Expression::FunctionCall(_))
242 || is_vararg(expression)
243 {
244 return PassedArgumentCount::Variable(passed_argument_count);
245 }
246 }
247 }
248
249 Self::Fixed(passed_argument_count)
250 }
251 ast::FunctionArgs::String(_) => Self::Fixed(1),
252 ast::FunctionArgs::TableConstructor(_) => Self::Fixed(1),
253 _ => Self::Fixed(0),
254 }
255 }
256}
257
258impl Display for PassedArgumentCount {
259 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
260 match self {
261 PassedArgumentCount::Fixed(amount) => write!(f, "{amount} arguments"),
262 PassedArgumentCount::Variable(amount) => write!(f, "at least {amount} arguments"),
263 }
264 }
265}
266
267struct MapFunctionDefinitionVisitor<'a> {
269 scope_manager: &'a ScopeManager,
270 definitions: &'a mut HashMap<Id<Variable>, ParameterCount>,
271}
272
273impl MapFunctionDefinitionVisitor<'_> {
274 fn find_variable(&self, identifier: (usize, usize)) -> Option<Id<Variable>> {
275 self.scope_manager
276 .variables
277 .iter()
278 .find(|variable| variable.1.identifiers.contains(&identifier))
279 .map(|variable| variable.0)
280 }
281
282 fn find_reference(&self, identifier: (usize, usize)) -> Option<&Reference> {
283 self.scope_manager
284 .references
285 .iter()
286 .find(|reference| reference.1.identifier == identifier)
287 .map(|reference| reference.1)
288 }
289
290 fn verify_assignment(&mut self, variable: Id<Variable>, function_body: &ast::FunctionBody) {
296 let parameter_count = ParameterCount::from_function_body(function_body);
297
298 self.definitions
299 .entry(variable)
300 .and_modify(|older_count| {
301 *older_count = parameter_count.overlap_with_other_parameter_count(*older_count)
302 })
303 .or_insert(parameter_count);
304 }
305}
306
307impl Visitor for MapFunctionDefinitionVisitor<'_> {
308 fn visit_local_function(&mut self, function: &ast::LocalFunction) {
309 let identifier = range(function.name());
310
311 if let Some(id) = self.find_variable(identifier) {
312 self.definitions
313 .insert(id, ParameterCount::from_function_body(function.body()));
314 }
315 }
316
317 fn visit_function_declaration(&mut self, function: &ast::FunctionDeclaration) {
318 let identifier = range(function.name());
319
320 if let Some(reference) = self.find_reference(identifier) {
321 if let Some(variable) = reference.resolved {
322 self.verify_assignment(variable, function.body())
323 }
324 }
325 }
326
327 fn visit_local_assignment(&mut self, local_assignment: &ast::LocalAssignment) {
328 let assignment_expressions = local_assignment
329 .names()
330 .iter()
331 .zip(local_assignment.expressions());
332
333 for (name_token, expression) in assignment_expressions {
334 if let ast::Expression::Function(function_box) = expression {
335 let function_body = &function_box.1;
336 let identifier = range(name_token);
337
338 if let Some(id) = self.find_variable(identifier) {
339 self.definitions
340 .insert(id, ParameterCount::from_function_body(function_body));
341 }
342 }
343 }
344 }
345
346 fn visit_assignment(&mut self, assignment: &ast::Assignment) {
347 let assignment_expressions = assignment.variables().iter().zip(assignment.expressions());
348
349 for (var, expression) in assignment_expressions {
350 if let ast::Expression::Function(function_box) = expression {
351 let function_body = &function_box.1;
352 let identifier = range(var);
353
354 if let Some(reference) = self.find_reference(identifier) {
355 if let Some(variable) = reference.resolved {
356 self.verify_assignment(variable, function_body)
357 }
358 }
359 }
360 }
361 }
362}
363
364struct MismatchedArgCountVisitor<'a> {
365 mismatched_arg_counts: Vec<MismatchedArgCount>,
366 scope_manager: &'a ScopeManager,
367 definitions: HashMap<Id<Variable>, ParameterCount>,
368}
369
370impl MismatchedArgCountVisitor<'_> {
371 fn get_function_definiton_ranges(&self, defined_variable: Id<Variable>) -> Vec<(usize, usize)> {
373 let variable = self.scope_manager.variables.get(defined_variable).unwrap();
374
375 variable
376 .definitions
377 .iter()
378 .copied()
379 .chain(variable.references.iter().filter_map(|reference_id| {
380 let reference = self.scope_manager.references.get(*reference_id)?;
381 if reference.write.is_some() {
382 Some(reference.identifier)
383 } else {
384 None
385 }
386 }))
387 .collect()
388 }
389}
390
391impl Visitor for MismatchedArgCountVisitor<'_> {
392 fn visit_function_call(&mut self, call: &ast::FunctionCall) {
393 if_chain::if_chain! {
394 if let ast::Prefix::Name(name) = call.prefix();
396 if let Some(ast::Suffix::Call(ast::Call::AnonymousCall(args))) = call.suffixes().next();
397
398 let identifier = range(name);
400 if let Some((_, reference)) = self.scope_manager.references.iter().find(|reference| reference.1.identifier == identifier);
401 if let Some(defined_variable) = reference.resolved;
402 if let Some(parameter_count) = self.definitions.get(&defined_variable);
403
404 let num_args_provided = PassedArgumentCount::from_function_args(args);
406 if !parameter_count.correct_num_args_provided(num_args_provided);
407
408 then {
409 self.mismatched_arg_counts.push(MismatchedArgCount {
410 num_provided: num_args_provided,
411 parameter_count: *parameter_count,
412 call_range: range(call),
413 function_definition_ranges: self.get_function_definiton_ranges(defined_variable),
414 });
415 }
416 }
417 }
418}
419
420#[cfg(test)]
421mod tests {
422 use super::{super::test_util::test_lint, *};
423
424 #[test]
425 fn test_mismatched_arg_count() {
426 test_lint(
427 MismatchedArgCountLint::new(()).unwrap(),
428 "mismatched_arg_count",
429 "mismatched_arg_count",
430 );
431 }
432
433 #[test]
434 fn test_vararg_function_def() {
435 test_lint(
436 MismatchedArgCountLint::new(()).unwrap(),
437 "mismatched_arg_count",
438 "variable_function_def",
439 );
440 }
441
442 #[test]
443 fn test_call_side_effects() {
444 test_lint(
445 MismatchedArgCountLint::new(()).unwrap(),
446 "mismatched_arg_count",
447 "call_side_effects",
448 );
449 }
450
451 #[test]
452 fn test_args_alt_definition() {
453 test_lint(
454 MismatchedArgCountLint::new(()).unwrap(),
455 "mismatched_arg_count",
456 "alternative_function_definition",
457 );
458 }
459
460 #[test]
461 fn test_args_shadowing_variables() {
462 test_lint(
463 MismatchedArgCountLint::new(()).unwrap(),
464 "mismatched_arg_count",
465 "shadowing_variables",
466 );
467 }
468
469 #[test]
470 fn test_args_reassigned_variables() {
471 test_lint(
472 MismatchedArgCountLint::new(()).unwrap(),
473 "mismatched_arg_count",
474 "reassigned_variables",
475 );
476 }
477
478 #[test]
479 fn test_args_reassigned_variables_2() {
480 test_lint(
481 MismatchedArgCountLint::new(()).unwrap(),
482 "mismatched_arg_count",
483 "reassigned_variables_2",
484 );
485 }
486
487 #[test]
488 fn test_definition_location() {
489 test_lint(
490 MismatchedArgCountLint::new(()).unwrap(),
491 "mismatched_arg_count",
492 "definition_location",
493 );
494 }
495
496 #[test]
497 fn test_multiple_definition_locations() {
498 test_lint(
499 MismatchedArgCountLint::new(()).unwrap(),
500 "mismatched_arg_count",
501 "multiple_definition_locations",
502 );
503 }
504}