1use runmat_builtins::Type;
2use runmat_hir::{
3 eval_const_num, infer_expr_type_with_env, merge_span, HirClassMember, HirDiagnostic,
4 HirDiagnosticSeverity, HirExpr, HirExprKind, HirStmt, LoweringResult, Span, VarId,
5};
6use runmat_parser as parser;
7
8pub fn lint_shapes(result: &LoweringResult) -> Vec<HirDiagnostic> {
9 fn vector_literal_length(expr: &HirExpr) -> Option<usize> {
10 let shape = tensor_literal_shape(expr)?;
11 match (
12 shape.first().copied().flatten(),
13 shape.get(1).copied().flatten(),
14 ) {
15 (Some(r), Some(c)) => {
16 if r == 1 {
17 Some(c)
18 } else if c == 1 {
19 Some(r)
20 } else {
21 None
22 }
23 }
24 _ => None,
25 }
26 }
27
28 fn concat_dims(ty: &Type) -> Option<(Option<usize>, Option<usize>)> {
29 match ty {
30 Type::Num | Type::Int | Type::Bool => Some((Some(1), Some(1))),
31 Type::Tensor { shape: Some(shape) } | Type::Logical { shape: Some(shape) } => {
32 Some(runmat_builtins::shape_rules::matrix_dims(shape))
33 }
34 _ => None,
35 }
36 }
37
38 fn format_dim(dim: Option<usize>) -> String {
39 dim.map(|v| v.to_string())
40 .unwrap_or_else(|| "unknown".to_string())
41 }
42
43 fn format_shape(shape: &[Option<usize>]) -> String {
44 if shape.len() == 2 {
45 return format!("{} x {}", format_dim(shape[0]), format_dim(shape[1]));
46 }
47 let dims: Vec<String> = shape.iter().map(|d| format_dim(*d)).collect();
48 format!("[{}]", dims.join(", "))
49 }
50
51 fn matrix_dims_from_type(ty: &Type) -> Option<(Option<usize>, Option<usize>)> {
52 match ty {
53 Type::Tensor { shape: Some(shape) } | Type::Logical { shape: Some(shape) } => {
54 Some(runmat_builtins::shape_rules::matrix_dims(shape))
55 }
56 _ => None,
57 }
58 }
59
60 fn element_count(shape: &[Option<usize>]) -> Option<usize> {
61 runmat_builtins::shape_rules::element_count_if_known(shape)
62 }
63
64 fn vector_length(shape: &[Option<usize>]) -> Option<usize> {
65 let count = element_count(shape)?;
66 let is_vector = shape.len() == 1
67 || (shape.len() == 2
68 && (shape[0] == Some(1) || shape[1] == Some(1))
69 && shape.iter().all(|d| d.is_some()));
70 if is_vector {
71 Some(count)
72 } else {
73 None
74 }
75 }
76
77 fn tensor_literal_shape(expr: &HirExpr) -> Option<Vec<Option<usize>>> {
78 let HirExprKind::Tensor(rows) = &expr.kind else {
79 return None;
80 };
81 if rows.is_empty() {
82 return Some(vec![Some(0), Some(0)]);
83 }
84 let cols = rows.iter().map(|r| r.len()).max().unwrap_or(0);
85 Some(vec![Some(rows.len()), Some(cols)])
86 }
87
88 enum DimSpec {
89 Known(usize),
90 Unknown,
91 Negative,
92 NonInteger,
93 }
94
95 fn parse_dim(expr: &HirExpr) -> DimSpec {
96 if let Some(value) = eval_const_num(expr) {
97 if value.is_finite() {
98 let rounded = value.round();
99 if (value - rounded).abs() <= 1e-9 {
100 if rounded < 0.0 {
101 return DimSpec::Negative;
102 }
103 return DimSpec::Known(rounded as usize);
104 }
105 return DimSpec::NonInteger;
106 }
107 }
108 DimSpec::Unknown
109 }
110
111 fn type_shape_for_broadcast(ty: &Type) -> Option<Vec<Option<usize>>> {
112 match ty {
113 Type::Tensor { shape: Some(shape) } | Type::Logical { shape: Some(shape) } => {
114 Some(shape.clone())
115 }
116 Type::Num | Type::Int | Type::Bool => Some(vec![Some(1), Some(1)]),
117 _ => None,
118 }
119 }
120
121 fn check_binary(
122 op: &parser::BinOp,
123 lhs: &HirExpr,
124 rhs: &HirExpr,
125 env: &std::collections::HashMap<VarId, Type>,
126 returns: &std::collections::HashMap<String, Vec<Type>>,
127 diags: &mut Vec<HirDiagnostic>,
128 ) {
129 let lhs_ty = infer_expr_type_with_env(lhs, env, returns);
130 let rhs_ty = infer_expr_type_with_env(rhs, env, returns);
131 match op {
132 parser::BinOp::Mul => {
133 if let Some(false) =
134 runmat_builtins::shape_rules::matmul_compatible(&lhs_ty, &rhs_ty)
135 {
136 let detail = match (
137 matrix_dims_from_type(&lhs_ty),
138 matrix_dims_from_type(&rhs_ty),
139 ) {
140 (Some((lrows, lcols)), Some((rrows, rcols))) => format!(
141 "left is {} x {}, right is {} x {} (inner dimensions {} and {})",
142 format_dim(lrows),
143 format_dim(lcols),
144 format_dim(rrows),
145 format_dim(rcols),
146 format_dim(lcols),
147 format_dim(rrows)
148 ),
149 _ => "unknown shapes".to_string(),
150 };
151 diags.push(HirDiagnostic {
152 message: format!(
153 "Matrix multiply dimension mismatch: {detail} (inner dimensions must match)"
154 ),
155 span: merge_span(lhs.span, rhs.span),
156 code: "lint.shape.matmul",
157 severity: HirDiagnosticSeverity::Warning,
158 });
159 }
160 }
161 parser::BinOp::LeftDiv => {
162 if let Some(false) =
163 runmat_builtins::shape_rules::left_divide_compatible(&lhs_ty, &rhs_ty)
164 {
165 let detail = match (
166 matrix_dims_from_type(&lhs_ty),
167 matrix_dims_from_type(&rhs_ty),
168 ) {
169 (Some((lrows, _)), Some((rrows, _))) => format!(
170 "left row dimension {}, right row dimension {}",
171 format_dim(lrows),
172 format_dim(rrows)
173 ),
174 _ => "unknown shapes".to_string(),
175 };
176 diags.push(HirDiagnostic {
177 message: format!(
178 "Left divide dimension mismatch: {detail} (row dimensions must match)"
179 ),
180 span: merge_span(lhs.span, rhs.span),
181 code: "lint.shape.ldivide",
182 severity: HirDiagnosticSeverity::Warning,
183 });
184 }
185 }
186 parser::BinOp::Div => {
187 if let Some(false) =
188 runmat_builtins::shape_rules::right_divide_compatible(&lhs_ty, &rhs_ty)
189 {
190 let detail = match (
191 matrix_dims_from_type(&lhs_ty),
192 matrix_dims_from_type(&rhs_ty),
193 ) {
194 (Some((_, lcols)), Some((_, rcols))) => format!(
195 "left column dimension {}, right column dimension {}",
196 format_dim(lcols),
197 format_dim(rcols)
198 ),
199 _ => "unknown shapes".to_string(),
200 };
201 diags.push(HirDiagnostic {
202 message: format!(
203 "Right divide dimension mismatch: {detail} (column dimensions must match)"
204 ),
205 span: merge_span(lhs.span, rhs.span),
206 code: "lint.shape.rdivide",
207 severity: HirDiagnosticSeverity::Warning,
208 });
209 }
210 }
211 parser::BinOp::Add
212 | parser::BinOp::Sub
213 | parser::BinOp::ElemMul
214 | parser::BinOp::ElemDiv
215 | parser::BinOp::ElemPow
216 | parser::BinOp::ElemLeftDiv
217 | parser::BinOp::Equal
218 | parser::BinOp::NotEqual
219 | parser::BinOp::Less
220 | parser::BinOp::LessEqual
221 | parser::BinOp::Greater
222 | parser::BinOp::GreaterEqual => {
223 let lhs_shape = type_shape_for_broadcast(&lhs_ty);
224 let rhs_shape = type_shape_for_broadcast(&rhs_ty);
225 if let (Some(a), Some(b)) = (lhs_shape, rhs_shape) {
226 if let Some(false) = runmat_builtins::shape_rules::broadcast_compatible(&a, &b)
227 {
228 let detail = format!(
229 "left is {}, right is {}",
230 format_shape(&a),
231 format_shape(&b)
232 );
233 diags.push(HirDiagnostic {
234 message: format!(
235 "Elementwise/broadcast dimension mismatch: {detail} (broadcasting failed)"
236 ),
237 span: merge_span(lhs.span, rhs.span),
238 code: "lint.shape.broadcast",
239 severity: HirDiagnosticSeverity::Warning,
240 });
241 }
242 }
243 }
244 _ => {}
245 }
246 }
247
248 fn walk_expr(
249 expr: &HirExpr,
250 env: &std::collections::HashMap<VarId, Type>,
251 returns: &std::collections::HashMap<String, Vec<Type>>,
252 diags: &mut Vec<HirDiagnostic>,
253 ) {
254 match &expr.kind {
255 HirExprKind::Unary(_, inner) => walk_expr(inner, env, returns, diags),
256 HirExprKind::Binary(lhs, op, rhs) => {
257 check_binary(op, lhs, rhs, env, returns, diags);
258 walk_expr(lhs, env, returns, diags);
259 walk_expr(rhs, env, returns, diags);
260 }
261 HirExprKind::Tensor(rows) => {
262 let mut col_constraint: Option<usize> = None;
263 for row in rows {
264 let mut row_dim: Option<usize> = None;
265 let mut row_cols: Option<usize> = Some(0);
266 let mut first_span: Option<Span> = None;
267 for e in row {
268 if first_span.is_none() {
269 first_span = Some(e.span);
270 }
271 let ty = infer_expr_type_with_env(e, env, returns);
272 if let Some((rows_dim, cols_dim)) = concat_dims(&ty) {
273 if let (Some(prev), Some(curr)) = (row_dim, rows_dim) {
274 if prev != curr {
275 diags.push(HirDiagnostic {
276 message: format!(
277 "Horizontal concatenation dimension mismatch: left row dimension {prev}, right row dimension {curr} (row dimensions must match)"
278 ),
279 span: merge_span(first_span.unwrap_or(e.span), e.span),
280 code: "lint.shape.horzcat",
281 severity: HirDiagnosticSeverity::Warning,
282 });
283 }
284 }
285 if row_dim.is_none() {
286 row_dim = rows_dim;
287 }
288 match (row_cols, cols_dim) {
289 (Some(total), Some(value)) => row_cols = Some(total + value),
290 _ => row_cols = None,
291 }
292 } else {
293 row_dim = None;
294 row_cols = None;
295 }
296 }
297
298 if let (Some(prev_cols), Some(curr_cols)) = (col_constraint, row_cols) {
299 if prev_cols != curr_cols {
300 diags.push(HirDiagnostic {
301 message: format!(
302 "Vertical concatenation dimension mismatch: upper column dimension {prev_cols}, lower column dimension {curr_cols} (column dimensions must match)"
303 ),
304 span: expr.span,
305 code: "lint.shape.vertcat",
306 severity: HirDiagnosticSeverity::Warning,
307 });
308 }
309 }
310 if col_constraint.is_none() {
311 col_constraint = row_cols;
312 }
313 }
314
315 for row in rows {
316 for e in row {
317 walk_expr(e, env, returns, diags);
318 }
319 }
320 }
321 HirExprKind::Cell(rows) => {
322 for row in rows {
323 for e in row {
324 walk_expr(e, env, returns, diags);
325 }
326 }
327 }
328 HirExprKind::Index(base, idxs) | HirExprKind::IndexCell(base, idxs) => {
329 walk_expr(base, env, returns, diags);
330 for idx in idxs {
331 walk_expr(idx, env, returns, diags);
332 }
333 if matches!(expr.kind, HirExprKind::Index(_, _)) && idxs.len() == 1 {
334 let base_ty = infer_expr_type_with_env(base, env, returns);
335 let idx_ty = infer_expr_type_with_env(&idxs[0], env, returns);
336 let base_shape = match base_ty {
337 Type::Tensor { shape: Some(shape) }
338 | Type::Logical { shape: Some(shape) } => Some(shape),
339 _ => None,
340 };
341 let mask_shape = match idx_ty {
342 Type::Logical { shape: Some(shape) }
343 | Type::Tensor { shape: Some(shape) } => Some(shape),
344 _ => None,
345 };
346 if let (Some(base_shape), Some(mask_shape)) = (base_shape, mask_shape) {
347 if let (Some(base_count), Some(mask_count)) =
348 (element_count(&base_shape), element_count(&mask_shape))
349 {
350 if base_count != mask_count {
351 diags.push(HirDiagnostic {
352 message: format!(
353 "Logical index size mismatch: mask has {mask_count}, array has {base_count} (must match)"
354 ),
355 span: merge_span(base.span, idxs[0].span),
356 code: "lint.shape.logical_index",
357 severity: HirDiagnosticSeverity::Warning,
358 });
359 }
360 }
361 }
362 }
363 }
364 HirExprKind::Range(start, step, end) => {
365 walk_expr(start, env, returns, diags);
366 if let Some(step) = step.as_ref() {
367 walk_expr(step, env, returns, diags);
368 }
369 walk_expr(end, env, returns, diags);
370 }
371 HirExprKind::FuncCall(name, args) => {
372 if name.eq_ignore_ascii_case("dot") && args.len() >= 2 {
373 let lhs_ty = infer_expr_type_with_env(&args[0], env, returns);
374 let rhs_ty = infer_expr_type_with_env(&args[1], env, returns);
375 let lhs_len = match lhs_ty {
376 Type::Tensor { shape: Some(shape) }
377 | Type::Logical { shape: Some(shape) } => vector_length(&shape),
378 _ => None,
379 };
380 let rhs_len = match rhs_ty {
381 Type::Tensor { shape: Some(shape) }
382 | Type::Logical { shape: Some(shape) } => vector_length(&shape),
383 _ => None,
384 };
385 if let (Some(a), Some(b)) = (lhs_len, rhs_len) {
386 if a != b {
387 diags.push(HirDiagnostic {
388 message: format!(
389 "Dot product length mismatch: left length {a}, right length {b} (lengths must match)"
390 ),
391 span: merge_span(args[0].span, args[1].span),
392 code: "lint.shape.dot",
393 severity: HirDiagnosticSeverity::Warning,
394 });
395 }
396 }
397 }
398
399 if name.eq_ignore_ascii_case("reshape") && args.len() >= 2 {
400 let input_ty = infer_expr_type_with_env(&args[0], env, returns);
401 let input_shape = match input_ty {
402 Type::Tensor { shape: Some(shape) }
403 | Type::Logical { shape: Some(shape) } => Some(shape),
404 _ => None,
405 };
406 let mut dims: Vec<Option<usize>> = Vec::new();
407 let mut negative_count = 0usize;
408 let mut non_integer = false;
409 for arg in args.iter().skip(1) {
410 match parse_dim(arg) {
411 DimSpec::Known(value) => dims.push(Some(value)),
412 DimSpec::Negative => {
413 negative_count += 1;
414 dims.push(None);
415 }
416 DimSpec::NonInteger => {
417 non_integer = true;
418 dims.push(None);
419 }
420 DimSpec::Unknown => dims.push(None),
421 }
422 }
423 if negative_count > 1 {
424 diags.push(HirDiagnostic {
425 message:
426 "Reshape dimension mismatch: more than one negative dimension (only one allowed)"
427 .to_string(),
428 span: merge_span(args[0].span, args[1].span),
429 code: "lint.shape.reshape",
430 severity: HirDiagnosticSeverity::Warning,
431 });
432 } else if negative_count == 1 && non_integer {
433 diags.push(HirDiagnostic {
434 message:
435 "Reshape dimension mismatch: negative dimensions require integer sizes"
436 .to_string(),
437 span: merge_span(args[0].span, args[1].span),
438 code: "lint.shape.reshape",
439 severity: HirDiagnosticSeverity::Warning,
440 });
441 }
442 if non_integer {
443 diags.push(HirDiagnostic {
444 message: "Reshape dimension mismatch: non-integer dimensions"
445 .to_string(),
446 span: merge_span(args[0].span, args[1].span),
447 code: "lint.shape.reshape",
448 severity: HirDiagnosticSeverity::Warning,
449 });
450 }
451 if let Some(shape) =
452 runmat_builtins::shape_rules::constructor_shape_from_dims(&dims)
453 {
454 if let Some(input_shape) = input_shape {
455 if let (Some(in_count), Some(out_count)) =
456 (element_count(&input_shape), element_count(&shape))
457 {
458 if in_count != out_count {
459 diags.push(HirDiagnostic {
460 message: format!(
461 "Reshape element count mismatch: input has {in_count}, output has {out_count} (must match)"
462 ),
463 span: merge_span(args[0].span, args[1].span),
464 code: "lint.shape.reshape",
465 severity: HirDiagnosticSeverity::Warning,
466 });
467 }
468 }
469 }
470 }
471 }
472
473 if (name.eq_ignore_ascii_case("permute") || name.eq_ignore_ascii_case("ipermute"))
474 && args.len() >= 2
475 {
476 let input_ty = infer_expr_type_with_env(&args[0], env, returns);
477 let input_rank = match input_ty {
478 Type::Tensor { shape: Some(shape) }
479 | Type::Logical { shape: Some(shape) } => Some(shape.len()),
480 _ => None,
481 };
482 let order_rank = vector_literal_length(&args[1]);
483 if let (Some(in_rank), Some(ord_rank)) = (input_rank, order_rank) {
484 if in_rank != ord_rank {
485 diags.push(HirDiagnostic {
486 message: format!(
487 "Permute rank mismatch: input rank {in_rank}, order length {ord_rank} (must match)"
488 ),
489 span: merge_span(args[0].span, args[1].span),
490 code: "lint.shape.permute",
491 severity: HirDiagnosticSeverity::Warning,
492 });
493 }
494 }
495 if let HirExprKind::Tensor(rows) = &args[1].kind {
496 let mut seen: std::collections::BTreeSet<usize> =
497 std::collections::BTreeSet::new();
498 let mut duplicate = false;
499 let mut max_index = 0usize;
500 for row in rows {
501 for entry in row {
502 if let Some(value) = eval_const_num(entry) {
503 let rounded = value.round();
504 if (value - rounded).abs() <= 1e-9 && rounded >= 1.0 {
505 let idx = rounded as usize;
506 max_index = max_index.max(idx);
507 if !seen.insert(idx) {
508 duplicate = true;
509 }
510 }
511 }
512 }
513 }
514 if duplicate {
515 diags.push(HirDiagnostic {
516 message:
517 "Permute order mismatch: duplicate dimensions in order vector"
518 .to_string(),
519 span: args[1].span,
520 code: "lint.shape.permute",
521 severity: HirDiagnosticSeverity::Warning,
522 });
523 }
524 if let Some(in_rank) = input_rank {
525 if max_index > in_rank {
526 diags.push(HirDiagnostic {
527 message: "Permute order mismatch: order references a dimension larger than the input rank"
528 .to_string(),
529 span: args[1].span,
530 code: "lint.shape.permute",
531 severity: HirDiagnosticSeverity::Warning,
532 });
533 }
534 }
535 }
536 }
537
538 if name.eq_ignore_ascii_case("repmat") && args.len() >= 2 {
539 let mut non_integer = false;
540 let mut negative = false;
541 for arg in args.iter().skip(1) {
542 match parse_dim(arg) {
543 DimSpec::Known(_) => {}
544 DimSpec::NonInteger => non_integer = true,
545 DimSpec::Negative => negative = true,
546 _ => {}
547 }
548 }
549 if non_integer || negative {
550 let reason = if non_integer {
551 "non-integer"
552 } else {
553 "negative"
554 };
555 diags.push(HirDiagnostic {
556 message: format!(
557 "Repmat dimension mismatch: {reason} replication factors"
558 ),
559 span: merge_span(args[0].span, args[1].span),
560 code: "lint.shape.repmat",
561 severity: HirDiagnosticSeverity::Warning,
562 });
563 }
564 }
565
566 if (name.eq_ignore_ascii_case("sum")
567 || name.eq_ignore_ascii_case("mean")
568 || name.eq_ignore_ascii_case("prod")
569 || name.eq_ignore_ascii_case("min")
570 || name.eq_ignore_ascii_case("max"))
571 && args.len() >= 2
572 {
573 let input_ty = infer_expr_type_with_env(&args[0], env, returns);
574 let input_rank = match input_ty {
575 Type::Tensor { shape: Some(shape) }
576 | Type::Logical { shape: Some(shape) } => Some(shape.len()),
577 _ => None,
578 };
579 if let Some(rank) = input_rank {
580 if let DimSpec::Known(dim) = parse_dim(&args[1]) {
581 if dim == 0 || dim > rank {
582 diags.push(HirDiagnostic {
583 message: format!(
584 "Reduction dimension mismatch: dimension {dim} is out of range for rank {rank}"
585 ),
586 span: args[1].span,
587 code: "lint.shape.reduction",
588 severity: HirDiagnosticSeverity::Warning,
589 });
590 }
591 }
592 }
593 }
594
595 for arg in args {
596 walk_expr(arg, env, returns, diags);
597 }
598 }
599 HirExprKind::MethodCall(_, _, args) => {
600 for arg in args {
601 walk_expr(arg, env, returns, diags);
602 }
603 }
604 HirExprKind::Member(base, _) | HirExprKind::MemberDynamic(base, _) => {
605 walk_expr(base, env, returns, diags);
606 }
607 HirExprKind::AnonFunc { body, .. } => {
608 walk_expr(body, env, returns, diags);
609 }
610 _ => {}
611 }
612 }
613
614 fn walk_stmt(
615 stmt: &HirStmt,
616 env: &std::collections::HashMap<VarId, Type>,
617 returns: &std::collections::HashMap<String, Vec<Type>>,
618 func_envs: &std::collections::HashMap<String, std::collections::HashMap<VarId, Type>>,
619 diags: &mut Vec<HirDiagnostic>,
620 ) {
621 match stmt {
622 HirStmt::Assign(_, expr, _, _)
623 | HirStmt::ExprStmt(expr, _, _)
624 | HirStmt::MultiAssign(_, expr, _, _) => walk_expr(expr, env, returns, diags),
625 HirStmt::If {
626 cond,
627 then_body,
628 elseif_blocks,
629 else_body,
630 ..
631 } => {
632 walk_expr(cond, env, returns, diags);
633 for s in then_body {
634 walk_stmt(s, env, returns, func_envs, diags);
635 }
636 for (cond, body) in elseif_blocks {
637 walk_expr(cond, env, returns, diags);
638 for s in body {
639 walk_stmt(s, env, returns, func_envs, diags);
640 }
641 }
642 if let Some(body) = else_body {
643 for s in body {
644 walk_stmt(s, env, returns, func_envs, diags);
645 }
646 }
647 }
648 HirStmt::While { cond, body, .. } => {
649 walk_expr(cond, env, returns, diags);
650 for s in body {
651 walk_stmt(s, env, returns, func_envs, diags);
652 }
653 }
654 HirStmt::For { expr, body, .. } => {
655 walk_expr(expr, env, returns, diags);
656 for s in body {
657 walk_stmt(s, env, returns, func_envs, diags);
658 }
659 }
660 HirStmt::Switch {
661 expr,
662 cases,
663 otherwise,
664 ..
665 } => {
666 walk_expr(expr, env, returns, diags);
667 for (case_expr, case_body) in cases {
668 walk_expr(case_expr, env, returns, diags);
669 for s in case_body {
670 walk_stmt(s, env, returns, func_envs, diags);
671 }
672 }
673 if let Some(body) = otherwise {
674 for s in body {
675 walk_stmt(s, env, returns, func_envs, diags);
676 }
677 }
678 }
679 HirStmt::Function { name, body, .. } => {
680 let func_env = func_envs.get(name).cloned().unwrap_or_default();
681 for s in body {
682 walk_stmt(s, &func_env, returns, func_envs, diags);
683 }
684 }
685 HirStmt::ClassDef { members, .. } => {
686 for member in members {
687 if let HirClassMember::Methods { body, .. } = member {
688 for s in body {
689 walk_stmt(s, env, returns, func_envs, diags);
690 }
691 }
692 }
693 }
694 _ => {}
695 }
696 }
697
698 let mut diags = Vec::new();
699 let global_env = result.inferred_globals.clone();
700 for stmt in &result.hir.body {
701 walk_stmt(
702 stmt,
703 &global_env,
704 &result.inferred_function_returns,
705 &result.inferred_function_envs,
706 &mut diags,
707 );
708 }
709 diags
710}