Skip to main content

runmat_static_analysis/lints/
shape.rs

1use runmat_builtins::{BuiltinSemanticKind, ConcatKind, ShapeTransformKind};
2use runmat_hir::{BindingId, HirDiagnostic, HirDiagnosticSeverity, OperatorKind, Span};
3use runmat_mir::analysis::{AnalysisStore, MirLocalKey};
4use std::collections::HashMap;
5
6pub fn lint_shapes(result: &runmat_hir::LoweringResult) -> Vec<HirDiagnostic> {
7    let mir = match runmat_mir::lowering::lower_assembly(&result.assembly) {
8        Ok(mir) => mir,
9        Err(err) => return vec![mir_lowering_diagnostic(err)],
10    };
11    let store = runmat_mir::analysis::analyze_assembly(&mir);
12    let mut ctx = ShapeLintContext::default();
13    ctx.seed_from_analysis(&mir, &store);
14    ctx.walk_mir_assembly(&mir);
15    ctx.diagnostics
16}
17
18pub fn infer_binding_shapes(
19    result: &runmat_hir::LoweringResult,
20) -> HashMap<BindingId, Vec<Option<usize>>> {
21    let Ok(mir) = runmat_mir::lowering::lower_assembly(&result.assembly) else {
22        return HashMap::new();
23    };
24    let store = runmat_mir::analysis::analyze_assembly(&mir);
25    let mut ctx = ShapeLintContext::default();
26    ctx.seed_from_analysis(&mir, &store);
27    ctx.walk_mir_assembly(&mir);
28    ctx.env
29        .into_iter()
30        .map(|(binding, shape)| (binding, shape.0))
31        .collect()
32}
33
34fn mir_lowering_diagnostic(err: runmat_hir::HirError) -> HirDiagnostic {
35    HirDiagnostic::new(
36        "lint.mir.lowering_failed",
37        HirDiagnosticSeverity::Error,
38        format!("MIR lowering failed: {}", err.message),
39        err.span.unwrap_or(runmat_hir::Span { start: 0, end: 0 }),
40    )
41    .with_category("mir-lowering")
42}
43
44#[derive(Debug, Clone, PartialEq)]
45struct Shape(Vec<Option<usize>>);
46
47#[derive(Default)]
48struct ShapeLintContext {
49    env: HashMap<BindingId, Shape>,
50    local_env: HashMap<MirLocalKey, Shape>,
51    number_env: HashMap<MirLocalKey, f64>,
52    int_vector_env: HashMap<MirLocalKey, Vec<usize>>,
53    diagnostics: Vec<HirDiagnostic>,
54}
55
56#[derive(Default)]
57struct MirShapeValue {
58    shape: Option<Shape>,
59    number: Option<f64>,
60    int_vector: Option<Vec<usize>>,
61}
62
63impl ShapeLintContext {
64    fn seed_from_analysis(&mut self, mir: &runmat_mir::MirAssembly, store: &AnalysisStore) {
65        for body in mir.bodies.values() {
66            for local in &body.locals {
67                let Some(binding) = local.binding else {
68                    continue;
69                };
70                let Some(fact) = store.mir_locals.get(&MirLocalKey {
71                    function: body.function,
72                    local: local.id,
73                }) else {
74                    continue;
75                };
76                if let Some(shape) = shape_from_fact(&fact.shape) {
77                    self.env.insert(binding, shape);
78                }
79            }
80        }
81    }
82
83    fn walk_mir_assembly(&mut self, mir: &runmat_mir::MirAssembly) {
84        for body in mir.bodies.values() {
85            for block in &body.blocks {
86                for stmt in &block.statements {
87                    match &stmt.kind {
88                        runmat_mir::MirStmtKind::Assign { place, value } => {
89                            let value = self.infer_mir_rvalue(body, value, stmt.span);
90                            if let runmat_mir::MirPlace::Local(local) = place {
91                                self.record_mir_value(body, *local, value);
92                            }
93                        }
94                        runmat_mir::MirStmtKind::MultiAssign { value, .. }
95                        | runmat_mir::MirStmtKind::Expr(value) => {
96                            self.infer_mir_rvalue(body, value, stmt.span);
97                        }
98                        runmat_mir::MirStmtKind::PlaceMutation(_)
99                        | runmat_mir::MirStmtKind::WorkspaceEffect { .. }
100                        | runmat_mir::MirStmtKind::EnvironmentEffect(_) => {}
101                    }
102                }
103            }
104        }
105    }
106
107    fn record_mir_value(
108        &mut self,
109        body: &runmat_mir::MirBody,
110        local: runmat_mir::MirLocalId,
111        value: MirShapeValue,
112    ) {
113        let key = MirLocalKey {
114            function: body.function,
115            local,
116        };
117        if let Some(shape) = value.shape {
118            if let Some(binding) = body.locals.get(local.0).and_then(|local| local.binding) {
119                self.env.insert(binding, shape.clone());
120            }
121            self.local_env.insert(key, shape);
122        }
123        if let Some(number) = value.number {
124            self.number_env.insert(key, number);
125        }
126        if let Some(vector) = value.int_vector {
127            self.int_vector_env.insert(key, vector);
128        }
129    }
130
131    fn infer_mir_rvalue(
132        &mut self,
133        body: &runmat_mir::MirBody,
134        value: &runmat_mir::MirRvalue,
135        span: Span,
136    ) -> MirShapeValue {
137        match value {
138            runmat_mir::MirRvalue::Use(operand) => self.infer_mir_operand(body, operand),
139            runmat_mir::MirRvalue::Unary(op, operand) => {
140                let inner = self.infer_mir_operand(body, operand);
141                let number = match (op, inner.number) {
142                    (OperatorKind::UnaryMinus, Some(value)) => Some(-value),
143                    _ => None,
144                };
145                MirShapeValue {
146                    shape: inner.shape,
147                    number,
148                    int_vector: None,
149                }
150            }
151            runmat_mir::MirRvalue::Binary(left, op, right) => {
152                let lhs = self.infer_mir_operand(body, left);
153                let rhs = self.infer_mir_operand(body, right);
154                MirShapeValue {
155                    shape: self.infer_mir_binary(span, lhs.shape.as_ref(), op, rhs.shape.as_ref()),
156                    number: None,
157                    int_vector: None,
158                }
159            }
160            runmat_mir::MirRvalue::ShortCircuit {
161                left,
162                right_temps,
163                right,
164                ..
165            } => {
166                self.infer_mir_operand(body, left);
167                for stmt in right_temps {
168                    match &stmt.kind {
169                        runmat_mir::MirStmtKind::Assign { place, value } => {
170                            let inferred = self.infer_mir_rvalue(body, value, stmt.span);
171                            if let runmat_mir::MirPlace::Local(local) = place {
172                                self.record_mir_value(body, *local, inferred);
173                            }
174                        }
175                        runmat_mir::MirStmtKind::MultiAssign { value, .. }
176                        | runmat_mir::MirStmtKind::Expr(value) => {
177                            self.infer_mir_rvalue(body, value, stmt.span);
178                        }
179                        runmat_mir::MirStmtKind::PlaceMutation(_)
180                        | runmat_mir::MirStmtKind::WorkspaceEffect { .. }
181                        | runmat_mir::MirStmtKind::EnvironmentEffect(_) => {}
182                    }
183                }
184                self.infer_mir_operand(body, right);
185                MirShapeValue::default()
186            }
187            runmat_mir::MirRvalue::Range { start, step, end } => {
188                let start = self.infer_mir_operand(body, start).number;
189                let step = step
190                    .as_ref()
191                    .and_then(|step| self.infer_mir_operand(body, step).number)
192                    .unwrap_or(1.0);
193                let end = self.infer_mir_operand(body, end).number;
194                let width = start
195                    .zip(end)
196                    .and_then(|(start, end)| range_width(start, step, end));
197                MirShapeValue {
198                    shape: Some(Shape(vec![Some(1), width])),
199                    number: None,
200                    int_vector: None,
201                }
202            }
203            runmat_mir::MirRvalue::Call(call) => self.infer_mir_call(body, span, call),
204            runmat_mir::MirRvalue::Aggregate {
205                kind,
206                rows,
207                elements,
208                ..
209            } => self.infer_mir_aggregate(body, span, kind, *rows, elements),
210            runmat_mir::MirRvalue::StructLiteral { fields } => {
211                for (_, value) in fields {
212                    self.infer_mir_operand(body, value);
213                }
214                MirShapeValue {
215                    shape: Some(Shape(vec![Some(1), Some(1)])),
216                    number: None,
217                    int_vector: None,
218                }
219            }
220            runmat_mir::MirRvalue::ObjectLiteral { fields, .. } => {
221                for (_, value) in fields {
222                    self.infer_mir_operand(body, value);
223                }
224                MirShapeValue {
225                    shape: Some(Shape(vec![Some(1), Some(1)])),
226                    number: None,
227                    int_vector: None,
228                }
229            }
230            runmat_mir::MirRvalue::Index { base, indexing } => {
231                let base_shape = self.infer_mir_operand(body, base).shape;
232                for component in &indexing.components {
233                    match component {
234                        runmat_mir::MirIndexComponent::Expr(operand) => {
235                            let idx_shape = self.infer_mir_operand(body, operand).shape;
236                            if indexing.components.len() == 1 {
237                                self.check_logical_index(
238                                    span,
239                                    base_shape.as_ref(),
240                                    idx_shape.as_ref(),
241                                );
242                            }
243                        }
244                        runmat_mir::MirIndexComponent::Colon
245                        | runmat_mir::MirIndexComponent::End { .. } => {}
246                    }
247                }
248                MirShapeValue::default()
249            }
250            runmat_mir::MirRvalue::Member { base, .. } => self.infer_mir_operand(body, base),
251            runmat_mir::MirRvalue::DynamicMember { base, member } => {
252                self.infer_mir_operand(body, member);
253                self.infer_mir_operand(body, base)
254            }
255            runmat_mir::MirRvalue::Future { args, .. } => {
256                for arg in args {
257                    self.infer_mir_operand(body, arg.operand());
258                }
259                MirShapeValue::default()
260            }
261            runmat_mir::MirRvalue::Spawn(operand) => self.infer_mir_operand(body, operand),
262            runmat_mir::MirRvalue::WorkspaceFirstStaticProperty { .. } => MirShapeValue::default(),
263            runmat_mir::MirRvalue::MetaClass(_)
264            | runmat_mir::MirRvalue::Colon
265            | runmat_mir::MirRvalue::End => MirShapeValue::default(),
266        }
267    }
268
269    fn infer_mir_operand(
270        &self,
271        body: &runmat_mir::MirBody,
272        operand: &runmat_mir::MirOperand,
273    ) -> MirShapeValue {
274        match operand {
275            runmat_mir::MirOperand::Constant(runmat_mir::MirConstant::Number(value)) => {
276                MirShapeValue {
277                    shape: Some(Shape(vec![Some(1), Some(1)])),
278                    number: value.parse().ok(),
279                    int_vector: None,
280                }
281            }
282            runmat_mir::MirOperand::Local(local) => {
283                let key = MirLocalKey {
284                    function: body.function,
285                    local: *local,
286                };
287                MirShapeValue {
288                    shape: self.local_env.get(&key).cloned(),
289                    number: self.number_env.get(&key).copied(),
290                    int_vector: self.int_vector_env.get(&key).cloned(),
291                }
292            }
293            runmat_mir::MirOperand::Constant(_) | runmat_mir::MirOperand::FunctionHandle(_) => {
294                MirShapeValue::default()
295            }
296        }
297    }
298
299    fn infer_mir_binary(
300        &mut self,
301        span: Span,
302        lhs: Option<&Shape>,
303        op: &OperatorKind,
304        rhs: Option<&Shape>,
305    ) -> Option<Shape> {
306        match op {
307            OperatorKind::MatrixMultiply => {
308                if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
309                    if matrix_dims(lhs)
310                        .zip(matrix_dims(rhs))
311                        .is_some_and(|((_, lc), (rr, _))| lc.is_some() && rr.is_some() && lc != rr)
312                    {
313                        self.warn(
314                            "lint.shape.matmul",
315                            "matrix multiply inner dimensions must match",
316                            span,
317                        );
318                    }
319                }
320                match (lhs.and_then(matrix_dims), rhs.and_then(matrix_dims)) {
321                    (Some((rows, _)), Some((_, cols))) => Some(Shape(vec![rows, cols])),
322                    _ => None,
323                }
324            }
325            OperatorKind::Add
326            | OperatorKind::Subtract
327            | OperatorKind::ElementwiseMultiply
328            | OperatorKind::ElementwiseDivide
329            | OperatorKind::ElementwiseLeftDivide
330            | OperatorKind::ElementwisePower
331            | OperatorKind::Greater
332            | OperatorKind::GreaterEqual
333            | OperatorKind::Less
334            | OperatorKind::LessEqual
335            | OperatorKind::Equal
336            | OperatorKind::NotEqual => {
337                if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
338                    if !broadcast_compatible(lhs, rhs) {
339                        self.warn(
340                            "lint.shape.broadcast",
341                            "array dimensions are not broadcast compatible",
342                            span,
343                        );
344                    }
345                }
346                lhs.cloned().or_else(|| rhs.cloned())
347            }
348            _ => lhs.cloned().or_else(|| rhs.cloned()),
349        }
350    }
351
352    fn infer_mir_aggregate(
353        &mut self,
354        body: &runmat_mir::MirBody,
355        span: Span,
356        kind: &runmat_mir::MirAggregateKind,
357        rows: usize,
358        elements: &[runmat_mir::MirOperand],
359    ) -> MirShapeValue {
360        let values: Vec<_> = elements
361            .iter()
362            .map(|element| self.infer_mir_operand(body, element))
363            .collect();
364        let int_vector = values
365            .iter()
366            .map(|value| value.number.and_then(number_to_int))
367            .collect::<Option<Vec<_>>>();
368        let shape = match kind {
369            runmat_mir::MirAggregateKind::Tensor => {
370                let row_count = rows.max(1);
371                let cols_per_row = if row_count == 0 {
372                    0
373                } else {
374                    elements.len() / row_count
375                };
376                let mut row_dims = Vec::new();
377                for row_idx in 0..row_count {
378                    let start = row_idx * cols_per_row;
379                    let end = start + cols_per_row;
380                    let mut total_cols = 0usize;
381                    let mut expected_rows = None;
382                    for value in &values[start..end] {
383                        if let Some((rows, cols)) = value.shape.as_ref().and_then(matrix_dims) {
384                            if let (Some(expected), Some(rows)) = (expected_rows, rows) {
385                                if expected != rows {
386                                    self.warn(
387                                        "lint.shape.horzcat",
388                                        "horizontal concatenation row dimensions do not agree",
389                                        span,
390                                    );
391                                }
392                            }
393                            expected_rows = expected_rows.or(rows);
394                            total_cols += cols.unwrap_or(1);
395                        } else {
396                            total_cols += 1;
397                        }
398                    }
399                    row_dims.push((expected_rows.unwrap_or(1), total_cols));
400                }
401                if let Some((_, first_cols)) = row_dims.first().copied() {
402                    for (_, cols) in &row_dims {
403                        if *cols != first_cols {
404                            self.warn(
405                                "lint.shape.vertcat",
406                                "vertical concatenation column dimensions do not agree",
407                                span,
408                            );
409                        }
410                    }
411                    Some(Shape(vec![
412                        Some(row_dims.iter().map(|(rows, _)| rows).sum()),
413                        Some(first_cols),
414                    ]))
415                } else {
416                    Some(Shape(vec![Some(0), Some(0)]))
417                }
418            }
419            runmat_mir::MirAggregateKind::Cell => Some(Shape(vec![Some(1), Some(elements.len())])),
420        };
421        MirShapeValue {
422            shape,
423            number: None,
424            int_vector,
425        }
426    }
427
428    fn infer_mir_call(
429        &mut self,
430        body: &runmat_mir::MirBody,
431        span: Span,
432        call: &runmat_mir::MirCall,
433    ) -> MirShapeValue {
434        let arg_values: Vec<_> = call
435            .args
436            .iter()
437            .map(|arg| self.infer_mir_operand(body, arg.operand()))
438            .collect();
439        let shape = match call.semantic_kind {
440            BuiltinSemanticKind::Elementwise => {
441                arg_values.first().and_then(|value| value.shape.clone())
442            }
443            BuiltinSemanticKind::ArrayConstructor => sized_constructor_shape(&arg_values),
444            BuiltinSemanticKind::ParameterizedArrayConstructor => {
445                sized_constructor_shape(arg_values.get(1..).unwrap_or(&[]))
446            }
447            BuiltinSemanticKind::PermutationConstructor => Some(Shape(vec![
448                Some(1),
449                arg_values
450                    .first()
451                    .and_then(|value| value.number.and_then(number_to_int)),
452            ])),
453            BuiltinSemanticKind::RangeConstructor => Some(Shape(vec![Some(1), None])),
454            BuiltinSemanticKind::EmptyConstructor => Some(Shape(vec![Some(0), Some(0)])),
455            BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Dot) => {
456                let lhs = arg_values.first().and_then(|value| value.shape.as_ref());
457                let rhs = arg_values.get(1).and_then(|value| value.shape.as_ref());
458                if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
459                    if vector_len(lhs)
460                        .zip(vector_len(rhs))
461                        .is_some_and(|(l, r)| l != r)
462                    {
463                        self.warn(
464                            "lint.shape.dot",
465                            "dot product vector lengths do not agree",
466                            span,
467                        );
468                    }
469                }
470                Some(Shape(vec![Some(1), Some(1)]))
471            }
472            BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Reshape) => {
473                let input = arg_values.first().and_then(|value| value.shape.as_ref());
474                let dims = mir_parse_dims(&arg_values[1..]);
475                if dims.iter().filter(|dim| matches!(dim, Dim::Infer)).count() > 1
476                    || incompatible_element_count(input, &dims)
477                {
478                    self.warn(
479                        "lint.shape.reshape",
480                        "reshape dimensions are not compatible",
481                        span,
482                    );
483                }
484                Some(Shape(dims.iter().map(|dim| dim.as_shape_dim()).collect()))
485            }
486            BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Repmat) => {
487                for arg in &arg_values[1..] {
488                    if !matches!(mir_parse_dim(arg), Dim::Known(_)) {
489                        self.warn(
490                            "lint.shape.repmat",
491                            "repmat dimensions must be non-negative integers",
492                            span,
493                        );
494                    }
495                }
496                arg_values.first().and_then(|value| value.shape.clone())
497            }
498            BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Permute) => {
499                let base = arg_values.first().and_then(|value| value.shape.clone());
500                let order = arg_values.get(1).and_then(|value| value.int_vector.clone());
501                if let Some(order) = &order {
502                    let mut sorted = order.clone();
503                    sorted.sort_unstable();
504                    if sorted.windows(2).any(|pair| pair[0] == pair[1])
505                        || base
506                            .as_ref()
507                            .is_some_and(|shape| order.len() != shape.0.len())
508                    {
509                        self.warn(
510                            "lint.shape.permute",
511                            "permute order is invalid for input rank",
512                            span,
513                        );
514                    }
515                }
516                base
517            }
518            BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Transpose) => {
519                let base = arg_values.first().and_then(|value| value.shape.clone());
520                base.map(|shape| {
521                    if shape.0.len() >= 2 {
522                        Shape(vec![shape.0[1], shape.0[0]])
523                    } else {
524                        shape
525                    }
526                })
527            }
528            BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Concatenate(kind)) => {
529                self.infer_mir_concat(span, kind, &arg_values)
530            }
531            BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::General) => {
532                arg_values.first().and_then(|value| value.shape.clone())
533            }
534            BuiltinSemanticKind::Reduction => {
535                let base = arg_values.first().and_then(|value| value.shape.clone());
536                if let (Some(base_shape), Some(dim)) = (
537                    base.as_ref(),
538                    arg_values
539                        .get(1)
540                        .and_then(|value| value.number.and_then(number_to_int)),
541                ) {
542                    if dim == 0 || dim > base_shape.0.len() {
543                        self.warn(
544                            "lint.shape.reduction",
545                            "reduction dimension is out of range",
546                            span,
547                        );
548                    }
549                }
550                base
551            }
552            _ => None,
553        };
554        MirShapeValue {
555            shape,
556            number: None,
557            int_vector: None,
558        }
559    }
560
561    fn infer_mir_concat(
562        &mut self,
563        span: Span,
564        kind: ConcatKind,
565        arg_values: &[MirShapeValue],
566    ) -> Option<Shape> {
567        let (dim, values) = match kind {
568            ConcatKind::Dimension => {
569                let dim = arg_values
570                    .first()
571                    .and_then(|value| value.number.and_then(number_to_int))?;
572                (dim, &arg_values[1..])
573            }
574            ConcatKind::Horizontal => (2, arg_values),
575            ConcatKind::Vertical => (1, arg_values),
576        };
577        let shapes: Vec<_> = values
578            .iter()
579            .filter_map(|value| value.shape.as_ref())
580            .collect();
581        if shapes.is_empty() || dim == 0 {
582            return None;
583        }
584        let rank = shapes
585            .iter()
586            .map(|shape| shape.0.len())
587            .max()
588            .unwrap_or(dim);
589        let axis = dim - 1;
590        if axis >= rank {
591            return None;
592        }
593        let mut out = vec![Some(1); rank];
594        for (idx, out_dim) in out.iter_mut().enumerate().take(rank) {
595            if idx == axis {
596                *out_dim = shapes
597                    .iter()
598                    .map(|shape| shape.0.get(idx).copied().flatten())
599                    .try_fold(0usize, |sum, dim| dim.map(|dim| sum + dim));
600                continue;
601            }
602            let mut expected = None;
603            for shape in &shapes {
604                let dim = shape.0.get(idx).copied().flatten().or(Some(1));
605                if let (Some(expected), Some(dim)) = (expected, dim) {
606                    if expected != dim {
607                        self.warn(
608                            "lint.shape.concat",
609                            "concatenation dimensions do not agree",
610                            span,
611                        );
612                    }
613                }
614                expected = expected.or(dim);
615            }
616            *out_dim = expected;
617        }
618        Some(Shape(out))
619    }
620
621    fn check_logical_index(&mut self, span: Span, base: Option<&Shape>, idx: Option<&Shape>) {
622        if let (Some(base), Some(idx)) = (base, idx) {
623            if element_count(base)
624                .zip(element_count(idx))
625                .is_some_and(|(base, idx)| base != idx)
626            {
627                self.warn(
628                    "lint.shape.logical_index",
629                    "logical index shape does not match indexed value",
630                    span,
631                );
632            }
633        }
634    }
635
636    fn warn(&mut self, code: &'static str, message: &'static str, span: Span) {
637        self.diagnostics.push(
638            HirDiagnostic::new(code, HirDiagnosticSeverity::Warning, message, span)
639                .with_category("shape"),
640        );
641    }
642}
643
644#[derive(Clone, Copy, PartialEq)]
645enum Dim {
646    Known(usize),
647    Infer,
648    Unknown,
649}
650
651impl Dim {
652    fn as_shape_dim(self) -> Option<usize> {
653        match self {
654            Dim::Known(value) => Some(value),
655            Dim::Infer | Dim::Unknown => None,
656        }
657    }
658}
659
660fn number_to_int(value: f64) -> Option<usize> {
661    if value.is_finite() && value >= 0.0 && (value.fract().abs() <= 1e-9) {
662        Some(value as usize)
663    } else {
664        None
665    }
666}
667
668fn mir_parse_dim(value: &MirShapeValue) -> Dim {
669    match value.number {
670        Some(-1.0) => Dim::Infer,
671        Some(value) => number_to_int(value).map(Dim::Known).unwrap_or(Dim::Unknown),
672        None => Dim::Unknown,
673    }
674}
675
676fn mir_parse_dims(args: &[MirShapeValue]) -> Vec<Dim> {
677    if args.len() == 1 {
678        if let Some(values) = &args[0].int_vector {
679            return values.iter().copied().map(Dim::Known).collect();
680        }
681    }
682    args.iter().map(mir_parse_dim).collect()
683}
684
685fn sized_constructor_shape(args: &[MirShapeValue]) -> Option<Shape> {
686    let dims: Vec<_> = args
687        .iter()
688        .filter_map(|value| value.number.and_then(number_to_int))
689        .map(Some)
690        .collect();
691    match dims.as_slice() {
692        [] => None,
693        [dim] => Some(Shape(vec![*dim, *dim])),
694        _ => Some(Shape(dims)),
695    }
696}
697
698fn shape_from_fact(shape: &runmat_hir::ShapeFact) -> Option<Shape> {
699    match shape {
700        runmat_hir::ShapeFact::Scalar => Some(Shape(vec![Some(1), Some(1)])),
701        runmat_hir::ShapeFact::Shaped { dims } => Some(Shape(
702            dims.iter()
703                .map(|dim| match dim {
704                    runmat_hir::DimFact::Known(value) => Some(*value),
705                    runmat_hir::DimFact::Symbolic(_) | runmat_hir::DimFact::Unknown => None,
706                })
707                .collect(),
708        )),
709        runmat_hir::ShapeFact::Ranked { .. }
710        | runmat_hir::ShapeFact::Unknown
711        | runmat_hir::ShapeFact::Unreachable => None,
712    }
713}
714
715fn range_width(start: f64, step: f64, end: f64) -> Option<usize> {
716    if step == 0.0 || !start.is_finite() || !step.is_finite() || !end.is_finite() {
717        return None;
718    }
719    let span = end - start;
720    if (span > 0.0 && step < 0.0) || (span < 0.0 && step > 0.0) {
721        return Some(0);
722    }
723    Some((span / step).floor().abs() as usize + 1)
724}
725
726fn matrix_dims(shape: &Shape) -> Option<(Option<usize>, Option<usize>)> {
727    Some((*shape.0.first()?, *shape.0.get(1)?))
728}
729
730fn element_count(shape: &Shape) -> Option<usize> {
731    shape
732        .0
733        .iter()
734        .try_fold(1usize, |acc, dim| dim.map(|dim| acc * dim))
735}
736
737fn vector_len(shape: &Shape) -> Option<usize> {
738    let count = element_count(shape)?;
739    if shape.0.len() == 1
740        || (shape.0.len() == 2 && (shape.0[0] == Some(1) || shape.0[1] == Some(1)))
741    {
742        Some(count)
743    } else {
744        None
745    }
746}
747
748fn broadcast_compatible(left: &Shape, right: &Shape) -> bool {
749    let len = left.0.len().max(right.0.len());
750    (0..len).all(|idx| {
751        let l = left.0.iter().rev().nth(idx).copied().flatten().unwrap_or(1);
752        let r = right
753            .0
754            .iter()
755            .rev()
756            .nth(idx)
757            .copied()
758            .flatten()
759            .unwrap_or(1);
760        l == r || l == 1 || r == 1
761    })
762}
763
764fn incompatible_element_count(input: Option<&Shape>, dims: &[Dim]) -> bool {
765    let Some(input_count) = input.and_then(element_count) else {
766        return false;
767    };
768    if dims.iter().any(|dim| matches!(dim, Dim::Unknown)) {
769        return true;
770    }
771    let known_product = dims.iter().fold(1usize, |acc, dim| match dim {
772        Dim::Known(value) => acc * value,
773        Dim::Infer | Dim::Unknown => acc,
774    });
775    if dims.iter().any(|dim| matches!(dim, Dim::Infer)) {
776        known_product == 0 || input_count % known_product != 0
777    } else {
778        known_product != input_count
779    }
780}