Skip to main content

runmat_static_analysis/lints/
shape.rs

1use runmat_builtins::{BuiltinSemanticKind, ConcatKind, ShapeTransformKind};
2use runmat_hir::{BindingId, HirDiagnostic, HirDiagnosticSeverity, IndexKind, OperatorKind, Span};
3use runmat_mir::analysis::{AnalysisStore, MirLocalKey};
4use std::collections::HashMap;
5
6pub fn lint_shapes(result: &runmat_hir::LoweringResult) -> Vec<HirDiagnostic> {
7    let mir = match runmat_mir::lowering::lower_assembly(&result.assembly) {
8        Ok(mir) => mir,
9        Err(err) => return vec![mir_lowering_diagnostic(err)],
10    };
11    let store = runmat_mir::analysis::analyze_assembly(&mir);
12    let mut ctx = ShapeLintContext::default();
13    ctx.seed_from_analysis(&mir, &store);
14    ctx.walk_mir_assembly(&mir);
15    ctx.diagnostics
16}
17
18pub fn infer_binding_shapes(
19    result: &runmat_hir::LoweringResult,
20) -> HashMap<BindingId, Vec<Option<usize>>> {
21    let Ok(mir) = runmat_mir::lowering::lower_assembly(&result.assembly) else {
22        return HashMap::new();
23    };
24    let store = runmat_mir::analysis::analyze_assembly(&mir);
25    let mut ctx = ShapeLintContext::default();
26    ctx.seed_from_analysis(&mir, &store);
27    ctx.walk_mir_assembly(&mir);
28    ctx.env
29        .into_iter()
30        .map(|(binding, shape)| (binding, shape.0))
31        .collect()
32}
33
34fn mir_lowering_diagnostic(err: runmat_hir::HirError) -> HirDiagnostic {
35    HirDiagnostic::new(
36        "lint.mir.lowering_failed",
37        HirDiagnosticSeverity::Error,
38        format!("MIR lowering failed: {}", err.message),
39        err.span.unwrap_or(runmat_hir::Span { start: 0, end: 0 }),
40    )
41    .with_category("mir-lowering")
42}
43
44#[derive(Debug, Clone, PartialEq)]
45struct Shape(Vec<Option<usize>>);
46
47#[derive(Default)]
48struct ShapeLintContext {
49    env: HashMap<BindingId, Shape>,
50    local_env: HashMap<MirLocalKey, Shape>,
51    number_env: HashMap<MirLocalKey, f64>,
52    int_vector_env: HashMap<MirLocalKey, Vec<usize>>,
53    logical_env: HashMap<MirLocalKey, bool>,
54    diagnostics: Vec<HirDiagnostic>,
55}
56
57#[derive(Default)]
58struct MirShapeValue {
59    shape: Option<Shape>,
60    number: Option<f64>,
61    int_vector: Option<Vec<usize>>,
62    logical: bool,
63}
64
65impl ShapeLintContext {
66    fn seed_from_analysis(&mut self, mir: &runmat_mir::MirAssembly, store: &AnalysisStore) {
67        for body in mir.bodies.values() {
68            for local in &body.locals {
69                let Some(binding) = local.binding else {
70                    continue;
71                };
72                let Some(fact) = store.mir_locals.get(&MirLocalKey {
73                    function: body.function,
74                    local: local.id,
75                }) else {
76                    continue;
77                };
78                if let Some(shape) = shape_from_fact(&fact.shape) {
79                    self.env.insert(binding, shape);
80                }
81            }
82        }
83    }
84
85    fn walk_mir_assembly(&mut self, mir: &runmat_mir::MirAssembly) {
86        for body in mir.bodies.values() {
87            for block in &body.blocks {
88                for stmt in &block.statements {
89                    match &stmt.kind {
90                        runmat_mir::MirStmtKind::Assign { place, value } => {
91                            let value = self.infer_mir_rvalue(body, value, stmt.span);
92                            if let runmat_mir::MirPlace::Local(local) = place {
93                                self.record_mir_value(body, *local, value);
94                            }
95                        }
96                        runmat_mir::MirStmtKind::MultiAssign { value, .. }
97                        | runmat_mir::MirStmtKind::Expr(value) => {
98                            self.infer_mir_rvalue(body, value, stmt.span);
99                        }
100                        runmat_mir::MirStmtKind::PlaceMutation(_)
101                        | runmat_mir::MirStmtKind::WorkspaceEffect { .. }
102                        | runmat_mir::MirStmtKind::EnvironmentEffect(_) => {}
103                    }
104                }
105            }
106        }
107    }
108
109    fn record_mir_value(
110        &mut self,
111        body: &runmat_mir::MirBody,
112        local: runmat_mir::MirLocalId,
113        value: MirShapeValue,
114    ) {
115        let key = MirLocalKey {
116            function: body.function,
117            local,
118        };
119        if let Some(shape) = value.shape {
120            if let Some(binding) = body.locals.get(local.0).and_then(|local| local.binding) {
121                self.env.insert(binding, shape.clone());
122            }
123            self.local_env.insert(key, shape);
124        }
125        if let Some(number) = value.number {
126            self.number_env.insert(key, number);
127        }
128        if let Some(vector) = value.int_vector {
129            self.int_vector_env.insert(key, vector);
130        }
131        if value.logical {
132            self.logical_env.insert(key, true);
133        } else {
134            self.logical_env.remove(&key);
135        }
136    }
137
138    fn infer_mir_rvalue(
139        &mut self,
140        body: &runmat_mir::MirBody,
141        value: &runmat_mir::MirRvalue,
142        span: Span,
143    ) -> MirShapeValue {
144        match value {
145            runmat_mir::MirRvalue::Use(operand) => self.infer_mir_operand(body, operand),
146            runmat_mir::MirRvalue::Unary(op, operand) => {
147                let inner = self.infer_mir_operand(body, operand);
148                let number = match (op, inner.number) {
149                    (OperatorKind::UnaryMinus, Some(value)) => Some(-value),
150                    _ => None,
151                };
152                MirShapeValue {
153                    shape: inner.shape,
154                    number,
155                    int_vector: None,
156                    logical: matches!(op, OperatorKind::Not),
157                }
158            }
159            runmat_mir::MirRvalue::Binary(left, op, right) => {
160                let lhs = self.infer_mir_operand(body, left);
161                let rhs = self.infer_mir_operand(body, right);
162                MirShapeValue {
163                    shape: self.infer_mir_binary(span, lhs.shape.as_ref(), op, rhs.shape.as_ref()),
164                    number: None,
165                    int_vector: None,
166                    logical: is_logical_operator(op),
167                }
168            }
169            runmat_mir::MirRvalue::ShortCircuit {
170                left,
171                right_temps,
172                right,
173                ..
174            } => {
175                let left_value = self.infer_mir_operand(body, left);
176                for stmt in right_temps {
177                    match &stmt.kind {
178                        runmat_mir::MirStmtKind::Assign { place, value } => {
179                            let inferred = self.infer_mir_rvalue(body, value, stmt.span);
180                            if let runmat_mir::MirPlace::Local(local) = place {
181                                self.record_mir_value(body, *local, inferred);
182                            }
183                        }
184                        runmat_mir::MirStmtKind::MultiAssign { value, .. }
185                        | runmat_mir::MirStmtKind::Expr(value) => {
186                            self.infer_mir_rvalue(body, value, stmt.span);
187                        }
188                        runmat_mir::MirStmtKind::PlaceMutation(_)
189                        | runmat_mir::MirStmtKind::WorkspaceEffect { .. }
190                        | runmat_mir::MirStmtKind::EnvironmentEffect(_) => {}
191                    }
192                }
193                let right_value = self.infer_mir_operand(body, right);
194                MirShapeValue {
195                    shape: left_value.shape.or(right_value.shape),
196                    number: None,
197                    int_vector: None,
198                    logical: true,
199                }
200            }
201            runmat_mir::MirRvalue::Range { start, step, end } => {
202                let start = self.infer_mir_operand(body, start).number;
203                let step = step
204                    .as_ref()
205                    .and_then(|step| self.infer_mir_operand(body, step).number)
206                    .unwrap_or(1.0);
207                let end = self.infer_mir_operand(body, end).number;
208                let width = start
209                    .zip(end)
210                    .and_then(|(start, end)| range_width(start, step, end));
211                MirShapeValue {
212                    shape: Some(Shape(vec![Some(1), width])),
213                    number: None,
214                    int_vector: None,
215                    logical: false,
216                }
217            }
218            runmat_mir::MirRvalue::Call(call) => self.infer_mir_call(body, span, call),
219            runmat_mir::MirRvalue::Aggregate {
220                kind,
221                rows,
222                elements,
223                ..
224            } => self.infer_mir_aggregate(body, span, kind, *rows, elements),
225            runmat_mir::MirRvalue::StructLiteral { fields } => {
226                for (_, value) in fields {
227                    self.infer_mir_operand(body, value);
228                }
229                MirShapeValue {
230                    shape: Some(Shape(vec![Some(1), Some(1)])),
231                    number: None,
232                    int_vector: None,
233                    logical: false,
234                }
235            }
236            runmat_mir::MirRvalue::ObjectLiteral { fields, .. } => {
237                for (_, value) in fields {
238                    self.infer_mir_operand(body, value);
239                }
240                MirShapeValue {
241                    shape: Some(Shape(vec![Some(1), Some(1)])),
242                    number: None,
243                    int_vector: None,
244                    logical: false,
245                }
246            }
247            runmat_mir::MirRvalue::Index { base, indexing } => {
248                let base_shape = self.infer_mir_operand(body, base).shape;
249                let mut component_values = Vec::new();
250                for component in &indexing.components {
251                    match component {
252                        runmat_mir::MirIndexComponent::Expr(operand) => {
253                            let idx_value = self.infer_mir_operand(body, operand);
254                            if indexing.components.len() == 1 && idx_value.logical {
255                                self.check_logical_index(
256                                    span,
257                                    base_shape.as_ref(),
258                                    idx_value.shape.as_ref(),
259                                );
260                            }
261                            component_values.push(Some(idx_value));
262                        }
263                        runmat_mir::MirIndexComponent::Colon
264                        | runmat_mir::MirIndexComponent::End { .. } => {
265                            component_values.push(None);
266                        }
267                    }
268                }
269                MirShapeValue {
270                    shape: infer_mir_index_shape(base_shape.as_ref(), indexing, &component_values),
271                    number: None,
272                    int_vector: None,
273                    logical: false,
274                }
275            }
276            runmat_mir::MirRvalue::Member { base, .. } => self.infer_mir_operand(body, base),
277            runmat_mir::MirRvalue::DynamicMember { base, member } => {
278                self.infer_mir_operand(body, member);
279                self.infer_mir_operand(body, base)
280            }
281            runmat_mir::MirRvalue::Future { args, .. } => {
282                for arg in args {
283                    self.infer_mir_operand(body, arg.operand());
284                }
285                MirShapeValue::default()
286            }
287            runmat_mir::MirRvalue::Spawn(operand) => self.infer_mir_operand(body, operand),
288            runmat_mir::MirRvalue::WorkspaceFirstStaticProperty { .. } => MirShapeValue::default(),
289            runmat_mir::MirRvalue::MetaClass(_)
290            | runmat_mir::MirRvalue::Colon
291            | runmat_mir::MirRvalue::End => MirShapeValue::default(),
292        }
293    }
294
295    fn infer_mir_operand(
296        &self,
297        body: &runmat_mir::MirBody,
298        operand: &runmat_mir::MirOperand,
299    ) -> MirShapeValue {
300        match operand {
301            runmat_mir::MirOperand::Constant(runmat_mir::MirConstant::Number(value)) => {
302                MirShapeValue {
303                    shape: Some(Shape(vec![Some(1), Some(1)])),
304                    number: value.parse().ok(),
305                    int_vector: None,
306                    logical: false,
307                }
308            }
309            runmat_mir::MirOperand::Constant(runmat_mir::MirConstant::Bool(_)) => MirShapeValue {
310                shape: Some(Shape(vec![Some(1), Some(1)])),
311                number: None,
312                int_vector: None,
313                logical: true,
314            },
315            runmat_mir::MirOperand::Local(local) => {
316                let key = MirLocalKey {
317                    function: body.function,
318                    local: *local,
319                };
320                MirShapeValue {
321                    shape: self.local_env.get(&key).cloned(),
322                    number: self.number_env.get(&key).copied(),
323                    int_vector: self.int_vector_env.get(&key).cloned(),
324                    logical: self.logical_env.get(&key).copied().unwrap_or(false),
325                }
326            }
327            runmat_mir::MirOperand::Constant(_) | runmat_mir::MirOperand::FunctionHandle(_) => {
328                MirShapeValue::default()
329            }
330        }
331    }
332
333    fn infer_mir_binary(
334        &mut self,
335        span: Span,
336        lhs: Option<&Shape>,
337        op: &OperatorKind,
338        rhs: Option<&Shape>,
339    ) -> Option<Shape> {
340        match op {
341            OperatorKind::MatrixMultiply => {
342                if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
343                    if matrix_dims(lhs)
344                        .zip(matrix_dims(rhs))
345                        .is_some_and(|((_, lc), (rr, _))| lc.is_some() && rr.is_some() && lc != rr)
346                    {
347                        self.warn(
348                            "lint.shape.matmul",
349                            "matrix multiply inner dimensions must match",
350                            span,
351                        );
352                    }
353                }
354                match (lhs.and_then(matrix_dims), rhs.and_then(matrix_dims)) {
355                    (Some((rows, _)), Some((_, cols))) => Some(Shape(vec![rows, cols])),
356                    _ => None,
357                }
358            }
359            OperatorKind::Add
360            | OperatorKind::Subtract
361            | OperatorKind::ElementwiseMultiply
362            | OperatorKind::ElementwiseDivide
363            | OperatorKind::ElementwiseLeftDivide
364            | OperatorKind::ElementwisePower
365            | OperatorKind::Greater
366            | OperatorKind::GreaterEqual
367            | OperatorKind::Less
368            | OperatorKind::LessEqual
369            | OperatorKind::Equal
370            | OperatorKind::NotEqual
371            | OperatorKind::ElementwiseAnd
372            | OperatorKind::ElementwiseOr => {
373                if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
374                    if !broadcast_compatible(lhs, rhs) {
375                        self.warn(
376                            "lint.shape.broadcast",
377                            "array dimensions are not broadcast compatible",
378                            span,
379                        );
380                    }
381                }
382                lhs.cloned().or_else(|| rhs.cloned())
383            }
384            _ => lhs.cloned().or_else(|| rhs.cloned()),
385        }
386    }
387
388    fn infer_mir_aggregate(
389        &mut self,
390        body: &runmat_mir::MirBody,
391        span: Span,
392        kind: &runmat_mir::MirAggregateKind,
393        rows: usize,
394        elements: &[runmat_mir::MirOperand],
395    ) -> MirShapeValue {
396        let values: Vec<_> = elements
397            .iter()
398            .map(|element| self.infer_mir_operand(body, element))
399            .collect();
400        let int_vector = values
401            .iter()
402            .map(|value| value.number.and_then(number_to_int))
403            .collect::<Option<Vec<_>>>();
404        let shape = match kind {
405            runmat_mir::MirAggregateKind::Tensor => {
406                let row_count = rows.max(1);
407                let cols_per_row = if row_count == 0 {
408                    0
409                } else {
410                    elements.len() / row_count
411                };
412                let mut row_dims = Vec::new();
413                for row_idx in 0..row_count {
414                    let start = row_idx * cols_per_row;
415                    let end = start + cols_per_row;
416                    let mut total_cols = 0usize;
417                    let mut expected_rows = None;
418                    for value in &values[start..end] {
419                        if let Some((rows, cols)) = value.shape.as_ref().and_then(matrix_dims) {
420                            if let (Some(expected), Some(rows)) = (expected_rows, rows) {
421                                if expected != rows {
422                                    self.warn(
423                                        "lint.shape.horzcat",
424                                        "horizontal concatenation row dimensions do not agree",
425                                        span,
426                                    );
427                                }
428                            }
429                            expected_rows = expected_rows.or(rows);
430                            total_cols += cols.unwrap_or(1);
431                        } else {
432                            total_cols += 1;
433                        }
434                    }
435                    row_dims.push((expected_rows.unwrap_or(1), total_cols));
436                }
437                if let Some((_, first_cols)) = row_dims.first().copied() {
438                    for (_, cols) in &row_dims {
439                        if *cols != first_cols {
440                            self.warn(
441                                "lint.shape.vertcat",
442                                "vertical concatenation column dimensions do not agree",
443                                span,
444                            );
445                        }
446                    }
447                    Some(Shape(vec![
448                        Some(row_dims.iter().map(|(rows, _)| rows).sum()),
449                        Some(first_cols),
450                    ]))
451                } else {
452                    Some(Shape(vec![Some(0), Some(0)]))
453                }
454            }
455            runmat_mir::MirAggregateKind::Cell => Some(Shape(vec![Some(1), Some(elements.len())])),
456        };
457        MirShapeValue {
458            shape,
459            number: None,
460            int_vector,
461            logical: false,
462        }
463    }
464
465    fn infer_mir_call(
466        &mut self,
467        body: &runmat_mir::MirBody,
468        span: Span,
469        call: &runmat_mir::MirCall,
470    ) -> MirShapeValue {
471        let arg_values: Vec<_> = call
472            .args
473            .iter()
474            .map(|arg| self.infer_mir_operand(body, arg.operand()))
475            .collect();
476        let shape = match call.semantic_kind {
477            BuiltinSemanticKind::Elementwise => {
478                arg_values.first().and_then(|value| value.shape.clone())
479            }
480            BuiltinSemanticKind::ArrayConstructor => sized_constructor_shape(&arg_values),
481            BuiltinSemanticKind::ParameterizedArrayConstructor => {
482                sized_constructor_shape(arg_values.get(1..).unwrap_or(&[]))
483            }
484            BuiltinSemanticKind::PermutationConstructor => Some(Shape(vec![
485                Some(1),
486                arg_values
487                    .first()
488                    .and_then(|value| value.number.and_then(number_to_int)),
489            ])),
490            BuiltinSemanticKind::RangeConstructor => Some(Shape(vec![Some(1), None])),
491            BuiltinSemanticKind::EmptyConstructor => Some(Shape(vec![Some(0), Some(0)])),
492            BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Dot) => {
493                let lhs = arg_values.first().and_then(|value| value.shape.as_ref());
494                let rhs = arg_values.get(1).and_then(|value| value.shape.as_ref());
495                if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
496                    if vector_len(lhs)
497                        .zip(vector_len(rhs))
498                        .is_some_and(|(l, r)| l != r)
499                    {
500                        self.warn(
501                            "lint.shape.dot",
502                            "dot product vector lengths do not agree",
503                            span,
504                        );
505                    }
506                }
507                Some(Shape(vec![Some(1), Some(1)]))
508            }
509            BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Reshape) => {
510                let input = arg_values.first().and_then(|value| value.shape.as_ref());
511                let dims = mir_parse_dims(&arg_values[1..]);
512                if dims.iter().filter(|dim| matches!(dim, Dim::Infer)).count() > 1
513                    || incompatible_element_count(input, &dims)
514                {
515                    self.warn(
516                        "lint.shape.reshape",
517                        "reshape dimensions are not compatible",
518                        span,
519                    );
520                }
521                Some(Shape(dims.iter().map(|dim| dim.as_shape_dim()).collect()))
522            }
523            BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Repmat) => {
524                for arg in &arg_values[1..] {
525                    if !matches!(mir_parse_dim(arg), Dim::Known(_)) {
526                        self.warn(
527                            "lint.shape.repmat",
528                            "repmat dimensions must be non-negative integers",
529                            span,
530                        );
531                    }
532                }
533                arg_values.first().and_then(|value| value.shape.clone())
534            }
535            BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Permute) => {
536                let base = arg_values.first().and_then(|value| value.shape.clone());
537                let order = arg_values.get(1).and_then(|value| value.int_vector.clone());
538                if let Some(order) = &order {
539                    let mut sorted = order.clone();
540                    sorted.sort_unstable();
541                    if sorted.windows(2).any(|pair| pair[0] == pair[1])
542                        || base
543                            .as_ref()
544                            .is_some_and(|shape| order.len() != shape.0.len())
545                    {
546                        self.warn(
547                            "lint.shape.permute",
548                            "permute order is invalid for input rank",
549                            span,
550                        );
551                    }
552                }
553                base
554            }
555            BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Transpose) => {
556                let base = arg_values.first().and_then(|value| value.shape.clone());
557                base.map(|shape| {
558                    if shape.0.len() >= 2 {
559                        Shape(vec![shape.0[1], shape.0[0]])
560                    } else {
561                        shape
562                    }
563                })
564            }
565            BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Concatenate(kind)) => {
566                self.infer_mir_concat(span, kind, &arg_values)
567            }
568            BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::General) => {
569                arg_values.first().and_then(|value| value.shape.clone())
570            }
571            BuiltinSemanticKind::Reduction => {
572                let base = arg_values.first().and_then(|value| value.shape.clone());
573                if let (Some(base_shape), Some(dim)) = (
574                    base.as_ref(),
575                    arg_values
576                        .get(1)
577                        .and_then(|value| value.number.and_then(number_to_int)),
578                ) {
579                    if dim == 0 || dim > base_shape.0.len() {
580                        self.warn(
581                            "lint.shape.reduction",
582                            "reduction dimension is out of range",
583                            span,
584                        );
585                    }
586                }
587                base
588            }
589            _ => None,
590        };
591        MirShapeValue {
592            shape,
593            number: None,
594            int_vector: None,
595            logical: false,
596        }
597    }
598
599    fn infer_mir_concat(
600        &mut self,
601        span: Span,
602        kind: ConcatKind,
603        arg_values: &[MirShapeValue],
604    ) -> Option<Shape> {
605        let (dim, values) = match kind {
606            ConcatKind::Dimension => {
607                let dim = arg_values
608                    .first()
609                    .and_then(|value| value.number.and_then(number_to_int))?;
610                (dim, &arg_values[1..])
611            }
612            ConcatKind::Horizontal => (2, arg_values),
613            ConcatKind::Vertical => (1, arg_values),
614        };
615        let shapes: Vec<_> = values
616            .iter()
617            .filter_map(|value| value.shape.as_ref())
618            .collect();
619        if shapes.is_empty() || dim == 0 {
620            return None;
621        }
622        let rank = shapes
623            .iter()
624            .map(|shape| shape.0.len())
625            .max()
626            .unwrap_or(dim);
627        let axis = dim - 1;
628        if axis >= rank {
629            return None;
630        }
631        let mut out = vec![Some(1); rank];
632        for (idx, out_dim) in out.iter_mut().enumerate().take(rank) {
633            if idx == axis {
634                *out_dim = shapes
635                    .iter()
636                    .map(|shape| shape.0.get(idx).copied().flatten())
637                    .try_fold(0usize, |sum, dim| dim.map(|dim| sum + dim));
638                continue;
639            }
640            let mut expected = None;
641            for shape in &shapes {
642                let dim = shape.0.get(idx).copied().flatten().or(Some(1));
643                if let (Some(expected), Some(dim)) = (expected, dim) {
644                    if expected != dim {
645                        self.warn(
646                            "lint.shape.concat",
647                            "concatenation dimensions do not agree",
648                            span,
649                        );
650                    }
651                }
652                expected = expected.or(dim);
653            }
654            *out_dim = expected;
655        }
656        Some(Shape(out))
657    }
658
659    fn check_logical_index(&mut self, span: Span, base: Option<&Shape>, idx: Option<&Shape>) {
660        if let (Some(base), Some(idx)) = (base, idx) {
661            if element_count(base)
662                .zip(element_count(idx))
663                .is_some_and(|(base, idx)| base != idx)
664            {
665                self.warn(
666                    "lint.shape.logical_index",
667                    "logical index shape does not match indexed value",
668                    span,
669                );
670            }
671        }
672    }
673
674    fn warn(&mut self, code: &'static str, message: &'static str, span: Span) {
675        self.diagnostics.push(
676            HirDiagnostic::new(code, HirDiagnosticSeverity::Warning, message, span)
677                .with_category("shape"),
678        );
679    }
680}
681
682#[derive(Clone, Copy, PartialEq)]
683enum Dim {
684    Known(usize),
685    Infer,
686    Unknown,
687}
688
689impl Dim {
690    fn as_shape_dim(self) -> Option<usize> {
691        match self {
692            Dim::Known(value) => Some(value),
693            Dim::Infer | Dim::Unknown => None,
694        }
695    }
696}
697
698fn number_to_int(value: f64) -> Option<usize> {
699    if value.is_finite() && value >= 0.0 && (value.fract().abs() <= 1e-9) {
700        Some(value as usize)
701    } else {
702        None
703    }
704}
705
706fn is_logical_operator(op: &OperatorKind) -> bool {
707    matches!(
708        op,
709        OperatorKind::Not
710            | OperatorKind::Greater
711            | OperatorKind::GreaterEqual
712            | OperatorKind::Less
713            | OperatorKind::LessEqual
714            | OperatorKind::Equal
715            | OperatorKind::NotEqual
716            | OperatorKind::ShortCircuitAnd
717            | OperatorKind::ShortCircuitOr
718            | OperatorKind::ElementwiseAnd
719            | OperatorKind::ElementwiseOr
720    )
721}
722
723fn mir_parse_dim(value: &MirShapeValue) -> Dim {
724    match value.number {
725        Some(-1.0) => Dim::Infer,
726        Some(value) => number_to_int(value).map(Dim::Known).unwrap_or(Dim::Unknown),
727        None => Dim::Unknown,
728    }
729}
730
731fn mir_parse_dims(args: &[MirShapeValue]) -> Vec<Dim> {
732    if args.len() == 1 {
733        if let Some(values) = &args[0].int_vector {
734            return values.iter().copied().map(Dim::Known).collect();
735        }
736    }
737    args.iter().map(mir_parse_dim).collect()
738}
739
740fn sized_constructor_shape(args: &[MirShapeValue]) -> Option<Shape> {
741    let dims: Vec<_> = args
742        .iter()
743        .filter_map(|value| value.number.and_then(number_to_int))
744        .map(Some)
745        .collect();
746    match dims.as_slice() {
747        [] => None,
748        [dim] => Some(Shape(vec![*dim, *dim])),
749        _ => Some(Shape(dims)),
750    }
751}
752
753fn shape_from_fact(shape: &runmat_hir::ShapeFact) -> Option<Shape> {
754    match shape {
755        runmat_hir::ShapeFact::Scalar => Some(Shape(vec![Some(1), Some(1)])),
756        runmat_hir::ShapeFact::Shaped { dims } => Some(Shape(
757            dims.iter()
758                .map(|dim| match dim {
759                    runmat_hir::DimFact::Known(value) => Some(*value),
760                    runmat_hir::DimFact::Symbolic(_) | runmat_hir::DimFact::Unknown => None,
761                })
762                .collect(),
763        )),
764        runmat_hir::ShapeFact::Ranked { .. }
765        | runmat_hir::ShapeFact::Unknown
766        | runmat_hir::ShapeFact::Unreachable => None,
767    }
768}
769
770fn range_width(start: f64, step: f64, end: f64) -> Option<usize> {
771    if step == 0.0 || !start.is_finite() || !step.is_finite() || !end.is_finite() {
772        return None;
773    }
774    let span = end - start;
775    if (span > 0.0 && step < 0.0) || (span < 0.0 && step > 0.0) {
776        return Some(0);
777    }
778    Some((span / step).floor().abs() as usize + 1)
779}
780
781fn matrix_dims(shape: &Shape) -> Option<(Option<usize>, Option<usize>)> {
782    Some((*shape.0.first()?, *shape.0.get(1)?))
783}
784
785fn element_count(shape: &Shape) -> Option<usize> {
786    shape
787        .0
788        .iter()
789        .try_fold(1usize, |acc, dim| dim.map(|dim| acc * dim))
790}
791
792fn vector_len(shape: &Shape) -> Option<usize> {
793    let count = element_count(shape)?;
794    if shape.0.len() == 1
795        || (shape.0.len() == 2 && (shape.0[0] == Some(1) || shape.0[1] == Some(1)))
796    {
797        Some(count)
798    } else {
799        None
800    }
801}
802
803fn infer_mir_index_shape(
804    base: Option<&Shape>,
805    indexing: &runmat_mir::MirIndexing,
806    component_values: &[Option<MirShapeValue>],
807) -> Option<Shape> {
808    if indexing.kind != IndexKind::Paren {
809        return None;
810    }
811    if indexing.components.is_empty() {
812        return base.cloned();
813    }
814    if indexing.plan == runmat_mir::MirIndexPlan::Scalar {
815        return Some(Shape(vec![Some(1), Some(1)]));
816    }
817    if indexing.components.len() == 1 {
818        return match indexing.components.first()? {
819            runmat_mir::MirIndexComponent::Colon => base.cloned(),
820            runmat_mir::MirIndexComponent::End { .. } => Some(Shape(vec![Some(1), Some(1)])),
821            runmat_mir::MirIndexComponent::Expr(_) => {
822                let value = component_values.first()?.as_ref()?;
823                if value.logical {
824                    None
825                } else if value.number.is_some() {
826                    Some(Shape(vec![Some(1), Some(1)]))
827                } else {
828                    value.shape.clone()
829                }
830            }
831        };
832    }
833
834    let dims = indexing
835        .components
836        .iter()
837        .enumerate()
838        .map(|(idx, component)| match component {
839            runmat_mir::MirIndexComponent::Colon => {
840                base.and_then(|shape| shape.0.get(idx).copied().flatten())
841            }
842            runmat_mir::MirIndexComponent::End { .. } => Some(1),
843            runmat_mir::MirIndexComponent::Expr(_) => component_values
844                .get(idx)
845                .and_then(|value| value.as_ref())
846                .and_then(numeric_index_component_len),
847        })
848        .collect();
849    Some(Shape(dims))
850}
851
852fn numeric_index_component_len(value: &MirShapeValue) -> Option<usize> {
853    if value.logical {
854        return None;
855    }
856    if value.number.is_some() {
857        return Some(1);
858    }
859    value.shape.as_ref().and_then(vector_len)
860}
861
862fn broadcast_compatible(left: &Shape, right: &Shape) -> bool {
863    let len = left.0.len().max(right.0.len());
864    (0..len).all(|idx| {
865        let l = left.0.iter().rev().nth(idx).copied().flatten().unwrap_or(1);
866        let r = right
867            .0
868            .iter()
869            .rev()
870            .nth(idx)
871            .copied()
872            .flatten()
873            .unwrap_or(1);
874        l == r || l == 1 || r == 1
875    })
876}
877
878fn incompatible_element_count(input: Option<&Shape>, dims: &[Dim]) -> bool {
879    let Some(input_count) = input.and_then(element_count) else {
880        return false;
881    };
882    if dims.iter().any(|dim| matches!(dim, Dim::Unknown)) {
883        return true;
884    }
885    let known_product = dims.iter().fold(1usize, |acc, dim| match dim {
886        Dim::Known(value) => acc * value,
887        Dim::Infer | Dim::Unknown => acc,
888    });
889    if dims.iter().any(|dim| matches!(dim, Dim::Infer)) {
890        known_product == 0 || input_count % known_product != 0
891    } else {
892        known_product != input_count
893    }
894}