1use runmat_builtins::{BuiltinSemanticKind, ConcatKind, ShapeTransformKind};
2use runmat_hir::{BindingId, HirDiagnostic, HirDiagnosticSeverity, IndexKind, OperatorKind, Span};
3use runmat_mir::analysis::{AnalysisStore, MirLocalKey};
4use std::collections::HashMap;
5
6pub fn lint_shapes(result: &runmat_hir::LoweringResult) -> Vec<HirDiagnostic> {
7 let mir = match runmat_mir::lowering::lower_assembly(&result.assembly) {
8 Ok(mir) => mir,
9 Err(err) => return vec![mir_lowering_diagnostic(err)],
10 };
11 let store = runmat_mir::analysis::analyze_assembly(&mir);
12 let mut ctx = ShapeLintContext::default();
13 ctx.seed_from_analysis(&mir, &store);
14 ctx.walk_mir_assembly(&mir);
15 ctx.diagnostics
16}
17
18pub fn infer_binding_shapes(
19 result: &runmat_hir::LoweringResult,
20) -> HashMap<BindingId, Vec<Option<usize>>> {
21 let Ok(mir) = runmat_mir::lowering::lower_assembly(&result.assembly) else {
22 return HashMap::new();
23 };
24 let store = runmat_mir::analysis::analyze_assembly(&mir);
25 let mut ctx = ShapeLintContext::default();
26 ctx.seed_from_analysis(&mir, &store);
27 ctx.walk_mir_assembly(&mir);
28 ctx.env
29 .into_iter()
30 .map(|(binding, shape)| (binding, shape.0))
31 .collect()
32}
33
34fn mir_lowering_diagnostic(err: runmat_hir::HirError) -> HirDiagnostic {
35 HirDiagnostic::new(
36 "lint.mir.lowering_failed",
37 HirDiagnosticSeverity::Error,
38 format!("MIR lowering failed: {}", err.message),
39 err.span.unwrap_or(runmat_hir::Span { start: 0, end: 0 }),
40 )
41 .with_category("mir-lowering")
42}
43
44#[derive(Debug, Clone, PartialEq)]
45struct Shape(Vec<Option<usize>>);
46
47#[derive(Default)]
48struct ShapeLintContext {
49 env: HashMap<BindingId, Shape>,
50 local_env: HashMap<MirLocalKey, Shape>,
51 number_env: HashMap<MirLocalKey, f64>,
52 int_vector_env: HashMap<MirLocalKey, Vec<usize>>,
53 logical_env: HashMap<MirLocalKey, bool>,
54 diagnostics: Vec<HirDiagnostic>,
55}
56
57#[derive(Default)]
58struct MirShapeValue {
59 shape: Option<Shape>,
60 number: Option<f64>,
61 int_vector: Option<Vec<usize>>,
62 logical: bool,
63}
64
65impl ShapeLintContext {
66 fn seed_from_analysis(&mut self, mir: &runmat_mir::MirAssembly, store: &AnalysisStore) {
67 for body in mir.bodies.values() {
68 for local in &body.locals {
69 let Some(binding) = local.binding else {
70 continue;
71 };
72 let Some(fact) = store.mir_locals.get(&MirLocalKey {
73 function: body.function,
74 local: local.id,
75 }) else {
76 continue;
77 };
78 if let Some(shape) = shape_from_fact(&fact.shape) {
79 self.env.insert(binding, shape);
80 }
81 }
82 }
83 }
84
85 fn walk_mir_assembly(&mut self, mir: &runmat_mir::MirAssembly) {
86 for body in mir.bodies.values() {
87 for block in &body.blocks {
88 for stmt in &block.statements {
89 match &stmt.kind {
90 runmat_mir::MirStmtKind::Assign { place, value } => {
91 let value = self.infer_mir_rvalue(body, value, stmt.span);
92 if let runmat_mir::MirPlace::Local(local) = place {
93 self.record_mir_value(body, *local, value);
94 }
95 }
96 runmat_mir::MirStmtKind::MultiAssign { value, .. }
97 | runmat_mir::MirStmtKind::Expr(value) => {
98 self.infer_mir_rvalue(body, value, stmt.span);
99 }
100 runmat_mir::MirStmtKind::PlaceMutation(_)
101 | runmat_mir::MirStmtKind::WorkspaceEffect { .. }
102 | runmat_mir::MirStmtKind::EnvironmentEffect(_) => {}
103 }
104 }
105 }
106 }
107 }
108
109 fn record_mir_value(
110 &mut self,
111 body: &runmat_mir::MirBody,
112 local: runmat_mir::MirLocalId,
113 value: MirShapeValue,
114 ) {
115 let key = MirLocalKey {
116 function: body.function,
117 local,
118 };
119 if let Some(shape) = value.shape {
120 if let Some(binding) = body.locals.get(local.0).and_then(|local| local.binding) {
121 self.env.insert(binding, shape.clone());
122 }
123 self.local_env.insert(key, shape);
124 }
125 if let Some(number) = value.number {
126 self.number_env.insert(key, number);
127 }
128 if let Some(vector) = value.int_vector {
129 self.int_vector_env.insert(key, vector);
130 }
131 if value.logical {
132 self.logical_env.insert(key, true);
133 } else {
134 self.logical_env.remove(&key);
135 }
136 }
137
138 fn infer_mir_rvalue(
139 &mut self,
140 body: &runmat_mir::MirBody,
141 value: &runmat_mir::MirRvalue,
142 span: Span,
143 ) -> MirShapeValue {
144 match value {
145 runmat_mir::MirRvalue::Use(operand) => self.infer_mir_operand(body, operand),
146 runmat_mir::MirRvalue::Unary(op, operand) => {
147 let inner = self.infer_mir_operand(body, operand);
148 let number = match (op, inner.number) {
149 (OperatorKind::UnaryMinus, Some(value)) => Some(-value),
150 _ => None,
151 };
152 MirShapeValue {
153 shape: inner.shape,
154 number,
155 int_vector: None,
156 logical: matches!(op, OperatorKind::Not),
157 }
158 }
159 runmat_mir::MirRvalue::Binary(left, op, right) => {
160 let lhs = self.infer_mir_operand(body, left);
161 let rhs = self.infer_mir_operand(body, right);
162 MirShapeValue {
163 shape: self.infer_mir_binary(span, lhs.shape.as_ref(), op, rhs.shape.as_ref()),
164 number: None,
165 int_vector: None,
166 logical: is_logical_operator(op),
167 }
168 }
169 runmat_mir::MirRvalue::ShortCircuit {
170 left,
171 right_temps,
172 right,
173 ..
174 } => {
175 let left_value = self.infer_mir_operand(body, left);
176 for stmt in right_temps {
177 match &stmt.kind {
178 runmat_mir::MirStmtKind::Assign { place, value } => {
179 let inferred = self.infer_mir_rvalue(body, value, stmt.span);
180 if let runmat_mir::MirPlace::Local(local) = place {
181 self.record_mir_value(body, *local, inferred);
182 }
183 }
184 runmat_mir::MirStmtKind::MultiAssign { value, .. }
185 | runmat_mir::MirStmtKind::Expr(value) => {
186 self.infer_mir_rvalue(body, value, stmt.span);
187 }
188 runmat_mir::MirStmtKind::PlaceMutation(_)
189 | runmat_mir::MirStmtKind::WorkspaceEffect { .. }
190 | runmat_mir::MirStmtKind::EnvironmentEffect(_) => {}
191 }
192 }
193 let right_value = self.infer_mir_operand(body, right);
194 MirShapeValue {
195 shape: left_value.shape.or(right_value.shape),
196 number: None,
197 int_vector: None,
198 logical: true,
199 }
200 }
201 runmat_mir::MirRvalue::Range { start, step, end } => {
202 let start = self.infer_mir_operand(body, start).number;
203 let step = step
204 .as_ref()
205 .and_then(|step| self.infer_mir_operand(body, step).number)
206 .unwrap_or(1.0);
207 let end = self.infer_mir_operand(body, end).number;
208 let width = start
209 .zip(end)
210 .and_then(|(start, end)| range_width(start, step, end));
211 MirShapeValue {
212 shape: Some(Shape(vec![Some(1), width])),
213 number: None,
214 int_vector: None,
215 logical: false,
216 }
217 }
218 runmat_mir::MirRvalue::Call(call) => self.infer_mir_call(body, span, call),
219 runmat_mir::MirRvalue::Aggregate {
220 kind,
221 rows,
222 elements,
223 ..
224 } => self.infer_mir_aggregate(body, span, kind, *rows, elements),
225 runmat_mir::MirRvalue::StructLiteral { fields } => {
226 for (_, value) in fields {
227 self.infer_mir_operand(body, value);
228 }
229 MirShapeValue {
230 shape: Some(Shape(vec![Some(1), Some(1)])),
231 number: None,
232 int_vector: None,
233 logical: false,
234 }
235 }
236 runmat_mir::MirRvalue::ObjectLiteral { fields, .. } => {
237 for (_, value) in fields {
238 self.infer_mir_operand(body, value);
239 }
240 MirShapeValue {
241 shape: Some(Shape(vec![Some(1), Some(1)])),
242 number: None,
243 int_vector: None,
244 logical: false,
245 }
246 }
247 runmat_mir::MirRvalue::Index { base, indexing } => {
248 let base_shape = self.infer_mir_operand(body, base).shape;
249 let mut component_values = Vec::new();
250 for component in &indexing.components {
251 match component {
252 runmat_mir::MirIndexComponent::Expr(operand) => {
253 let idx_value = self.infer_mir_operand(body, operand);
254 if indexing.components.len() == 1 && idx_value.logical {
255 self.check_logical_index(
256 span,
257 base_shape.as_ref(),
258 idx_value.shape.as_ref(),
259 );
260 }
261 component_values.push(Some(idx_value));
262 }
263 runmat_mir::MirIndexComponent::Colon
264 | runmat_mir::MirIndexComponent::End { .. } => {
265 component_values.push(None);
266 }
267 }
268 }
269 MirShapeValue {
270 shape: infer_mir_index_shape(base_shape.as_ref(), indexing, &component_values),
271 number: None,
272 int_vector: None,
273 logical: false,
274 }
275 }
276 runmat_mir::MirRvalue::Member { base, .. } => self.infer_mir_operand(body, base),
277 runmat_mir::MirRvalue::DynamicMember { base, member } => {
278 self.infer_mir_operand(body, member);
279 self.infer_mir_operand(body, base)
280 }
281 runmat_mir::MirRvalue::Future { args, .. } => {
282 for arg in args {
283 self.infer_mir_operand(body, arg.operand());
284 }
285 MirShapeValue::default()
286 }
287 runmat_mir::MirRvalue::Spawn(operand) => self.infer_mir_operand(body, operand),
288 runmat_mir::MirRvalue::WorkspaceFirstStaticProperty { .. } => MirShapeValue::default(),
289 runmat_mir::MirRvalue::MetaClass(_)
290 | runmat_mir::MirRvalue::Colon
291 | runmat_mir::MirRvalue::End => MirShapeValue::default(),
292 }
293 }
294
295 fn infer_mir_operand(
296 &self,
297 body: &runmat_mir::MirBody,
298 operand: &runmat_mir::MirOperand,
299 ) -> MirShapeValue {
300 match operand {
301 runmat_mir::MirOperand::Constant(runmat_mir::MirConstant::Number(value)) => {
302 MirShapeValue {
303 shape: Some(Shape(vec![Some(1), Some(1)])),
304 number: value.parse().ok(),
305 int_vector: None,
306 logical: false,
307 }
308 }
309 runmat_mir::MirOperand::Constant(runmat_mir::MirConstant::Bool(_)) => MirShapeValue {
310 shape: Some(Shape(vec![Some(1), Some(1)])),
311 number: None,
312 int_vector: None,
313 logical: true,
314 },
315 runmat_mir::MirOperand::Local(local) => {
316 let key = MirLocalKey {
317 function: body.function,
318 local: *local,
319 };
320 MirShapeValue {
321 shape: self.local_env.get(&key).cloned(),
322 number: self.number_env.get(&key).copied(),
323 int_vector: self.int_vector_env.get(&key).cloned(),
324 logical: self.logical_env.get(&key).copied().unwrap_or(false),
325 }
326 }
327 runmat_mir::MirOperand::Constant(_) | runmat_mir::MirOperand::FunctionHandle(_) => {
328 MirShapeValue::default()
329 }
330 }
331 }
332
333 fn infer_mir_binary(
334 &mut self,
335 span: Span,
336 lhs: Option<&Shape>,
337 op: &OperatorKind,
338 rhs: Option<&Shape>,
339 ) -> Option<Shape> {
340 match op {
341 OperatorKind::MatrixMultiply => {
342 if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
343 if matrix_dims(lhs)
344 .zip(matrix_dims(rhs))
345 .is_some_and(|((_, lc), (rr, _))| lc.is_some() && rr.is_some() && lc != rr)
346 {
347 self.warn(
348 "lint.shape.matmul",
349 "matrix multiply inner dimensions must match",
350 span,
351 );
352 }
353 }
354 match (lhs.and_then(matrix_dims), rhs.and_then(matrix_dims)) {
355 (Some((rows, _)), Some((_, cols))) => Some(Shape(vec![rows, cols])),
356 _ => None,
357 }
358 }
359 OperatorKind::Add
360 | OperatorKind::Subtract
361 | OperatorKind::ElementwiseMultiply
362 | OperatorKind::ElementwiseDivide
363 | OperatorKind::ElementwiseLeftDivide
364 | OperatorKind::ElementwisePower
365 | OperatorKind::Greater
366 | OperatorKind::GreaterEqual
367 | OperatorKind::Less
368 | OperatorKind::LessEqual
369 | OperatorKind::Equal
370 | OperatorKind::NotEqual
371 | OperatorKind::ElementwiseAnd
372 | OperatorKind::ElementwiseOr => {
373 if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
374 if !broadcast_compatible(lhs, rhs) {
375 self.warn(
376 "lint.shape.broadcast",
377 "array dimensions are not broadcast compatible",
378 span,
379 );
380 }
381 }
382 lhs.cloned().or_else(|| rhs.cloned())
383 }
384 _ => lhs.cloned().or_else(|| rhs.cloned()),
385 }
386 }
387
388 fn infer_mir_aggregate(
389 &mut self,
390 body: &runmat_mir::MirBody,
391 span: Span,
392 kind: &runmat_mir::MirAggregateKind,
393 rows: usize,
394 elements: &[runmat_mir::MirOperand],
395 ) -> MirShapeValue {
396 let values: Vec<_> = elements
397 .iter()
398 .map(|element| self.infer_mir_operand(body, element))
399 .collect();
400 let int_vector = values
401 .iter()
402 .map(|value| value.number.and_then(number_to_int))
403 .collect::<Option<Vec<_>>>();
404 let shape = match kind {
405 runmat_mir::MirAggregateKind::Tensor => {
406 let row_count = rows.max(1);
407 let cols_per_row = if row_count == 0 {
408 0
409 } else {
410 elements.len() / row_count
411 };
412 let mut row_dims = Vec::new();
413 for row_idx in 0..row_count {
414 let start = row_idx * cols_per_row;
415 let end = start + cols_per_row;
416 let mut total_cols = 0usize;
417 let mut expected_rows = None;
418 for value in &values[start..end] {
419 if let Some((rows, cols)) = value.shape.as_ref().and_then(matrix_dims) {
420 if let (Some(expected), Some(rows)) = (expected_rows, rows) {
421 if expected != rows {
422 self.warn(
423 "lint.shape.horzcat",
424 "horizontal concatenation row dimensions do not agree",
425 span,
426 );
427 }
428 }
429 expected_rows = expected_rows.or(rows);
430 total_cols += cols.unwrap_or(1);
431 } else {
432 total_cols += 1;
433 }
434 }
435 row_dims.push((expected_rows.unwrap_or(1), total_cols));
436 }
437 if let Some((_, first_cols)) = row_dims.first().copied() {
438 for (_, cols) in &row_dims {
439 if *cols != first_cols {
440 self.warn(
441 "lint.shape.vertcat",
442 "vertical concatenation column dimensions do not agree",
443 span,
444 );
445 }
446 }
447 Some(Shape(vec![
448 Some(row_dims.iter().map(|(rows, _)| rows).sum()),
449 Some(first_cols),
450 ]))
451 } else {
452 Some(Shape(vec![Some(0), Some(0)]))
453 }
454 }
455 runmat_mir::MirAggregateKind::Cell => Some(Shape(vec![Some(1), Some(elements.len())])),
456 };
457 MirShapeValue {
458 shape,
459 number: None,
460 int_vector,
461 logical: false,
462 }
463 }
464
465 fn infer_mir_call(
466 &mut self,
467 body: &runmat_mir::MirBody,
468 span: Span,
469 call: &runmat_mir::MirCall,
470 ) -> MirShapeValue {
471 let arg_values: Vec<_> = call
472 .args
473 .iter()
474 .map(|arg| self.infer_mir_operand(body, arg.operand()))
475 .collect();
476 let shape = match call.semantic_kind {
477 BuiltinSemanticKind::Elementwise => {
478 arg_values.first().and_then(|value| value.shape.clone())
479 }
480 BuiltinSemanticKind::ArrayConstructor => sized_constructor_shape(&arg_values),
481 BuiltinSemanticKind::ParameterizedArrayConstructor => {
482 sized_constructor_shape(arg_values.get(1..).unwrap_or(&[]))
483 }
484 BuiltinSemanticKind::PermutationConstructor => Some(Shape(vec![
485 Some(1),
486 arg_values
487 .first()
488 .and_then(|value| value.number.and_then(number_to_int)),
489 ])),
490 BuiltinSemanticKind::RangeConstructor => Some(Shape(vec![Some(1), None])),
491 BuiltinSemanticKind::EmptyConstructor => Some(Shape(vec![Some(0), Some(0)])),
492 BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Dot) => {
493 let lhs = arg_values.first().and_then(|value| value.shape.as_ref());
494 let rhs = arg_values.get(1).and_then(|value| value.shape.as_ref());
495 if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
496 if vector_len(lhs)
497 .zip(vector_len(rhs))
498 .is_some_and(|(l, r)| l != r)
499 {
500 self.warn(
501 "lint.shape.dot",
502 "dot product vector lengths do not agree",
503 span,
504 );
505 }
506 }
507 Some(Shape(vec![Some(1), Some(1)]))
508 }
509 BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Reshape) => {
510 let input = arg_values.first().and_then(|value| value.shape.as_ref());
511 let dims = mir_parse_dims(&arg_values[1..]);
512 if dims.iter().filter(|dim| matches!(dim, Dim::Infer)).count() > 1
513 || incompatible_element_count(input, &dims)
514 {
515 self.warn(
516 "lint.shape.reshape",
517 "reshape dimensions are not compatible",
518 span,
519 );
520 }
521 Some(Shape(dims.iter().map(|dim| dim.as_shape_dim()).collect()))
522 }
523 BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Repmat) => {
524 for arg in &arg_values[1..] {
525 if !matches!(mir_parse_dim(arg), Dim::Known(_)) {
526 self.warn(
527 "lint.shape.repmat",
528 "repmat dimensions must be non-negative integers",
529 span,
530 );
531 }
532 }
533 arg_values.first().and_then(|value| value.shape.clone())
534 }
535 BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Permute) => {
536 let base = arg_values.first().and_then(|value| value.shape.clone());
537 let order = arg_values.get(1).and_then(|value| value.int_vector.clone());
538 if let Some(order) = &order {
539 let mut sorted = order.clone();
540 sorted.sort_unstable();
541 if sorted.windows(2).any(|pair| pair[0] == pair[1])
542 || base
543 .as_ref()
544 .is_some_and(|shape| order.len() != shape.0.len())
545 {
546 self.warn(
547 "lint.shape.permute",
548 "permute order is invalid for input rank",
549 span,
550 );
551 }
552 }
553 base
554 }
555 BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Transpose) => {
556 let base = arg_values.first().and_then(|value| value.shape.clone());
557 base.map(|shape| {
558 if shape.0.len() >= 2 {
559 Shape(vec![shape.0[1], shape.0[0]])
560 } else {
561 shape
562 }
563 })
564 }
565 BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::Concatenate(kind)) => {
566 self.infer_mir_concat(span, kind, &arg_values)
567 }
568 BuiltinSemanticKind::ShapeTransform(ShapeTransformKind::General) => {
569 arg_values.first().and_then(|value| value.shape.clone())
570 }
571 BuiltinSemanticKind::Reduction => {
572 let base = arg_values.first().and_then(|value| value.shape.clone());
573 if let (Some(base_shape), Some(dim)) = (
574 base.as_ref(),
575 arg_values
576 .get(1)
577 .and_then(|value| value.number.and_then(number_to_int)),
578 ) {
579 if dim == 0 || dim > base_shape.0.len() {
580 self.warn(
581 "lint.shape.reduction",
582 "reduction dimension is out of range",
583 span,
584 );
585 }
586 }
587 base
588 }
589 _ => None,
590 };
591 MirShapeValue {
592 shape,
593 number: None,
594 int_vector: None,
595 logical: false,
596 }
597 }
598
599 fn infer_mir_concat(
600 &mut self,
601 span: Span,
602 kind: ConcatKind,
603 arg_values: &[MirShapeValue],
604 ) -> Option<Shape> {
605 let (dim, values) = match kind {
606 ConcatKind::Dimension => {
607 let dim = arg_values
608 .first()
609 .and_then(|value| value.number.and_then(number_to_int))?;
610 (dim, &arg_values[1..])
611 }
612 ConcatKind::Horizontal => (2, arg_values),
613 ConcatKind::Vertical => (1, arg_values),
614 };
615 let shapes: Vec<_> = values
616 .iter()
617 .filter_map(|value| value.shape.as_ref())
618 .collect();
619 if shapes.is_empty() || dim == 0 {
620 return None;
621 }
622 let rank = shapes
623 .iter()
624 .map(|shape| shape.0.len())
625 .max()
626 .unwrap_or(dim);
627 let axis = dim - 1;
628 if axis >= rank {
629 return None;
630 }
631 let mut out = vec![Some(1); rank];
632 for (idx, out_dim) in out.iter_mut().enumerate().take(rank) {
633 if idx == axis {
634 *out_dim = shapes
635 .iter()
636 .map(|shape| shape.0.get(idx).copied().flatten())
637 .try_fold(0usize, |sum, dim| dim.map(|dim| sum + dim));
638 continue;
639 }
640 let mut expected = None;
641 for shape in &shapes {
642 let dim = shape.0.get(idx).copied().flatten().or(Some(1));
643 if let (Some(expected), Some(dim)) = (expected, dim) {
644 if expected != dim {
645 self.warn(
646 "lint.shape.concat",
647 "concatenation dimensions do not agree",
648 span,
649 );
650 }
651 }
652 expected = expected.or(dim);
653 }
654 *out_dim = expected;
655 }
656 Some(Shape(out))
657 }
658
659 fn check_logical_index(&mut self, span: Span, base: Option<&Shape>, idx: Option<&Shape>) {
660 if let (Some(base), Some(idx)) = (base, idx) {
661 if element_count(base)
662 .zip(element_count(idx))
663 .is_some_and(|(base, idx)| base != idx)
664 {
665 self.warn(
666 "lint.shape.logical_index",
667 "logical index shape does not match indexed value",
668 span,
669 );
670 }
671 }
672 }
673
674 fn warn(&mut self, code: &'static str, message: &'static str, span: Span) {
675 self.diagnostics.push(
676 HirDiagnostic::new(code, HirDiagnosticSeverity::Warning, message, span)
677 .with_category("shape"),
678 );
679 }
680}
681
682#[derive(Clone, Copy, PartialEq)]
683enum Dim {
684 Known(usize),
685 Infer,
686 Unknown,
687}
688
689impl Dim {
690 fn as_shape_dim(self) -> Option<usize> {
691 match self {
692 Dim::Known(value) => Some(value),
693 Dim::Infer | Dim::Unknown => None,
694 }
695 }
696}
697
698fn number_to_int(value: f64) -> Option<usize> {
699 if value.is_finite() && value >= 0.0 && (value.fract().abs() <= 1e-9) {
700 Some(value as usize)
701 } else {
702 None
703 }
704}
705
706fn is_logical_operator(op: &OperatorKind) -> bool {
707 matches!(
708 op,
709 OperatorKind::Not
710 | OperatorKind::Greater
711 | OperatorKind::GreaterEqual
712 | OperatorKind::Less
713 | OperatorKind::LessEqual
714 | OperatorKind::Equal
715 | OperatorKind::NotEqual
716 | OperatorKind::ShortCircuitAnd
717 | OperatorKind::ShortCircuitOr
718 | OperatorKind::ElementwiseAnd
719 | OperatorKind::ElementwiseOr
720 )
721}
722
723fn mir_parse_dim(value: &MirShapeValue) -> Dim {
724 match value.number {
725 Some(-1.0) => Dim::Infer,
726 Some(value) => number_to_int(value).map(Dim::Known).unwrap_or(Dim::Unknown),
727 None => Dim::Unknown,
728 }
729}
730
731fn mir_parse_dims(args: &[MirShapeValue]) -> Vec<Dim> {
732 if args.len() == 1 {
733 if let Some(values) = &args[0].int_vector {
734 return values.iter().copied().map(Dim::Known).collect();
735 }
736 }
737 args.iter().map(mir_parse_dim).collect()
738}
739
740fn sized_constructor_shape(args: &[MirShapeValue]) -> Option<Shape> {
741 let dims: Vec<_> = args
742 .iter()
743 .filter_map(|value| value.number.and_then(number_to_int))
744 .map(Some)
745 .collect();
746 match dims.as_slice() {
747 [] => None,
748 [dim] => Some(Shape(vec![*dim, *dim])),
749 _ => Some(Shape(dims)),
750 }
751}
752
753fn shape_from_fact(shape: &runmat_hir::ShapeFact) -> Option<Shape> {
754 match shape {
755 runmat_hir::ShapeFact::Scalar => Some(Shape(vec![Some(1), Some(1)])),
756 runmat_hir::ShapeFact::Shaped { dims } => Some(Shape(
757 dims.iter()
758 .map(|dim| match dim {
759 runmat_hir::DimFact::Known(value) => Some(*value),
760 runmat_hir::DimFact::Symbolic(_) | runmat_hir::DimFact::Unknown => None,
761 })
762 .collect(),
763 )),
764 runmat_hir::ShapeFact::Ranked { .. }
765 | runmat_hir::ShapeFact::Unknown
766 | runmat_hir::ShapeFact::Unreachable => None,
767 }
768}
769
770fn range_width(start: f64, step: f64, end: f64) -> Option<usize> {
771 if step == 0.0 || !start.is_finite() || !step.is_finite() || !end.is_finite() {
772 return None;
773 }
774 let span = end - start;
775 if (span > 0.0 && step < 0.0) || (span < 0.0 && step > 0.0) {
776 return Some(0);
777 }
778 Some((span / step).floor().abs() as usize + 1)
779}
780
781fn matrix_dims(shape: &Shape) -> Option<(Option<usize>, Option<usize>)> {
782 Some((*shape.0.first()?, *shape.0.get(1)?))
783}
784
785fn element_count(shape: &Shape) -> Option<usize> {
786 shape
787 .0
788 .iter()
789 .try_fold(1usize, |acc, dim| dim.map(|dim| acc * dim))
790}
791
792fn vector_len(shape: &Shape) -> Option<usize> {
793 let count = element_count(shape)?;
794 if shape.0.len() == 1
795 || (shape.0.len() == 2 && (shape.0[0] == Some(1) || shape.0[1] == Some(1)))
796 {
797 Some(count)
798 } else {
799 None
800 }
801}
802
803fn infer_mir_index_shape(
804 base: Option<&Shape>,
805 indexing: &runmat_mir::MirIndexing,
806 component_values: &[Option<MirShapeValue>],
807) -> Option<Shape> {
808 if indexing.kind != IndexKind::Paren {
809 return None;
810 }
811 if indexing.components.is_empty() {
812 return base.cloned();
813 }
814 if indexing.plan == runmat_mir::MirIndexPlan::Scalar {
815 return Some(Shape(vec![Some(1), Some(1)]));
816 }
817 if indexing.components.len() == 1 {
818 return match indexing.components.first()? {
819 runmat_mir::MirIndexComponent::Colon => base.cloned(),
820 runmat_mir::MirIndexComponent::End { .. } => Some(Shape(vec![Some(1), Some(1)])),
821 runmat_mir::MirIndexComponent::Expr(_) => {
822 let value = component_values.first()?.as_ref()?;
823 if value.logical {
824 None
825 } else if value.number.is_some() {
826 Some(Shape(vec![Some(1), Some(1)]))
827 } else {
828 value.shape.clone()
829 }
830 }
831 };
832 }
833
834 let dims = indexing
835 .components
836 .iter()
837 .enumerate()
838 .map(|(idx, component)| match component {
839 runmat_mir::MirIndexComponent::Colon => {
840 base.and_then(|shape| shape.0.get(idx).copied().flatten())
841 }
842 runmat_mir::MirIndexComponent::End { .. } => Some(1),
843 runmat_mir::MirIndexComponent::Expr(_) => component_values
844 .get(idx)
845 .and_then(|value| value.as_ref())
846 .and_then(numeric_index_component_len),
847 })
848 .collect();
849 Some(Shape(dims))
850}
851
852fn numeric_index_component_len(value: &MirShapeValue) -> Option<usize> {
853 if value.logical {
854 return None;
855 }
856 if value.number.is_some() {
857 return Some(1);
858 }
859 value.shape.as_ref().and_then(vector_len)
860}
861
862fn broadcast_compatible(left: &Shape, right: &Shape) -> bool {
863 let len = left.0.len().max(right.0.len());
864 (0..len).all(|idx| {
865 let l = left.0.iter().rev().nth(idx).copied().flatten().unwrap_or(1);
866 let r = right
867 .0
868 .iter()
869 .rev()
870 .nth(idx)
871 .copied()
872 .flatten()
873 .unwrap_or(1);
874 l == r || l == 1 || r == 1
875 })
876}
877
878fn incompatible_element_count(input: Option<&Shape>, dims: &[Dim]) -> bool {
879 let Some(input_count) = input.and_then(element_count) else {
880 return false;
881 };
882 if dims.iter().any(|dim| matches!(dim, Dim::Unknown)) {
883 return true;
884 }
885 let known_product = dims.iter().fold(1usize, |acc, dim| match dim {
886 Dim::Known(value) => acc * value,
887 Dim::Infer | Dim::Unknown => acc,
888 });
889 if dims.iter().any(|dim| matches!(dim, Dim::Infer)) {
890 known_product == 0 || input_count % known_product != 0
891 } else {
892 known_product != input_count
893 }
894}