1use crate::bytecode::EndExpr;
2use crate::indexing::selectors::{index_scalar_from_value, SliceSelector};
3use crate::interpreter::errors::mex;
4use runmat_builtins::Value;
5use runmat_runtime::{builtins::common::shape::is_scalar_shape, RuntimeError};
6use std::future::Future;
7
8pub type VmResult<T> = Result<T, RuntimeError>;
9
10#[derive(Debug, Clone, Default)]
11pub struct IndexPlanProperties {
12 pub is_empty: bool,
13 pub is_scalar: bool,
14 pub full_row: Option<usize>,
15 pub full_column: Option<usize>,
16}
17
18#[derive(Debug, Clone)]
19pub struct IndexPlan {
20 pub indices: Vec<u32>,
21 pub output_shape: Vec<usize>,
22 pub selection_lengths: Vec<usize>,
23 pub dims: usize,
24 pub base_shape: Vec<usize>,
25 pub properties: IndexPlanProperties,
26}
27
28impl IndexPlan {
29 pub fn new(
30 indices: Vec<u32>,
31 output_shape: Vec<usize>,
32 selection_lengths: Vec<usize>,
33 dims: usize,
34 base_shape: Vec<usize>,
35 ) -> Self {
36 let properties = derive_plan_properties(&indices, dims, &base_shape);
37 Self {
38 indices,
39 output_shape,
40 selection_lengths,
41 dims,
42 base_shape,
43 properties,
44 }
45 }
46}
47
48fn derive_plan_properties(
49 indices: &[u32],
50 dims: usize,
51 base_shape: &[usize],
52) -> IndexPlanProperties {
53 let is_empty = indices.is_empty();
54 let is_scalar = !is_empty && indices.len() == 1;
55 let mut properties = IndexPlanProperties {
56 is_empty,
57 is_scalar,
58 full_row: None,
59 full_column: None,
60 };
61 if dims != 2 || is_empty {
62 return properties;
63 }
64 let rows = base_shape.first().copied().unwrap_or(1);
65 let cols = base_shape.get(1).copied().unwrap_or(1);
66 if indices.len() == rows {
67 let first = indices[0] as usize;
68 if first.is_multiple_of(rows) {
69 let col = first / rows;
70 if col < cols
71 && indices
72 .iter()
73 .enumerate()
74 .all(|(r, &idx)| idx as usize == col * rows + r)
75 {
76 properties.full_column = Some(col);
77 }
78 }
79 }
80 if indices.len() == cols {
81 let first = indices[0] as usize;
82 let row = first % rows;
83 if row < rows
84 && indices
85 .iter()
86 .enumerate()
87 .all(|(c, &idx)| idx as usize == row + c * rows)
88 {
89 properties.full_row = Some(row);
90 }
91 }
92 properties
93}
94
95fn cartesian_product<F: FnMut(&[usize])>(lists: &[Vec<usize>], mut f: F) {
96 let dims = lists.len();
97 if dims == 0 {
98 return;
99 }
100 let mut idx = vec![0usize; dims];
101 loop {
102 let current: Vec<usize> = (0..dims).map(|d| lists[d][idx[d]]).collect();
103 f(¤t);
104 let mut d = 0usize;
105 while d < dims {
106 idx[d] += 1;
107 if idx[d] < lists[d].len() {
108 break;
109 }
110 idx[d] = 0;
111 d += 1;
112 }
113 if d == dims {
114 break;
115 }
116 }
117}
118
119pub fn total_len_from_shape(shape: &[usize]) -> usize {
120 if is_scalar_shape(shape) {
121 1
122 } else {
123 shape.iter().copied().product()
124 }
125}
126
127fn matlab_squeezed_shape(selection_lengths: &[usize], scalar_mask: &[bool]) -> Vec<usize> {
128 let mut dims: Vec<(usize, usize, bool)> = selection_lengths
129 .iter()
130 .enumerate()
131 .map(|(d, &len)| (d, len, scalar_mask.get(d).copied().unwrap_or(false)))
132 .collect();
133 while dims.len() > 2
134 && dims
135 .last()
136 .map(|&(_, len, is_scalar)| len == 1 && is_scalar)
137 .unwrap_or(false)
138 {
139 dims.pop();
140 }
141 let out: Vec<usize> = dims.into_iter().map(|(_, len, _)| len).collect();
142 if out.is_empty() {
143 vec![1, 1]
144 } else {
145 out
146 }
147}
148
149pub fn build_index_plan(
150 selectors: &[SliceSelector],
151 dims: usize,
152 base_shape: &[usize],
153) -> VmResult<IndexPlan> {
154 let total_len = total_len_from_shape(base_shape);
155 if dims == 1 {
156 let list = selectors
157 .first()
158 .cloned()
159 .unwrap_or(SliceSelector::Indices(Vec::new()));
160 let indices = match &list {
161 SliceSelector::Colon => (1..=total_len).collect::<Vec<usize>>(),
162 SliceSelector::Scalar(i) => vec![*i],
163 SliceSelector::Indices(v) => v.clone(),
164 SliceSelector::LinearIndices { values, .. } => values.clone(),
165 };
166 if indices.iter().any(|&i| i == 0 || i > total_len) {
167 return Err(mex("IndexOutOfBounds", "Index out of bounds"));
168 }
169 let zero_based: Vec<u32> = indices.iter().map(|&i| (i - 1) as u32).collect();
170 let count = zero_based.len();
171 let shape = match list {
172 SliceSelector::LinearIndices { output_shape, .. } => output_shape,
173 _ if count <= 1 => vec![1, 1],
174 _ => vec![count, 1],
175 };
176 return Ok(IndexPlan::new(
177 zero_based,
178 shape,
179 vec![count],
180 dims,
181 base_shape.to_vec(),
182 ));
183 }
184
185 let mut selection_lengths = Vec::with_capacity(dims);
186 let mut per_dim_lists: Vec<Vec<usize>> = Vec::with_capacity(dims);
187 let mut scalar_mask: Vec<bool> = Vec::with_capacity(dims);
188 for (d, sel) in selectors.iter().enumerate().take(dims) {
189 let dim_len = base_shape.get(d).copied().unwrap_or(1);
190 let idxs = match sel {
191 SliceSelector::Colon => (1..=dim_len).collect::<Vec<usize>>(),
192 SliceSelector::Scalar(i) => vec![*i],
193 SliceSelector::Indices(v) => v.clone(),
194 SliceSelector::LinearIndices { values: v, .. } => v.clone(),
195 };
196 if idxs.iter().any(|&i| i == 0 || i > dim_len) {
197 return Err(mex("IndexOutOfBounds", "Index out of bounds"));
198 }
199 selection_lengths.push(idxs.len());
200 per_dim_lists.push(idxs);
201 scalar_mask.push(matches!(sel, SliceSelector::Scalar(_)));
202 }
203
204 let mut out_shape = matlab_squeezed_shape(&selection_lengths, &scalar_mask);
205 if selection_lengths.contains(&0) {
206 let selection_lengths = out_shape.clone();
207 return Ok(IndexPlan::new(
208 Vec::new(),
209 out_shape,
210 selection_lengths,
211 dims,
212 base_shape.to_vec(),
213 ));
214 }
215
216 let mut base_norm = base_shape.to_vec();
217 if base_norm.len() < dims {
218 base_norm.resize(dims, 1);
219 }
220 let mut strides = vec![1usize; dims];
221 for d in 1..dims {
222 strides[d] = strides[d - 1] * base_norm[d - 1].max(1);
223 }
224
225 let mut indices = Vec::new();
226 cartesian_product(&per_dim_lists, |multi| {
227 let mut lin = 0usize;
228 for d in 0..dims {
229 let idx = multi[d] - 1;
230 lin += idx * strides[d];
231 }
232 indices.push(lin as u32);
233 });
234
235 let total_out: usize = selection_lengths.iter().product();
236 if total_out == 1 {
237 out_shape = vec![1, 1];
238 }
239 let selection_lengths = out_shape.clone();
240 Ok(IndexPlan::new(
241 indices,
242 out_shape,
243 selection_lengths,
244 dims,
245 base_shape.to_vec(),
246 ))
247}
248
249#[derive(Clone)]
250enum ExprSel {
251 Colon,
252 Scalar(usize),
253 Indices(Vec<usize>),
254 Range {
255 start: i64,
256 step: i64,
257 end_off: EndExpr,
258 },
259}
260
261pub struct ExprPlanSpec<'a> {
262 pub dims: usize,
263 pub colon_mask: u32,
264 pub end_mask: u32,
265 pub range_dims: &'a [usize],
266 pub range_params: &'a [(f64, f64)],
267 pub range_start_exprs: &'a [Option<EndExpr>],
268 pub range_step_exprs: &'a [Option<EndExpr>],
269 pub range_end_exprs: &'a [EndExpr],
270 pub numeric: &'a [Value],
271 pub shape: &'a [usize],
272}
273
274pub async fn build_expr_index_plan<ResolveEnd, Fut>(
275 spec: ExprPlanSpec<'_>,
276 mut resolve_end: ResolveEnd,
277) -> Result<IndexPlan, RuntimeError>
278where
279 ResolveEnd: FnMut(usize, &EndExpr) -> Fut,
280 Fut: Future<Output = Result<i64, RuntimeError>>,
281{
282 let rank = spec.shape.len();
283 let full_shape: Vec<usize> = if spec.dims == 1 {
284 vec![total_len_from_shape(spec.shape)]
285 } else if rank < spec.dims {
286 let mut s = spec.shape.to_vec();
287 s.resize(spec.dims, 1);
288 s
289 } else {
290 spec.shape.to_vec()
291 };
292
293 let mut selectors: Vec<ExprSel> = Vec::with_capacity(spec.dims);
294 let mut num_iter = 0usize;
295 let mut rp_iter = 0usize;
296 for d in 0..spec.dims {
297 let is_colon = (spec.colon_mask & (1u32 << d)) != 0;
298 let is_end = (spec.end_mask & (1u32 << d)) != 0;
299 if is_colon {
300 selectors.push(ExprSel::Colon);
301 } else if is_end {
302 selectors.push(ExprSel::Scalar(*full_shape.get(d).unwrap_or(&1)));
303 } else if let Some(pos) = spec.range_dims.iter().position(|&rd| rd == d) {
304 let (raw_st, raw_sp) = spec.range_params[rp_iter];
305 let dim_len = *full_shape.get(d).unwrap_or(&1);
306 let st = if let Some(expr) = &spec.range_start_exprs[rp_iter] {
307 resolve_end(dim_len, expr).await? as f64
308 } else {
309 raw_st
310 };
311 let sp = if let Some(expr) = &spec.range_step_exprs[rp_iter] {
312 resolve_end(dim_len, expr).await? as f64
313 } else {
314 raw_sp
315 };
316 rp_iter += 1;
317 let off = spec.range_end_exprs[pos].clone();
318 selectors.push(ExprSel::Range {
319 start: st as i64,
320 step: if sp >= 0.0 {
321 sp as i64
322 } else {
323 -(sp.abs() as i64)
324 },
325 end_off: off,
326 });
327 } else {
328 let v = spec
329 .numeric
330 .get(num_iter)
331 .ok_or_else(|| mex("MissingNumericIndex", "missing numeric index"))?;
332 num_iter += 1;
333 if let Some(idx) = index_scalar_from_value(v).await? {
334 if idx < 1 {
335 return Err(mex("IndexOutOfBounds", "Index out of bounds"));
336 }
337 selectors.push(ExprSel::Scalar(idx as usize));
338 } else {
339 match v {
340 Value::Tensor(idx_t) => {
341 let dim_len = *full_shape.get(d).unwrap_or(&1);
342 let len = idx_t.shape.iter().product::<usize>();
343 if len == dim_len {
344 let mut vv = Vec::new();
345 for (i, &val) in idx_t.data.iter().enumerate() {
346 if val != 0.0 {
347 vv.push(i + 1);
348 }
349 }
350 selectors.push(ExprSel::Indices(vv));
351 } else {
352 let mut vv = Vec::with_capacity(len);
353 for &val in &idx_t.data {
354 let idx = val as isize;
355 if idx < 1 {
356 return Err(mex("IndexOutOfBounds", "Index out of bounds"));
357 }
358 vv.push(idx as usize);
359 }
360 selectors.push(ExprSel::Indices(vv));
361 }
362 }
363 _ => return Err(mex("UnsupportedIndexType", "Unsupported index type")),
364 }
365 }
366 }
367 }
368
369 let mut per_dim_indices: Vec<Vec<usize>> = Vec::with_capacity(spec.dims);
370 let mut selection_lengths: Vec<usize> = Vec::with_capacity(spec.dims);
371 let mut scalar_mask: Vec<bool> = Vec::with_capacity(spec.dims);
372 for (d, sel) in selectors.iter().enumerate().take(spec.dims) {
373 let dim_len = full_shape[d] as i64;
374 let idxs: Vec<usize> = match sel {
375 ExprSel::Colon => (1..=full_shape[d]).collect(),
376 ExprSel::Scalar(i) => vec![*i],
377 ExprSel::Indices(v) => v.clone(),
378 ExprSel::Range {
379 start,
380 step,
381 end_off,
382 } => {
383 let mut v = Vec::new();
384 let mut cur = *start;
385 let stp = *step;
386 let end_i = resolve_end(dim_len as usize, end_off).await?;
387 if stp == 0 {
388 return Err(mex("IndexStepZero", "Index step cannot be zero"));
389 }
390 if stp > 0 {
391 while cur <= end_i {
392 if cur < 1 || cur > dim_len {
393 break;
394 }
395 v.push(cur as usize);
396 cur += stp;
397 }
398 } else {
399 while cur >= end_i {
400 if cur < 1 || cur > dim_len {
401 break;
402 }
403 v.push(cur as usize);
404 cur += stp;
405 }
406 }
407 v
408 }
409 };
410 if idxs.iter().any(|&i| i == 0 || i > full_shape[d]) {
411 return Err(mex("IndexOutOfBounds", "Index out of bounds"));
412 }
413 selection_lengths.push(idxs.len());
414 per_dim_indices.push(idxs);
415 scalar_mask.push(matches!(sel, ExprSel::Scalar(_)));
416 }
417
418 let mut strides: Vec<usize> = vec![0; spec.dims];
419 let mut acc = 1usize;
420 for (d, stride) in strides.iter_mut().enumerate().take(spec.dims) {
421 *stride = acc;
422 acc *= full_shape[d];
423 }
424 let total_out: usize = per_dim_indices.iter().map(|v| v.len()).product();
425 if total_out == 0 {
426 let output_shape = if spec.dims == 1 {
427 vec![1, 0]
428 } else {
429 let mut dims_out: Vec<(usize, usize, bool)> = selection_lengths
430 .iter()
431 .enumerate()
432 .map(|(d, &len)| (d, len, scalar_mask.get(d).copied().unwrap_or(false)))
433 .collect();
434 while dims_out.len() > 2
435 && dims_out
436 .last()
437 .map(|&(_, len, is_scalar)| len == 1 && is_scalar)
438 .unwrap_or(false)
439 {
440 dims_out.pop();
441 }
442 if dims_out.is_empty() {
443 vec![1, 1]
444 } else if dims_out.len() == 1 {
445 let (dim, len, _) = dims_out[0];
446 if dim == 1 {
447 vec![1, len]
448 } else {
449 vec![len, 1]
450 }
451 } else {
452 dims_out.into_iter().map(|(_, len, _)| len).collect()
453 }
454 };
455 return Ok(IndexPlan::new(
456 Vec::new(),
457 output_shape,
458 selection_lengths,
459 spec.dims,
460 spec.shape.to_vec(),
461 ));
462 }
463
464 let mut indices: Vec<u32> = Vec::with_capacity(total_out);
465 let mut idx = vec![0usize; spec.dims];
466 loop {
467 let mut lin = 0usize;
468 for d in 0..spec.dims {
469 let i0 = per_dim_indices[d][idx[d]] - 1;
470 lin += i0 * strides[d];
471 }
472 indices.push(lin as u32);
473 let mut d = 0usize;
474 while d < spec.dims {
475 idx[d] += 1;
476 if idx[d] < per_dim_indices[d].len() {
477 break;
478 }
479 idx[d] = 0;
480 d += 1;
481 }
482 if d == spec.dims {
483 break;
484 }
485 }
486
487 let output_shape = if spec.dims == 1 {
488 if total_out <= 1 {
489 vec![1, 1]
490 } else {
491 vec![1, total_out]
492 }
493 } else {
494 let mut dims_out: Vec<(usize, usize, bool)> = selection_lengths
495 .iter()
496 .enumerate()
497 .map(|(d, &len)| (d, len, scalar_mask.get(d).copied().unwrap_or(false)))
498 .collect();
499 while dims_out.len() > 2
500 && dims_out
501 .last()
502 .map(|&(_, len, is_scalar)| len == 1 && is_scalar)
503 .unwrap_or(false)
504 {
505 dims_out.pop();
506 }
507 if dims_out.is_empty() {
508 vec![1, 1]
509 } else if dims_out.len() == 1 {
510 let (dim, len, _) = dims_out[0];
511 if dim == 1 {
512 vec![1, len]
513 } else {
514 vec![len, 1]
515 }
516 } else {
517 dims_out.into_iter().map(|(_, len, _)| len).collect()
518 }
519 };
520 Ok(IndexPlan::new(
521 indices,
522 output_shape,
523 selection_lengths,
524 spec.dims,
525 spec.shape.to_vec(),
526 ))
527}
528
529#[cfg(test)]
530mod tests {
531 use super::{build_expr_index_plan, build_index_plan, ExprPlanSpec};
532 use crate::bytecode::EndExpr;
533 use crate::indexing::selectors::build_slice_selectors;
534 use runmat_builtins::{Tensor, Value};
535
536 #[test]
537 fn plain_and_expr_linear_range_plans_match() {
538 futures::executor::block_on(async {
539 let shape = vec![1, 10];
540 let numeric = vec![Value::Tensor(
541 Tensor::new(vec![2.0, 4.0, 6.0, 8.0], vec![1, 4]).unwrap(),
542 )];
543 let plain_selectors = build_slice_selectors(1, 0, 0, &numeric, &shape)
544 .await
545 .unwrap();
546 let plain = build_index_plan(&plain_selectors, 1, &shape).unwrap();
547 let expr = build_expr_index_plan(
548 ExprPlanSpec {
549 dims: 1,
550 colon_mask: 0,
551 end_mask: 0,
552 range_dims: &[0],
553 range_params: &[(2.0, 2.0)],
554 range_start_exprs: &[None],
555 range_step_exprs: &[None],
556 range_end_exprs: &[EndExpr::Sub(
557 Box::new(EndExpr::End),
558 Box::new(EndExpr::Const(1.0)),
559 )],
560 numeric: &[],
561 shape: &shape,
562 },
563 |dim_len, expr| {
564 let expr = expr.clone();
565 async move {
566 Ok(match &expr {
567 EndExpr::End => dim_len as i64,
568 EndExpr::Const(value) => *value as i64,
569 EndExpr::Sub(lhs, rhs) => {
570 let lhs_val = match lhs.as_ref() {
571 EndExpr::End => dim_len as i64,
572 EndExpr::Const(value) => *value as i64,
573 other => panic!("unsupported lhs expr: {other:?}"),
574 };
575 let rhs_val = match rhs.as_ref() {
576 EndExpr::Const(value) => *value as i64,
577 other => panic!("unsupported rhs expr: {other:?}"),
578 };
579 lhs_val - rhs_val
580 }
581 other => panic!("unsupported expr: {other:?}"),
582 })
583 }
584 },
585 )
586 .await
587 .unwrap();
588 assert_eq!(plain.indices, expr.indices);
589 assert_eq!(plain.output_shape, expr.output_shape);
590 assert_eq!(plain.selection_lengths, expr.selection_lengths);
591 assert_eq!(plain.properties.full_row, expr.properties.full_row);
592 assert_eq!(plain.properties.full_column, expr.properties.full_column);
593 })
594 }
595
596 #[test]
597 fn plain_and_expr_column_plans_match_properties() {
598 futures::executor::block_on(async {
599 let shape = vec![3, 4];
600 let numeric = vec![Value::Num(3.0)];
601 let plain_selectors = build_slice_selectors(2, 1, 0, &numeric, &shape)
602 .await
603 .unwrap();
604 let plain = build_index_plan(&plain_selectors, 2, &shape).unwrap();
605 let expr = build_expr_index_plan(
606 ExprPlanSpec {
607 dims: 2,
608 colon_mask: 1,
609 end_mask: 0,
610 range_dims: &[],
611 range_params: &[],
612 range_start_exprs: &[],
613 range_step_exprs: &[],
614 range_end_exprs: &[],
615 numeric: &numeric,
616 shape: &shape,
617 },
618 |_dim_len, _expr| async move { unreachable!() },
619 )
620 .await
621 .unwrap();
622 assert_eq!(plain.indices, expr.indices);
623 assert_eq!(plain.properties.full_column, Some(2));
624 assert_eq!(plain.properties.full_column, expr.properties.full_column);
625 assert_eq!(plain.properties.full_row, expr.properties.full_row);
626 })
627 }
628}