Skip to main content

runmat_static_analysis/lints/
shape.rs

1use runmat_builtins::Type;
2use runmat_hir::{
3    eval_const_num, infer_expr_type_with_env, merge_span, HirClassMember, HirDiagnostic,
4    HirDiagnosticSeverity, HirExpr, HirExprKind, HirStmt, LoweringResult, Span, VarId,
5};
6use runmat_parser as parser;
7
8pub fn lint_shapes(result: &LoweringResult) -> Vec<HirDiagnostic> {
9    fn vector_literal_length(expr: &HirExpr) -> Option<usize> {
10        let shape = tensor_literal_shape(expr)?;
11        match (
12            shape.first().copied().flatten(),
13            shape.get(1).copied().flatten(),
14        ) {
15            (Some(r), Some(c)) => {
16                if r == 1 {
17                    Some(c)
18                } else if c == 1 {
19                    Some(r)
20                } else {
21                    None
22                }
23            }
24            _ => None,
25        }
26    }
27
28    fn concat_dims(ty: &Type) -> Option<(Option<usize>, Option<usize>)> {
29        match ty {
30            Type::Num | Type::Int | Type::Bool => Some((Some(1), Some(1))),
31            Type::Tensor { shape: Some(shape) } | Type::Logical { shape: Some(shape) } => {
32                Some(runmat_builtins::shape_rules::matrix_dims(shape))
33            }
34            _ => None,
35        }
36    }
37
38    fn format_dim(dim: Option<usize>) -> String {
39        dim.map(|v| v.to_string())
40            .unwrap_or_else(|| "unknown".to_string())
41    }
42
43    fn format_shape(shape: &[Option<usize>]) -> String {
44        if shape.len() == 2 {
45            return format!("{} x {}", format_dim(shape[0]), format_dim(shape[1]));
46        }
47        let dims: Vec<String> = shape.iter().map(|d| format_dim(*d)).collect();
48        format!("[{}]", dims.join(", "))
49    }
50
51    fn matrix_dims_from_type(ty: &Type) -> Option<(Option<usize>, Option<usize>)> {
52        match ty {
53            Type::Tensor { shape: Some(shape) } | Type::Logical { shape: Some(shape) } => {
54                Some(runmat_builtins::shape_rules::matrix_dims(shape))
55            }
56            _ => None,
57        }
58    }
59
60    fn element_count(shape: &[Option<usize>]) -> Option<usize> {
61        runmat_builtins::shape_rules::element_count_if_known(shape)
62    }
63
64    fn vector_length(shape: &[Option<usize>]) -> Option<usize> {
65        let count = element_count(shape)?;
66        let is_vector = shape.len() == 1
67            || (shape.len() == 2
68                && (shape[0] == Some(1) || shape[1] == Some(1))
69                && shape.iter().all(|d| d.is_some()));
70        if is_vector {
71            Some(count)
72        } else {
73            None
74        }
75    }
76
77    fn tensor_literal_shape(expr: &HirExpr) -> Option<Vec<Option<usize>>> {
78        let HirExprKind::Tensor(rows) = &expr.kind else {
79            return None;
80        };
81        if rows.is_empty() {
82            return Some(vec![Some(0), Some(0)]);
83        }
84        let cols = rows.iter().map(|r| r.len()).max().unwrap_or(0);
85        Some(vec![Some(rows.len()), Some(cols)])
86    }
87
88    enum DimSpec {
89        Known(usize),
90        Unknown,
91        Negative,
92        NonInteger,
93    }
94
95    fn parse_dim(expr: &HirExpr) -> DimSpec {
96        if let Some(value) = eval_const_num(expr) {
97            if value.is_finite() {
98                let rounded = value.round();
99                if (value - rounded).abs() <= 1e-9 {
100                    if rounded < 0.0 {
101                        return DimSpec::Negative;
102                    }
103                    return DimSpec::Known(rounded as usize);
104                }
105                return DimSpec::NonInteger;
106            }
107        }
108        DimSpec::Unknown
109    }
110
111    fn type_shape_for_broadcast(ty: &Type) -> Option<Vec<Option<usize>>> {
112        match ty {
113            Type::Tensor { shape: Some(shape) } | Type::Logical { shape: Some(shape) } => {
114                Some(shape.clone())
115            }
116            Type::Num | Type::Int | Type::Bool => Some(vec![Some(1), Some(1)]),
117            _ => None,
118        }
119    }
120
121    fn check_binary(
122        op: &parser::BinOp,
123        lhs: &HirExpr,
124        rhs: &HirExpr,
125        env: &std::collections::HashMap<VarId, Type>,
126        returns: &std::collections::HashMap<String, Vec<Type>>,
127        diags: &mut Vec<HirDiagnostic>,
128    ) {
129        let lhs_ty = infer_expr_type_with_env(lhs, env, returns);
130        let rhs_ty = infer_expr_type_with_env(rhs, env, returns);
131        match op {
132            parser::BinOp::Mul => {
133                if let Some(false) =
134                    runmat_builtins::shape_rules::matmul_compatible(&lhs_ty, &rhs_ty)
135                {
136                    let detail = match (
137                        matrix_dims_from_type(&lhs_ty),
138                        matrix_dims_from_type(&rhs_ty),
139                    ) {
140                        (Some((lrows, lcols)), Some((rrows, rcols))) => format!(
141                            "left is {} x {}, right is {} x {} (inner dimensions {} and {})",
142                            format_dim(lrows),
143                            format_dim(lcols),
144                            format_dim(rrows),
145                            format_dim(rcols),
146                            format_dim(lcols),
147                            format_dim(rrows)
148                        ),
149                        _ => "unknown shapes".to_string(),
150                    };
151                    diags.push(HirDiagnostic {
152                        message: format!(
153                            "Matrix multiply dimension mismatch: {detail} (inner dimensions must match)"
154                        ),
155                        span: merge_span(lhs.span, rhs.span),
156                        code: "lint.shape.matmul",
157                        severity: HirDiagnosticSeverity::Warning,
158                    });
159                }
160            }
161            parser::BinOp::LeftDiv => {
162                if let Some(false) =
163                    runmat_builtins::shape_rules::left_divide_compatible(&lhs_ty, &rhs_ty)
164                {
165                    let detail = match (
166                        matrix_dims_from_type(&lhs_ty),
167                        matrix_dims_from_type(&rhs_ty),
168                    ) {
169                        (Some((lrows, _)), Some((rrows, _))) => format!(
170                            "left row dimension {}, right row dimension {}",
171                            format_dim(lrows),
172                            format_dim(rrows)
173                        ),
174                        _ => "unknown shapes".to_string(),
175                    };
176                    diags.push(HirDiagnostic {
177                        message: format!(
178                            "Left divide dimension mismatch: {detail} (row dimensions must match)"
179                        ),
180                        span: merge_span(lhs.span, rhs.span),
181                        code: "lint.shape.ldivide",
182                        severity: HirDiagnosticSeverity::Warning,
183                    });
184                }
185            }
186            parser::BinOp::Div => {
187                if let Some(false) =
188                    runmat_builtins::shape_rules::right_divide_compatible(&lhs_ty, &rhs_ty)
189                {
190                    let detail = match (
191                        matrix_dims_from_type(&lhs_ty),
192                        matrix_dims_from_type(&rhs_ty),
193                    ) {
194                        (Some((_, lcols)), Some((_, rcols))) => format!(
195                            "left column dimension {}, right column dimension {}",
196                            format_dim(lcols),
197                            format_dim(rcols)
198                        ),
199                        _ => "unknown shapes".to_string(),
200                    };
201                    diags.push(HirDiagnostic {
202                        message: format!(
203                            "Right divide dimension mismatch: {detail} (column dimensions must match)"
204                        ),
205                        span: merge_span(lhs.span, rhs.span),
206                        code: "lint.shape.rdivide",
207                        severity: HirDiagnosticSeverity::Warning,
208                    });
209                }
210            }
211            parser::BinOp::Add
212            | parser::BinOp::Sub
213            | parser::BinOp::ElemMul
214            | parser::BinOp::ElemDiv
215            | parser::BinOp::ElemPow
216            | parser::BinOp::ElemLeftDiv
217            | parser::BinOp::Equal
218            | parser::BinOp::NotEqual
219            | parser::BinOp::Less
220            | parser::BinOp::LessEqual
221            | parser::BinOp::Greater
222            | parser::BinOp::GreaterEqual => {
223                let lhs_shape = type_shape_for_broadcast(&lhs_ty);
224                let rhs_shape = type_shape_for_broadcast(&rhs_ty);
225                if let (Some(a), Some(b)) = (lhs_shape, rhs_shape) {
226                    if let Some(false) = runmat_builtins::shape_rules::broadcast_compatible(&a, &b)
227                    {
228                        let detail = format!(
229                            "left is {}, right is {}",
230                            format_shape(&a),
231                            format_shape(&b)
232                        );
233                        diags.push(HirDiagnostic {
234                            message: format!(
235                                "Elementwise/broadcast dimension mismatch: {detail} (broadcasting failed)"
236                            ),
237                            span: merge_span(lhs.span, rhs.span),
238                            code: "lint.shape.broadcast",
239                            severity: HirDiagnosticSeverity::Warning,
240                        });
241                    }
242                }
243            }
244            _ => {}
245        }
246    }
247
248    fn walk_expr(
249        expr: &HirExpr,
250        env: &std::collections::HashMap<VarId, Type>,
251        returns: &std::collections::HashMap<String, Vec<Type>>,
252        diags: &mut Vec<HirDiagnostic>,
253    ) {
254        match &expr.kind {
255            HirExprKind::Unary(_, inner) => walk_expr(inner, env, returns, diags),
256            HirExprKind::Binary(lhs, op, rhs) => {
257                check_binary(op, lhs, rhs, env, returns, diags);
258                walk_expr(lhs, env, returns, diags);
259                walk_expr(rhs, env, returns, diags);
260            }
261            HirExprKind::Tensor(rows) => {
262                let mut col_constraint: Option<usize> = None;
263                for row in rows {
264                    let mut row_dim: Option<usize> = None;
265                    let mut row_cols: Option<usize> = Some(0);
266                    let mut first_span: Option<Span> = None;
267                    for e in row {
268                        if first_span.is_none() {
269                            first_span = Some(e.span);
270                        }
271                        let ty = infer_expr_type_with_env(e, env, returns);
272                        if let Some((rows_dim, cols_dim)) = concat_dims(&ty) {
273                            if let (Some(prev), Some(curr)) = (row_dim, rows_dim) {
274                                if prev != curr {
275                                    diags.push(HirDiagnostic {
276                                        message: format!(
277                                            "Horizontal concatenation dimension mismatch: left row dimension {prev}, right row dimension {curr} (row dimensions must match)"
278                                        ),
279                                        span: merge_span(first_span.unwrap_or(e.span), e.span),
280                                        code: "lint.shape.horzcat",
281                                        severity: HirDiagnosticSeverity::Warning,
282                                    });
283                                }
284                            }
285                            if row_dim.is_none() {
286                                row_dim = rows_dim;
287                            }
288                            match (row_cols, cols_dim) {
289                                (Some(total), Some(value)) => row_cols = Some(total + value),
290                                _ => row_cols = None,
291                            }
292                        } else {
293                            row_dim = None;
294                            row_cols = None;
295                        }
296                    }
297
298                    if let (Some(prev_cols), Some(curr_cols)) = (col_constraint, row_cols) {
299                        if prev_cols != curr_cols {
300                            diags.push(HirDiagnostic {
301                                message: format!(
302                                    "Vertical concatenation dimension mismatch: upper column dimension {prev_cols}, lower column dimension {curr_cols} (column dimensions must match)"
303                                ),
304                                span: expr.span,
305                                code: "lint.shape.vertcat",
306                                severity: HirDiagnosticSeverity::Warning,
307                            });
308                        }
309                    }
310                    if col_constraint.is_none() {
311                        col_constraint = row_cols;
312                    }
313                }
314
315                for row in rows {
316                    for e in row {
317                        walk_expr(e, env, returns, diags);
318                    }
319                }
320            }
321            HirExprKind::Cell(rows) => {
322                for row in rows {
323                    for e in row {
324                        walk_expr(e, env, returns, diags);
325                    }
326                }
327            }
328            HirExprKind::Index(base, idxs) | HirExprKind::IndexCell(base, idxs) => {
329                walk_expr(base, env, returns, diags);
330                for idx in idxs {
331                    walk_expr(idx, env, returns, diags);
332                }
333                if matches!(expr.kind, HirExprKind::Index(_, _)) && idxs.len() == 1 {
334                    let base_ty = infer_expr_type_with_env(base, env, returns);
335                    let idx_ty = infer_expr_type_with_env(&idxs[0], env, returns);
336                    let base_shape = match base_ty {
337                        Type::Tensor { shape: Some(shape) }
338                        | Type::Logical { shape: Some(shape) } => Some(shape),
339                        _ => None,
340                    };
341                    let mask_shape = match idx_ty {
342                        Type::Logical { shape: Some(shape) }
343                        | Type::Tensor { shape: Some(shape) } => Some(shape),
344                        _ => None,
345                    };
346                    if let (Some(base_shape), Some(mask_shape)) = (base_shape, mask_shape) {
347                        if let (Some(base_count), Some(mask_count)) =
348                            (element_count(&base_shape), element_count(&mask_shape))
349                        {
350                            if base_count != mask_count {
351                                diags.push(HirDiagnostic {
352                                    message: format!(
353                                        "Logical index size mismatch: mask has {mask_count}, array has {base_count} (must match)"
354                                    ),
355                                    span: merge_span(base.span, idxs[0].span),
356                                    code: "lint.shape.logical_index",
357                                    severity: HirDiagnosticSeverity::Warning,
358                                });
359                            }
360                        }
361                    }
362                }
363            }
364            HirExprKind::Range(start, step, end) => {
365                walk_expr(start, env, returns, diags);
366                if let Some(step) = step.as_ref() {
367                    walk_expr(step, env, returns, diags);
368                }
369                walk_expr(end, env, returns, diags);
370            }
371            HirExprKind::FuncCall(name, args) => {
372                if name.eq_ignore_ascii_case("dot") && args.len() >= 2 {
373                    let lhs_ty = infer_expr_type_with_env(&args[0], env, returns);
374                    let rhs_ty = infer_expr_type_with_env(&args[1], env, returns);
375                    let lhs_len = match lhs_ty {
376                        Type::Tensor { shape: Some(shape) }
377                        | Type::Logical { shape: Some(shape) } => vector_length(&shape),
378                        _ => None,
379                    };
380                    let rhs_len = match rhs_ty {
381                        Type::Tensor { shape: Some(shape) }
382                        | Type::Logical { shape: Some(shape) } => vector_length(&shape),
383                        _ => None,
384                    };
385                    if let (Some(a), Some(b)) = (lhs_len, rhs_len) {
386                        if a != b {
387                            diags.push(HirDiagnostic {
388                                message: format!(
389                                    "Dot product length mismatch: left length {a}, right length {b} (lengths must match)"
390                                ),
391                                span: merge_span(args[0].span, args[1].span),
392                                code: "lint.shape.dot",
393                                severity: HirDiagnosticSeverity::Warning,
394                            });
395                        }
396                    }
397                }
398
399                if name.eq_ignore_ascii_case("reshape") && args.len() >= 2 {
400                    let input_ty = infer_expr_type_with_env(&args[0], env, returns);
401                    let input_shape = match input_ty {
402                        Type::Tensor { shape: Some(shape) }
403                        | Type::Logical { shape: Some(shape) } => Some(shape),
404                        _ => None,
405                    };
406                    let mut dims: Vec<Option<usize>> = Vec::new();
407                    let mut negative_count = 0usize;
408                    let mut non_integer = false;
409                    for arg in args.iter().skip(1) {
410                        match parse_dim(arg) {
411                            DimSpec::Known(value) => dims.push(Some(value)),
412                            DimSpec::Negative => {
413                                negative_count += 1;
414                                dims.push(None);
415                            }
416                            DimSpec::NonInteger => {
417                                non_integer = true;
418                                dims.push(None);
419                            }
420                            DimSpec::Unknown => dims.push(None),
421                        }
422                    }
423                    if negative_count > 1 {
424                        diags.push(HirDiagnostic {
425                            message:
426                                "Reshape dimension mismatch: more than one negative dimension (only one allowed)"
427                                    .to_string(),
428                            span: merge_span(args[0].span, args[1].span),
429                            code: "lint.shape.reshape",
430                            severity: HirDiagnosticSeverity::Warning,
431                        });
432                    } else if negative_count == 1 && non_integer {
433                        diags.push(HirDiagnostic {
434                            message:
435                                "Reshape dimension mismatch: negative dimensions require integer sizes"
436                                    .to_string(),
437                            span: merge_span(args[0].span, args[1].span),
438                            code: "lint.shape.reshape",
439                            severity: HirDiagnosticSeverity::Warning,
440                        });
441                    }
442                    if non_integer {
443                        diags.push(HirDiagnostic {
444                            message: "Reshape dimension mismatch: non-integer dimensions"
445                                .to_string(),
446                            span: merge_span(args[0].span, args[1].span),
447                            code: "lint.shape.reshape",
448                            severity: HirDiagnosticSeverity::Warning,
449                        });
450                    }
451                    if let Some(shape) =
452                        runmat_builtins::shape_rules::constructor_shape_from_dims(&dims)
453                    {
454                        if let Some(input_shape) = input_shape {
455                            if let (Some(in_count), Some(out_count)) =
456                                (element_count(&input_shape), element_count(&shape))
457                            {
458                                if in_count != out_count {
459                                    diags.push(HirDiagnostic {
460                                        message: format!(
461                                            "Reshape element count mismatch: input has {in_count}, output has {out_count} (must match)"
462                                        ),
463                                        span: merge_span(args[0].span, args[1].span),
464                                        code: "lint.shape.reshape",
465                                        severity: HirDiagnosticSeverity::Warning,
466                                    });
467                                }
468                            }
469                        }
470                    }
471                }
472
473                if (name.eq_ignore_ascii_case("permute") || name.eq_ignore_ascii_case("ipermute"))
474                    && args.len() >= 2
475                {
476                    let input_ty = infer_expr_type_with_env(&args[0], env, returns);
477                    let input_rank = match input_ty {
478                        Type::Tensor { shape: Some(shape) }
479                        | Type::Logical { shape: Some(shape) } => Some(shape.len()),
480                        _ => None,
481                    };
482                    let order_rank = vector_literal_length(&args[1]);
483                    if let (Some(in_rank), Some(ord_rank)) = (input_rank, order_rank) {
484                        if in_rank != ord_rank {
485                            diags.push(HirDiagnostic {
486                                message: format!(
487                                    "Permute rank mismatch: input rank {in_rank}, order length {ord_rank} (must match)"
488                                ),
489                                span: merge_span(args[0].span, args[1].span),
490                                code: "lint.shape.permute",
491                                severity: HirDiagnosticSeverity::Warning,
492                            });
493                        }
494                    }
495                    if let HirExprKind::Tensor(rows) = &args[1].kind {
496                        let mut seen: std::collections::BTreeSet<usize> =
497                            std::collections::BTreeSet::new();
498                        let mut duplicate = false;
499                        let mut max_index = 0usize;
500                        for row in rows {
501                            for entry in row {
502                                if let Some(value) = eval_const_num(entry) {
503                                    let rounded = value.round();
504                                    if (value - rounded).abs() <= 1e-9 && rounded >= 1.0 {
505                                        let idx = rounded as usize;
506                                        max_index = max_index.max(idx);
507                                        if !seen.insert(idx) {
508                                            duplicate = true;
509                                        }
510                                    }
511                                }
512                            }
513                        }
514                        if duplicate {
515                            diags.push(HirDiagnostic {
516                                message:
517                                    "Permute order mismatch: duplicate dimensions in order vector"
518                                        .to_string(),
519                                span: args[1].span,
520                                code: "lint.shape.permute",
521                                severity: HirDiagnosticSeverity::Warning,
522                            });
523                        }
524                        if let Some(in_rank) = input_rank {
525                            if max_index > in_rank {
526                                diags.push(HirDiagnostic {
527                                    message: "Permute order mismatch: order references a dimension larger than the input rank"
528                                        .to_string(),
529                                    span: args[1].span,
530                                    code: "lint.shape.permute",
531                                    severity: HirDiagnosticSeverity::Warning,
532                                });
533                            }
534                        }
535                    }
536                }
537
538                if name.eq_ignore_ascii_case("repmat") && args.len() >= 2 {
539                    let mut non_integer = false;
540                    let mut negative = false;
541                    for arg in args.iter().skip(1) {
542                        match parse_dim(arg) {
543                            DimSpec::Known(_) => {}
544                            DimSpec::NonInteger => non_integer = true,
545                            DimSpec::Negative => negative = true,
546                            _ => {}
547                        }
548                    }
549                    if non_integer || negative {
550                        let reason = if non_integer {
551                            "non-integer"
552                        } else {
553                            "negative"
554                        };
555                        diags.push(HirDiagnostic {
556                            message: format!(
557                                "Repmat dimension mismatch: {reason} replication factors"
558                            ),
559                            span: merge_span(args[0].span, args[1].span),
560                            code: "lint.shape.repmat",
561                            severity: HirDiagnosticSeverity::Warning,
562                        });
563                    }
564                }
565
566                if (name.eq_ignore_ascii_case("sum")
567                    || name.eq_ignore_ascii_case("mean")
568                    || name.eq_ignore_ascii_case("prod")
569                    || name.eq_ignore_ascii_case("min")
570                    || name.eq_ignore_ascii_case("max"))
571                    && args.len() >= 2
572                {
573                    let input_ty = infer_expr_type_with_env(&args[0], env, returns);
574                    let input_rank = match input_ty {
575                        Type::Tensor { shape: Some(shape) }
576                        | Type::Logical { shape: Some(shape) } => Some(shape.len()),
577                        _ => None,
578                    };
579                    if let Some(rank) = input_rank {
580                        if let DimSpec::Known(dim) = parse_dim(&args[1]) {
581                            if dim == 0 || dim > rank {
582                                diags.push(HirDiagnostic {
583                                    message: format!(
584                                        "Reduction dimension mismatch: dimension {dim} is out of range for rank {rank}"
585                                    ),
586                                    span: args[1].span,
587                                    code: "lint.shape.reduction",
588                                    severity: HirDiagnosticSeverity::Warning,
589                                });
590                            }
591                        }
592                    }
593                }
594
595                for arg in args {
596                    walk_expr(arg, env, returns, diags);
597                }
598            }
599            HirExprKind::MethodCall(_, _, args) => {
600                for arg in args {
601                    walk_expr(arg, env, returns, diags);
602                }
603            }
604            HirExprKind::Member(base, _) | HirExprKind::MemberDynamic(base, _) => {
605                walk_expr(base, env, returns, diags);
606            }
607            HirExprKind::AnonFunc { body, .. } => {
608                walk_expr(body, env, returns, diags);
609            }
610            _ => {}
611        }
612    }
613
614    fn walk_stmt(
615        stmt: &HirStmt,
616        env: &std::collections::HashMap<VarId, Type>,
617        returns: &std::collections::HashMap<String, Vec<Type>>,
618        func_envs: &std::collections::HashMap<String, std::collections::HashMap<VarId, Type>>,
619        diags: &mut Vec<HirDiagnostic>,
620    ) {
621        match stmt {
622            HirStmt::Assign(_, expr, _, _)
623            | HirStmt::ExprStmt(expr, _, _)
624            | HirStmt::MultiAssign(_, expr, _, _) => walk_expr(expr, env, returns, diags),
625            HirStmt::If {
626                cond,
627                then_body,
628                elseif_blocks,
629                else_body,
630                ..
631            } => {
632                walk_expr(cond, env, returns, diags);
633                for s in then_body {
634                    walk_stmt(s, env, returns, func_envs, diags);
635                }
636                for (cond, body) in elseif_blocks {
637                    walk_expr(cond, env, returns, diags);
638                    for s in body {
639                        walk_stmt(s, env, returns, func_envs, diags);
640                    }
641                }
642                if let Some(body) = else_body {
643                    for s in body {
644                        walk_stmt(s, env, returns, func_envs, diags);
645                    }
646                }
647            }
648            HirStmt::While { cond, body, .. } => {
649                walk_expr(cond, env, returns, diags);
650                for s in body {
651                    walk_stmt(s, env, returns, func_envs, diags);
652                }
653            }
654            HirStmt::For { expr, body, .. } => {
655                walk_expr(expr, env, returns, diags);
656                for s in body {
657                    walk_stmt(s, env, returns, func_envs, diags);
658                }
659            }
660            HirStmt::Switch {
661                expr,
662                cases,
663                otherwise,
664                ..
665            } => {
666                walk_expr(expr, env, returns, diags);
667                for (case_expr, case_body) in cases {
668                    walk_expr(case_expr, env, returns, diags);
669                    for s in case_body {
670                        walk_stmt(s, env, returns, func_envs, diags);
671                    }
672                }
673                if let Some(body) = otherwise {
674                    for s in body {
675                        walk_stmt(s, env, returns, func_envs, diags);
676                    }
677                }
678            }
679            HirStmt::Function { name, body, .. } => {
680                let func_env = func_envs.get(name).cloned().unwrap_or_default();
681                for s in body {
682                    walk_stmt(s, &func_env, returns, func_envs, diags);
683                }
684            }
685            HirStmt::ClassDef { members, .. } => {
686                for member in members {
687                    if let HirClassMember::Methods { body, .. } = member {
688                        for s in body {
689                            walk_stmt(s, env, returns, func_envs, diags);
690                        }
691                    }
692                }
693            }
694            _ => {}
695        }
696    }
697
698    let mut diags = Vec::new();
699    let global_env = result.inferred_globals.clone();
700    for stmt in &result.hir.body {
701        walk_stmt(
702            stmt,
703            &global_env,
704            &result.inferred_function_returns,
705            &result.inferred_function_envs,
706            &mut diags,
707        );
708    }
709    diags
710}