1use runmat_builtins::{BuiltinSemanticKind, ConcatKind, ShapeTransformKind};
2use runmat_hir::{BindingId, HirDiagnostic, HirDiagnosticSeverity, OperatorKind, Span};
3use runmat_mir::analysis::{AnalysisStore, MirLocalKey};
4use std::collections::HashMap;
5
6pub fn lint_shapes(result: &runmat_hir::LoweringResult) -> Vec<HirDiagnostic> {
7 let mir = match runmat_mir::lowering::lower_assembly(&result.assembly) {
8 Ok(mir) => mir,
9 Err(err) => return vec![mir_lowering_diagnostic(err)],
10 };
11 let store = runmat_mir::analysis::analyze_assembly(&mir);
12 let mut ctx = ShapeLintContext::default();
13 ctx.seed_from_analysis(&mir, &store);
14 ctx.walk_mir_assembly(&mir);
15 ctx.diagnostics
16}
17
18pub fn infer_binding_shapes(
19 result: &runmat_hir::LoweringResult,
20) -> HashMap<BindingId, Vec<Option<usize>>> {
21 let Ok(mir) = runmat_mir::lowering::lower_assembly(&result.assembly) else {
22 return HashMap::new();
23 };
24 let store = runmat_mir::analysis::analyze_assembly(&mir);
25 let mut ctx = ShapeLintContext::default();
26 ctx.seed_from_analysis(&mir, &store);
27 ctx.walk_mir_assembly(&mir);
28 ctx.env
29 .into_iter()
30 .map(|(binding, shape)| (binding, shape.0))
31 .collect()
32}
33
34fn mir_lowering_diagnostic(err: runmat_hir::HirError) -> HirDiagnostic {
35 HirDiagnostic::new(
36 "lint.mir.lowering_failed",
37 HirDiagnosticSeverity::Error,
38 format!("MIR lowering failed: {}", err.message),
39 err.span.unwrap_or(runmat_hir::Span { start: 0, end: 0 }),
40 )
41 .with_category("mir-lowering")
42}
43
44#[derive(Debug, Clone, PartialEq)]
45struct Shape(Vec<Option<usize>>);
46
47#[derive(Default)]
48struct ShapeLintContext {
49 env: HashMap<BindingId, Shape>,
50 local_env: HashMap<MirLocalKey, Shape>,
51 number_env: HashMap<MirLocalKey, f64>,
52 int_vector_env: HashMap<MirLocalKey, Vec<usize>>,
53 diagnostics: Vec<HirDiagnostic>,
54}
55
56#[derive(Default)]
57struct MirShapeValue {
58 shape: Option<Shape>,
59 number: Option<f64>,
60 int_vector: Option<Vec<usize>>,
61}
62
63impl ShapeLintContext {
64 fn seed_from_analysis(&mut self, mir: &runmat_mir::MirAssembly, store: &AnalysisStore) {
65 for body in mir.bodies.values() {
66 for local in &body.locals {
67 let Some(binding) = local.binding else {
68 continue;
69 };
70 let Some(fact) = store.mir_locals.get(&MirLocalKey {
71 function: body.function,
72 local: local.id,
73 }) else {
74 continue;
75 };
76 if let Some(shape) = shape_from_fact(&fact.shape) {
77 self.env.insert(binding, shape);
78 }
79 }
80 }
81 }
82
83 fn walk_mir_assembly(&mut self, mir: &runmat_mir::MirAssembly) {
84 for body in mir.bodies.values() {
85 for block in &body.blocks {
86 for stmt in &block.statements {
87 match &stmt.kind {
88 runmat_mir::MirStmtKind::Assign { place, value } => {
89 let value = self.infer_mir_rvalue(body, value, stmt.span);
90 if let runmat_mir::MirPlace::Local(local) = place {
91 self.record_mir_value(body, *local, value);
92 }
93 }
94 runmat_mir::MirStmtKind::MultiAssign { value, .. }
95 | runmat_mir::MirStmtKind::Expr(value) => {
96 self.infer_mir_rvalue(body, value, stmt.span);
97 }
98 runmat_mir::MirStmtKind::PlaceMutation(_)
99 | runmat_mir::MirStmtKind::WorkspaceEffect { .. }
100 | runmat_mir::MirStmtKind::EnvironmentEffect(_) => {}
101 }
102 }
103 }
104 }
105 }
106
107 fn record_mir_value(
108 &mut self,
109 body: &runmat_mir::MirBody,
110 local: runmat_mir::MirLocalId,
111 value: MirShapeValue,
112 ) {
113 let key = MirLocalKey {
114 function: body.function,
115 local,
116 };
117 if let Some(shape) = value.shape {
118 if let Some(binding) = body.locals.get(local.0).and_then(|local| local.binding) {
119 self.env.insert(binding, shape.clone());
120 }
121 self.local_env.insert(key, shape);
122 }
123 if let Some(number) = value.number {
124 self.number_env.insert(key, number);
125 }
126 if let Some(vector) = value.int_vector {
127 self.int_vector_env.insert(key, vector);
128 }
129 }
130
131 fn infer_mir_rvalue(
132 &mut self,
133 body: &runmat_mir::MirBody,
134 value: &runmat_mir::MirRvalue,
135 span: Span,
136 ) -> MirShapeValue {
137 match value {
138 runmat_mir::MirRvalue::Use(operand) => self.infer_mir_operand(body, operand),
139 runmat_mir::MirRvalue::Unary(op, operand) => {
140 let inner = self.infer_mir_operand(body, operand);
141 let number = match (op, inner.number) {
142 (OperatorKind::UnaryMinus, Some(value)) => Some(-value),
143 _ => None,
144 };
145 MirShapeValue {
146 shape: inner.shape,
147 number,
148 int_vector: None,
149 }
150 }
151 runmat_mir::MirRvalue::Binary(left, op, right) => {
152 let lhs = self.infer_mir_operand(body, left);
153 let rhs = self.infer_mir_operand(body, right);
154 MirShapeValue {
155 shape: self.infer_mir_binary(span, lhs.shape.as_ref(), op, rhs.shape.as_ref()),
156 number: None,
157 int_vector: None,
158 }
159 }
160 runmat_mir::MirRvalue::ShortCircuit {
161 left,
162 right_temps,
163 right,
164 ..
165 } => {
166 self.infer_mir_operand(body, left);
167 for stmt in right_temps {
168 match &stmt.kind {
169 runmat_mir::MirStmtKind::Assign { place, value } => {
170 let inferred = self.infer_mir_rvalue(body, value, stmt.span);
171 if let runmat_mir::MirPlace::Local(local) = place {
172 self.record_mir_value(body, *local, inferred);
173 }
174 }
175 runmat_mir::MirStmtKind::MultiAssign { value, .. }
176 | runmat_mir::MirStmtKind::Expr(value) => {
177 self.infer_mir_rvalue(body, value, stmt.span);
178 }
179 runmat_mir::MirStmtKind::PlaceMutation(_)
180 | runmat_mir::MirStmtKind::WorkspaceEffect { .. }
181 | runmat_mir::MirStmtKind::EnvironmentEffect(_) => {}
182 }
183 }
184 self.infer_mir_operand(body, right);
185 MirShapeValue::default()
186 }
187 runmat_mir::MirRvalue::Range { start, step, end } => {
188 let start = self.infer_mir_operand(body, start).number;
189 let step = step
190 .as_ref()
191 .and_then(|step| self.infer_mir_operand(body, step).number)
192 .unwrap_or(1.0);
193 let end = self.infer_mir_operand(body, end).number;
194 let width = start
195 .zip(end)
196 .and_then(|(start, end)| range_width(start, step, end));
197 MirShapeValue {
198 shape: Some(Shape(vec![Some(1), width])),
199 number: None,
200 int_vector: None,
201 }
202 }
203 runmat_mir::MirRvalue::Call(call) => self.infer_mir_call(body, span, call),
204 runmat_mir::MirRvalue::Aggregate {
205 kind,
206 rows,
207 elements,
208 ..
209 } => self.infer_mir_aggregate(body, span, kind, *rows, elements),
210 runmat_mir::MirRvalue::StructLiteral { fields } => {
211 for (_, value) in fields {
212 self.infer_mir_operand(body, value);
213 }
214 MirShapeValue {
215 shape: Some(Shape(vec![Some(1), Some(1)])),
216 number: None,
217 int_vector: None,
218 }
219 }
220 runmat_mir::MirRvalue::ObjectLiteral { fields, .. } => {
221 for (_, value) in fields {
222 self.infer_mir_operand(body, value);
223 }
224 MirShapeValue {
225 shape: Some(Shape(vec![Some(1), Some(1)])),
226 number: None,
227 int_vector: None,
228 }
229 }
230 runmat_mir::MirRvalue::Index { base, indexing } => {
231 let base_shape = self.infer_mir_operand(body, base).shape;
232 for component in &indexing.components {
233 match component {
234 runmat_mir::MirIndexComponent::Expr(operand) => {
235 let idx_shape = self.infer_mir_operand(body, operand).shape;
236 if indexing.components.len() == 1 {
237 self.check_logical_index(
238 span,
239 base_shape.as_ref(),
240 idx_shape.as_ref(),
241 );
242 }
243 }
244 runmat_mir::MirIndexComponent::Colon
245 | runmat_mir::MirIndexComponent::End { .. } => {}
246 }
247 }
248 MirShapeValue::default()
249 }
250 runmat_mir::MirRvalue::Member { base, .. } => self.infer_mir_operand(body, base),
251 runmat_mir::MirRvalue::DynamicMember { base, member } => {
252 self.infer_mir_operand(body, member);
253 self.infer_mir_operand(body, base)
254 }
255 runmat_mir::MirRvalue::Future { args, .. } => {
256 for arg in args {
257 self.infer_mir_operand(body, arg.operand());
258 }
259 MirShapeValue::default()
260 }
261 runmat_mir::MirRvalue::Spawn(operand) => self.infer_mir_operand(body, operand),
262 runmat_mir::MirRvalue::WorkspaceFirstStaticProperty { .. } => MirShapeValue::default(),
263 runmat_mir::MirRvalue::MetaClass(_)
264 | runmat_mir::MirRvalue::Colon
265 | runmat_mir::MirRvalue::End => MirShapeValue::default(),
266 }
267 }
268
269 fn infer_mir_operand(
270 &self,
271 body: &runmat_mir::MirBody,
272 operand: &runmat_mir::MirOperand,
273 ) -> MirShapeValue {
274 match operand {
275 runmat_mir::MirOperand::Constant(runmat_mir::MirConstant::Number(value)) => {
276 MirShapeValue {
277 shape: Some(Shape(vec![Some(1), Some(1)])),
278 number: value.parse().ok(),
279 int_vector: None,
280 }
281 }
282 runmat_mir::MirOperand::Local(local) => {
283 let key = MirLocalKey {
284 function: body.function,
285 local: *local,
286 };
287 MirShapeValue {
288 shape: self.local_env.get(&key).cloned(),
289 number: self.number_env.get(&key).copied(),
290 int_vector: self.int_vector_env.get(&key).cloned(),
291 }
292 }
293 runmat_mir::MirOperand::Constant(_) | runmat_mir::MirOperand::FunctionHandle(_) => {
294 MirShapeValue::default()
295 }
296 }
297 }
298
299 fn infer_mir_binary(
300 &mut self,
301 span: Span,
302 lhs: Option<&Shape>,
303 op: &OperatorKind,
304 rhs: Option<&Shape>,
305 ) -> Option<Shape> {
306 match op {
307 OperatorKind::MatrixMultiply => {
308 if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
309 if matrix_dims(lhs)
310 .zip(matrix_dims(rhs))
311 .is_some_and(|((_, lc), (rr, _))| lc.is_some() && rr.is_some() && lc != rr)
312 {
313 self.warn(
314 "lint.shape.matmul",
315 "matrix multiply inner dimensions must match",
316 span,
317 );
318 }
319 }
320 match (lhs.and_then(matrix_dims), rhs.and_then(matrix_dims)) {
321 (Some((rows, _)), Some((_, cols))) => Some(Shape(vec![rows, cols])),
322 _ => None,
323 }
324 }
325 OperatorKind::Add
326 | OperatorKind::Subtract
327 | OperatorKind::ElementwiseMultiply
328 | OperatorKind::ElementwiseDivide
329 | OperatorKind::ElementwiseLeftDivide
330 | OperatorKind::ElementwisePower
331 | OperatorKind::Greater
332 | OperatorKind::GreaterEqual
333 | OperatorKind::Less
334 | OperatorKind::LessEqual
335 | OperatorKind::Equal
336 | OperatorKind::NotEqual => {
337 if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
338 if !broadcast_compatible(lhs, rhs) {
339 self.warn(
340 "lint.shape.broadcast",
341 "array dimensions are not broadcast compatible",
342 span,
343 );
344 }
345 }
346 lhs.cloned().or_else(|| rhs.cloned())
347 }
348 _ => lhs.cloned().or_else(|| rhs.cloned()),
349 }
350 }
351
352 fn infer_mir_aggregate(
353 &mut self,
354 body: &runmat_mir::MirBody,
355 span: Span,
356 kind: &runmat_mir::MirAggregateKind,
357 rows: usize,
358 elements: &[runmat_mir::MirOperand],
359 ) -> MirShapeValue {
360 let values: Vec<_> = elements
361 .iter()
362 .map(|element| self.infer_mir_operand(body, element))
363 .collect();
364 let int_vector = values
365 .iter()
366 .map(|value| value.number.and_then(number_to_int))
367 .collect::<Option<Vec<_>>>();
368 let shape = match kind {
369 runmat_mir::MirAggregateKind::Tensor => {
370 let row_count = rows.max(1);
371 let cols_per_row = if row_count == 0 {
372 0
373 } else {
374 elements.len() / row_count
375 };
376 let mut row_dims = Vec::new();
377 for row_idx in 0..row_count {
378 let start = row_idx * cols_per_row;
379 let end = start + cols_per_row;
380 let mut total_cols = 0usize;
381 let mut expected_rows = None;
382 for value in &values[start..end] {
383 if let Some((rows, cols)) = value.shape.as_ref().and_then(matrix_dims) {
384 if let (Some(expected), Some(rows)) = (expected_rows, rows) {
385 if expected != rows {
386 self.warn(
387 "lint.shape.horzcat",
388 "horizontal concatenation row dimensions do not agree",
389 span,
390 );
391 }
392 }
393 expected_rows = expected_rows.or(rows);
394 total_cols += cols.unwrap_or(1);
395 } else {
396 total_cols += 1;
397 }
398 }
399 row_dims.push((expected_rows.unwrap_or(1), total_cols));
400 }
401 if let Some((_, first_cols)) = row_dims.first().copied() {
402 for (_, cols) in &row_dims {
403 if *cols != first_cols {
404 self.warn(
405 "lint.shape.vertcat",
406 "vertical concatenation column dimensions do not agree",
407 span,
408 );
409 }
410 }
411 Some(Shape(vec![
412 Some(row_dims.iter().map(|(rows, _)| rows).sum()),
413 Some(first_cols),
414 ]))
415 } else {
416 Some(Shape(vec![Some(0), Some(0)]))
417 }
418 }
419 runmat_mir::MirAggregateKind::Cell => Some(Shape(vec![Some(1), Some(elements.len())])),
420 };
421 MirShapeValue {
422 shape,
423 number: None,
424 int_vector,
425 }
426 }
427
428 fn infer_mir_call(
429 &mut self,
430 body: &runmat_mir::MirBody,
431 span: Span,
432 call: &runmat_mir::MirCall,
433 ) -> MirShapeValue {
434 let arg_values: Vec<_> = call
435 .args
436 .iter()
437 .map(|arg| self.infer_mir_operand(body, arg.operand()))
438 .collect();
439 let shape = match call.semantic_kind {
440 BuiltinSemanticKind::Elementwise => {
441 arg_values.first().and_then(|value| value.shape.clone())
442 }
443 BuiltinSemanticKind::ArrayConstructor => sized_constructor_shape(&arg_values),
444 BuiltinSemanticKind::ParameterizedArrayConstructor => {
445 sized_constructor_shape(arg_values.get(1..).unwrap_or(&[]))
446 }
447 BuiltinSemanticKind::PermutationConstructor => Some(Shape(vec![
448 Some(1),
449 arg_values
450 .first()
451 .and_then(|value| value.number.and_then(number_to_int)),
452 ])),
453 BuiltinSemanticKind::RangeConstructor => Some(Shape(vec![Some(1), None])),
454 BuiltinSemanticKind::EmptyConstructor => Some(Shape(vec![Some(0), Some(0)])),
455 BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Dot) => {
456 let lhs = arg_values.first().and_then(|value| value.shape.as_ref());
457 let rhs = arg_values.get(1).and_then(|value| value.shape.as_ref());
458 if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
459 if vector_len(lhs)
460 .zip(vector_len(rhs))
461 .is_some_and(|(l, r)| l != r)
462 {
463 self.warn(
464 "lint.shape.dot",
465 "dot product vector lengths do not agree",
466 span,
467 );
468 }
469 }
470 Some(Shape(vec![Some(1), Some(1)]))
471 }
472 BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Reshape) => {
473 let input = arg_values.first().and_then(|value| value.shape.as_ref());
474 let dims = mir_parse_dims(&arg_values[1..]);
475 if dims.iter().filter(|dim| matches!(dim, Dim::Infer)).count() > 1
476 || incompatible_element_count(input, &dims)
477 {
478 self.warn(
479 "lint.shape.reshape",
480 "reshape dimensions are not compatible",
481 span,
482 );
483 }
484 Some(Shape(dims.iter().map(|dim| dim.as_shape_dim()).collect()))
485 }
486 BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Repmat) => {
487 for arg in &arg_values[1..] {
488 if !matches!(mir_parse_dim(arg), Dim::Known(_)) {
489 self.warn(
490 "lint.shape.repmat",
491 "repmat dimensions must be non-negative integers",
492 span,
493 );
494 }
495 }
496 arg_values.first().and_then(|value| value.shape.clone())
497 }
498 BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Permute) => {
499 let base = arg_values.first().and_then(|value| value.shape.clone());
500 let order = arg_values.get(1).and_then(|value| value.int_vector.clone());
501 if let Some(order) = &order {
502 let mut sorted = order.clone();
503 sorted.sort_unstable();
504 if sorted.windows(2).any(|pair| pair[0] == pair[1])
505 || base
506 .as_ref()
507 .is_some_and(|shape| order.len() != shape.0.len())
508 {
509 self.warn(
510 "lint.shape.permute",
511 "permute order is invalid for input rank",
512 span,
513 );
514 }
515 }
516 base
517 }
518 BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Transpose) => {
519 let base = arg_values.first().and_then(|value| value.shape.clone());
520 base.map(|shape| {
521 if shape.0.len() >= 2 {
522 Shape(vec![shape.0[1], shape.0[0]])
523 } else {
524 shape
525 }
526 })
527 }
528 BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Concatenate(kind)) => {
529 self.infer_mir_concat(span, kind, &arg_values)
530 }
531 BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::General) => {
532 arg_values.first().and_then(|value| value.shape.clone())
533 }
534 BuiltinSemanticKind::Reduction => {
535 let base = arg_values.first().and_then(|value| value.shape.clone());
536 if let (Some(base_shape), Some(dim)) = (
537 base.as_ref(),
538 arg_values
539 .get(1)
540 .and_then(|value| value.number.and_then(number_to_int)),
541 ) {
542 if dim == 0 || dim > base_shape.0.len() {
543 self.warn(
544 "lint.shape.reduction",
545 "reduction dimension is out of range",
546 span,
547 );
548 }
549 }
550 base
551 }
552 _ => None,
553 };
554 MirShapeValue {
555 shape,
556 number: None,
557 int_vector: None,
558 }
559 }
560
561 fn infer_mir_concat(
562 &mut self,
563 span: Span,
564 kind: ConcatKind,
565 arg_values: &[MirShapeValue],
566 ) -> Option<Shape> {
567 let (dim, values) = match kind {
568 ConcatKind::Dimension => {
569 let dim = arg_values
570 .first()
571 .and_then(|value| value.number.and_then(number_to_int))?;
572 (dim, &arg_values[1..])
573 }
574 ConcatKind::Horizontal => (2, arg_values),
575 ConcatKind::Vertical => (1, arg_values),
576 };
577 let shapes: Vec<_> = values
578 .iter()
579 .filter_map(|value| value.shape.as_ref())
580 .collect();
581 if shapes.is_empty() || dim == 0 {
582 return None;
583 }
584 let rank = shapes
585 .iter()
586 .map(|shape| shape.0.len())
587 .max()
588 .unwrap_or(dim);
589 let axis = dim - 1;
590 if axis >= rank {
591 return None;
592 }
593 let mut out = vec![Some(1); rank];
594 for (idx, out_dim) in out.iter_mut().enumerate().take(rank) {
595 if idx == axis {
596 *out_dim = shapes
597 .iter()
598 .map(|shape| shape.0.get(idx).copied().flatten())
599 .try_fold(0usize, |sum, dim| dim.map(|dim| sum + dim));
600 continue;
601 }
602 let mut expected = None;
603 for shape in &shapes {
604 let dim = shape.0.get(idx).copied().flatten().or(Some(1));
605 if let (Some(expected), Some(dim)) = (expected, dim) {
606 if expected != dim {
607 self.warn(
608 "lint.shape.concat",
609 "concatenation dimensions do not agree",
610 span,
611 );
612 }
613 }
614 expected = expected.or(dim);
615 }
616 *out_dim = expected;
617 }
618 Some(Shape(out))
619 }
620
621 fn check_logical_index(&mut self, span: Span, base: Option<&Shape>, idx: Option<&Shape>) {
622 if let (Some(base), Some(idx)) = (base, idx) {
623 if element_count(base)
624 .zip(element_count(idx))
625 .is_some_and(|(base, idx)| base != idx)
626 {
627 self.warn(
628 "lint.shape.logical_index",
629 "logical index shape does not match indexed value",
630 span,
631 );
632 }
633 }
634 }
635
636 fn warn(&mut self, code: &'static str, message: &'static str, span: Span) {
637 self.diagnostics.push(
638 HirDiagnostic::new(code, HirDiagnosticSeverity::Warning, message, span)
639 .with_category("shape"),
640 );
641 }
642}
643
644#[derive(Clone, Copy, PartialEq)]
645enum Dim {
646 Known(usize),
647 Infer,
648 Unknown,
649}
650
651impl Dim {
652 fn as_shape_dim(self) -> Option<usize> {
653 match self {
654 Dim::Known(value) => Some(value),
655 Dim::Infer | Dim::Unknown => None,
656 }
657 }
658}
659
660fn number_to_int(value: f64) -> Option<usize> {
661 if value.is_finite() && value >= 0.0 && (value.fract().abs() <= 1e-9) {
662 Some(value as usize)
663 } else {
664 None
665 }
666}
667
668fn mir_parse_dim(value: &MirShapeValue) -> Dim {
669 match value.number {
670 Some(-1.0) => Dim::Infer,
671 Some(value) => number_to_int(value).map(Dim::Known).unwrap_or(Dim::Unknown),
672 None => Dim::Unknown,
673 }
674}
675
676fn mir_parse_dims(args: &[MirShapeValue]) -> Vec<Dim> {
677 if args.len() == 1 {
678 if let Some(values) = &args[0].int_vector {
679 return values.iter().copied().map(Dim::Known).collect();
680 }
681 }
682 args.iter().map(mir_parse_dim).collect()
683}
684
685fn sized_constructor_shape(args: &[MirShapeValue]) -> Option<Shape> {
686 let dims: Vec<_> = args
687 .iter()
688 .filter_map(|value| value.number.and_then(number_to_int))
689 .map(Some)
690 .collect();
691 match dims.as_slice() {
692 [] => None,
693 [dim] => Some(Shape(vec![*dim, *dim])),
694 _ => Some(Shape(dims)),
695 }
696}
697
698fn shape_from_fact(shape: &runmat_hir::ShapeFact) -> Option<Shape> {
699 match shape {
700 runmat_hir::ShapeFact::Scalar => Some(Shape(vec![Some(1), Some(1)])),
701 runmat_hir::ShapeFact::Shaped { dims } => Some(Shape(
702 dims.iter()
703 .map(|dim| match dim {
704 runmat_hir::DimFact::Known(value) => Some(*value),
705 runmat_hir::DimFact::Symbolic(_) | runmat_hir::DimFact::Unknown => None,
706 })
707 .collect(),
708 )),
709 runmat_hir::ShapeFact::Ranked { .. }
710 | runmat_hir::ShapeFact::Unknown
711 | runmat_hir::ShapeFact::Unreachable => None,
712 }
713}
714
715fn range_width(start: f64, step: f64, end: f64) -> Option<usize> {
716 if step == 0.0 || !start.is_finite() || !step.is_finite() || !end.is_finite() {
717 return None;
718 }
719 let span = end - start;
720 if (span > 0.0 && step < 0.0) || (span < 0.0 && step > 0.0) {
721 return Some(0);
722 }
723 Some((span / step).floor().abs() as usize + 1)
724}
725
726fn matrix_dims(shape: &Shape) -> Option<(Option<usize>, Option<usize>)> {
727 Some((*shape.0.first()?, *shape.0.get(1)?))
728}
729
730fn element_count(shape: &Shape) -> Option<usize> {
731 shape
732 .0
733 .iter()
734 .try_fold(1usize, |acc, dim| dim.map(|dim| acc * dim))
735}
736
737fn vector_len(shape: &Shape) -> Option<usize> {
738 let count = element_count(shape)?;
739 if shape.0.len() == 1
740 || (shape.0.len() == 2 && (shape.0[0] == Some(1) || shape.0[1] == Some(1)))
741 {
742 Some(count)
743 } else {
744 None
745 }
746}
747
748fn broadcast_compatible(left: &Shape, right: &Shape) -> bool {
749 let len = left.0.len().max(right.0.len());
750 (0..len).all(|idx| {
751 let l = left.0.iter().rev().nth(idx).copied().flatten().unwrap_or(1);
752 let r = right
753 .0
754 .iter()
755 .rev()
756 .nth(idx)
757 .copied()
758 .flatten()
759 .unwrap_or(1);
760 l == r || l == 1 || r == 1
761 })
762}
763
764fn incompatible_element_count(input: Option<&Shape>, dims: &[Dim]) -> bool {
765 let Some(input_count) = input.and_then(element_count) else {
766 return false;
767 };
768 if dims.iter().any(|dim| matches!(dim, Dim::Unknown)) {
769 return true;
770 }
771 let known_product = dims.iter().fold(1usize, |acc, dim| match dim {
772 Dim::Known(value) => acc * value,
773 Dim::Infer | Dim::Unknown => acc,
774 });
775 if dims.iter().any(|dim| matches!(dim, Dim::Infer)) {
776 known_product == 0 || input_count % known_product != 0
777 } else {
778 known_product != input_count
779 }
780}