Skip to main content

runmat_vm/indexing/
plan.rs

1use crate::bytecode::EndExpr;
2use crate::indexing::selectors::{index_scalar_from_value, SliceSelector};
3use crate::interpreter::errors::mex;
4use runmat_builtins::Value;
5use runmat_runtime::{builtins::common::shape::is_scalar_shape, RuntimeError};
6use std::future::Future;
7
8pub type VmResult<T> = Result<T, RuntimeError>;
9
10#[derive(Debug, Clone, Default)]
11pub struct IndexPlanProperties {
12    pub is_empty: bool,
13    pub is_scalar: bool,
14    pub full_row: Option<usize>,
15    pub full_column: Option<usize>,
16}
17
18#[derive(Debug, Clone)]
19pub struct IndexPlan {
20    pub indices: Vec<u32>,
21    pub output_shape: Vec<usize>,
22    pub selection_lengths: Vec<usize>,
23    pub dims: usize,
24    pub base_shape: Vec<usize>,
25    pub properties: IndexPlanProperties,
26}
27
28impl IndexPlan {
29    pub fn new(
30        indices: Vec<u32>,
31        output_shape: Vec<usize>,
32        selection_lengths: Vec<usize>,
33        dims: usize,
34        base_shape: Vec<usize>,
35    ) -> Self {
36        let properties = derive_plan_properties(&indices, dims, &base_shape);
37        Self {
38            indices,
39            output_shape,
40            selection_lengths,
41            dims,
42            base_shape,
43            properties,
44        }
45    }
46}
47
48fn derive_plan_properties(
49    indices: &[u32],
50    dims: usize,
51    base_shape: &[usize],
52) -> IndexPlanProperties {
53    let is_empty = indices.is_empty();
54    let is_scalar = !is_empty && indices.len() == 1;
55    let mut properties = IndexPlanProperties {
56        is_empty,
57        is_scalar,
58        full_row: None,
59        full_column: None,
60    };
61    if dims != 2 || is_empty {
62        return properties;
63    }
64    let rows = base_shape.first().copied().unwrap_or(1);
65    let cols = base_shape.get(1).copied().unwrap_or(1);
66    if indices.len() == rows {
67        let first = indices[0] as usize;
68        if first.is_multiple_of(rows) {
69            let col = first / rows;
70            if col < cols
71                && indices
72                    .iter()
73                    .enumerate()
74                    .all(|(r, &idx)| idx as usize == col * rows + r)
75            {
76                properties.full_column = Some(col);
77            }
78        }
79    }
80    if indices.len() == cols {
81        let first = indices[0] as usize;
82        let row = first % rows;
83        if row < rows
84            && indices
85                .iter()
86                .enumerate()
87                .all(|(c, &idx)| idx as usize == row + c * rows)
88        {
89            properties.full_row = Some(row);
90        }
91    }
92    properties
93}
94
95fn cartesian_product<F: FnMut(&[usize])>(lists: &[Vec<usize>], mut f: F) {
96    let dims = lists.len();
97    if dims == 0 {
98        return;
99    }
100    let mut idx = vec![0usize; dims];
101    loop {
102        let current: Vec<usize> = (0..dims).map(|d| lists[d][idx[d]]).collect();
103        f(&current);
104        let mut d = 0usize;
105        while d < dims {
106            idx[d] += 1;
107            if idx[d] < lists[d].len() {
108                break;
109            }
110            idx[d] = 0;
111            d += 1;
112        }
113        if d == dims {
114            break;
115        }
116    }
117}
118
119pub fn total_len_from_shape(shape: &[usize]) -> usize {
120    if is_scalar_shape(shape) {
121        1
122    } else {
123        shape.iter().copied().product()
124    }
125}
126
127fn matlab_squeezed_shape(selection_lengths: &[usize], scalar_mask: &[bool]) -> Vec<usize> {
128    let mut dims: Vec<(usize, usize, bool)> = selection_lengths
129        .iter()
130        .enumerate()
131        .map(|(d, &len)| (d, len, scalar_mask.get(d).copied().unwrap_or(false)))
132        .collect();
133    while dims.len() > 2
134        && dims
135            .last()
136            .map(|&(_, len, is_scalar)| len == 1 && is_scalar)
137            .unwrap_or(false)
138    {
139        dims.pop();
140    }
141    let out: Vec<usize> = dims.into_iter().map(|(_, len, _)| len).collect();
142    if out.is_empty() {
143        vec![1, 1]
144    } else {
145        out
146    }
147}
148
149pub fn build_index_plan(
150    selectors: &[SliceSelector],
151    dims: usize,
152    base_shape: &[usize],
153) -> VmResult<IndexPlan> {
154    let total_len = total_len_from_shape(base_shape);
155    if dims == 1 {
156        let list = selectors
157            .first()
158            .cloned()
159            .unwrap_or(SliceSelector::Indices(Vec::new()));
160        let indices = match &list {
161            SliceSelector::Colon => (1..=total_len).collect::<Vec<usize>>(),
162            SliceSelector::Scalar(i) => vec![*i],
163            SliceSelector::Indices(v) => v.clone(),
164            SliceSelector::LinearIndices { values, .. } => values.clone(),
165        };
166        if indices.iter().any(|&i| i == 0 || i > total_len) {
167            return Err(mex("IndexOutOfBounds", "Index out of bounds"));
168        }
169        let zero_based: Vec<u32> = indices.iter().map(|&i| (i - 1) as u32).collect();
170        let count = zero_based.len();
171        let shape = match list {
172            SliceSelector::LinearIndices { output_shape, .. } => output_shape,
173            _ if count <= 1 => vec![1, 1],
174            _ => vec![count, 1],
175        };
176        return Ok(IndexPlan::new(
177            zero_based,
178            shape,
179            vec![count],
180            dims,
181            base_shape.to_vec(),
182        ));
183    }
184
185    let mut selection_lengths = Vec::with_capacity(dims);
186    let mut per_dim_lists: Vec<Vec<usize>> = Vec::with_capacity(dims);
187    let mut scalar_mask: Vec<bool> = Vec::with_capacity(dims);
188    for (d, sel) in selectors.iter().enumerate().take(dims) {
189        let dim_len = base_shape.get(d).copied().unwrap_or(1);
190        let idxs = match sel {
191            SliceSelector::Colon => (1..=dim_len).collect::<Vec<usize>>(),
192            SliceSelector::Scalar(i) => vec![*i],
193            SliceSelector::Indices(v) => v.clone(),
194            SliceSelector::LinearIndices { values: v, .. } => v.clone(),
195        };
196        if idxs.iter().any(|&i| i == 0 || i > dim_len) {
197            return Err(mex("IndexOutOfBounds", "Index out of bounds"));
198        }
199        selection_lengths.push(idxs.len());
200        per_dim_lists.push(idxs);
201        scalar_mask.push(matches!(sel, SliceSelector::Scalar(_)));
202    }
203
204    let mut out_shape = matlab_squeezed_shape(&selection_lengths, &scalar_mask);
205    if selection_lengths.contains(&0) {
206        let selection_lengths = out_shape.clone();
207        return Ok(IndexPlan::new(
208            Vec::new(),
209            out_shape,
210            selection_lengths,
211            dims,
212            base_shape.to_vec(),
213        ));
214    }
215
216    let mut base_norm = base_shape.to_vec();
217    if base_norm.len() < dims {
218        base_norm.resize(dims, 1);
219    }
220    let mut strides = vec![1usize; dims];
221    for d in 1..dims {
222        strides[d] = strides[d - 1] * base_norm[d - 1].max(1);
223    }
224
225    let mut indices = Vec::new();
226    cartesian_product(&per_dim_lists, |multi| {
227        let mut lin = 0usize;
228        for d in 0..dims {
229            let idx = multi[d] - 1;
230            lin += idx * strides[d];
231        }
232        indices.push(lin as u32);
233    });
234
235    let total_out: usize = selection_lengths.iter().product();
236    if total_out == 1 {
237        out_shape = vec![1, 1];
238    }
239    let selection_lengths = out_shape.clone();
240    Ok(IndexPlan::new(
241        indices,
242        out_shape,
243        selection_lengths,
244        dims,
245        base_shape.to_vec(),
246    ))
247}
248
249#[derive(Clone)]
250enum ExprSel {
251    Colon,
252    Scalar(usize),
253    Indices(Vec<usize>),
254    Range {
255        start: i64,
256        step: i64,
257        end_off: EndExpr,
258    },
259}
260
261pub struct ExprPlanSpec<'a> {
262    pub dims: usize,
263    pub colon_mask: u32,
264    pub end_mask: u32,
265    pub range_dims: &'a [usize],
266    pub range_params: &'a [(f64, f64)],
267    pub range_start_exprs: &'a [Option<EndExpr>],
268    pub range_step_exprs: &'a [Option<EndExpr>],
269    pub range_end_exprs: &'a [EndExpr],
270    pub numeric: &'a [Value],
271    pub shape: &'a [usize],
272}
273
274pub async fn build_expr_index_plan<ResolveEnd, Fut>(
275    spec: ExprPlanSpec<'_>,
276    mut resolve_end: ResolveEnd,
277) -> Result<IndexPlan, RuntimeError>
278where
279    ResolveEnd: FnMut(usize, &EndExpr) -> Fut,
280    Fut: Future<Output = Result<i64, RuntimeError>>,
281{
282    let rank = spec.shape.len();
283    let full_shape: Vec<usize> = if spec.dims == 1 {
284        vec![total_len_from_shape(spec.shape)]
285    } else if rank < spec.dims {
286        let mut s = spec.shape.to_vec();
287        s.resize(spec.dims, 1);
288        s
289    } else {
290        spec.shape.to_vec()
291    };
292
293    let mut selectors: Vec<ExprSel> = Vec::with_capacity(spec.dims);
294    let mut num_iter = 0usize;
295    let mut rp_iter = 0usize;
296    for d in 0..spec.dims {
297        let is_colon = (spec.colon_mask & (1u32 << d)) != 0;
298        let is_end = (spec.end_mask & (1u32 << d)) != 0;
299        if is_colon {
300            selectors.push(ExprSel::Colon);
301        } else if is_end {
302            selectors.push(ExprSel::Scalar(*full_shape.get(d).unwrap_or(&1)));
303        } else if let Some(pos) = spec.range_dims.iter().position(|&rd| rd == d) {
304            let (raw_st, raw_sp) = spec.range_params[rp_iter];
305            let dim_len = *full_shape.get(d).unwrap_or(&1);
306            let st = if let Some(expr) = &spec.range_start_exprs[rp_iter] {
307                resolve_end(dim_len, expr).await? as f64
308            } else {
309                raw_st
310            };
311            let sp = if let Some(expr) = &spec.range_step_exprs[rp_iter] {
312                resolve_end(dim_len, expr).await? as f64
313            } else {
314                raw_sp
315            };
316            rp_iter += 1;
317            let off = spec.range_end_exprs[pos].clone();
318            selectors.push(ExprSel::Range {
319                start: st as i64,
320                step: if sp >= 0.0 {
321                    sp as i64
322                } else {
323                    -(sp.abs() as i64)
324                },
325                end_off: off,
326            });
327        } else {
328            let v = spec
329                .numeric
330                .get(num_iter)
331                .ok_or_else(|| mex("MissingNumericIndex", "missing numeric index"))?;
332            num_iter += 1;
333            if let Some(idx) = index_scalar_from_value(v).await? {
334                if idx < 1 {
335                    return Err(mex("IndexOutOfBounds", "Index out of bounds"));
336                }
337                selectors.push(ExprSel::Scalar(idx as usize));
338            } else {
339                match v {
340                    Value::Tensor(idx_t) => {
341                        let dim_len = *full_shape.get(d).unwrap_or(&1);
342                        let len = idx_t.shape.iter().product::<usize>();
343                        if len == dim_len {
344                            let mut vv = Vec::new();
345                            for (i, &val) in idx_t.data.iter().enumerate() {
346                                if val != 0.0 {
347                                    vv.push(i + 1);
348                                }
349                            }
350                            selectors.push(ExprSel::Indices(vv));
351                        } else {
352                            let mut vv = Vec::with_capacity(len);
353                            for &val in &idx_t.data {
354                                let idx = val as isize;
355                                if idx < 1 {
356                                    return Err(mex("IndexOutOfBounds", "Index out of bounds"));
357                                }
358                                vv.push(idx as usize);
359                            }
360                            selectors.push(ExprSel::Indices(vv));
361                        }
362                    }
363                    _ => return Err(mex("UnsupportedIndexType", "Unsupported index type")),
364                }
365            }
366        }
367    }
368
369    let mut per_dim_indices: Vec<Vec<usize>> = Vec::with_capacity(spec.dims);
370    let mut selection_lengths: Vec<usize> = Vec::with_capacity(spec.dims);
371    let mut scalar_mask: Vec<bool> = Vec::with_capacity(spec.dims);
372    for (d, sel) in selectors.iter().enumerate().take(spec.dims) {
373        let dim_len = full_shape[d] as i64;
374        let idxs: Vec<usize> = match sel {
375            ExprSel::Colon => (1..=full_shape[d]).collect(),
376            ExprSel::Scalar(i) => vec![*i],
377            ExprSel::Indices(v) => v.clone(),
378            ExprSel::Range {
379                start,
380                step,
381                end_off,
382            } => {
383                let mut v = Vec::new();
384                let mut cur = *start;
385                let stp = *step;
386                let end_i = resolve_end(dim_len as usize, end_off).await?;
387                if stp == 0 {
388                    return Err(mex("IndexStepZero", "Index step cannot be zero"));
389                }
390                if stp > 0 {
391                    while cur <= end_i {
392                        if cur < 1 || cur > dim_len {
393                            break;
394                        }
395                        v.push(cur as usize);
396                        cur += stp;
397                    }
398                } else {
399                    while cur >= end_i {
400                        if cur < 1 || cur > dim_len {
401                            break;
402                        }
403                        v.push(cur as usize);
404                        cur += stp;
405                    }
406                }
407                v
408            }
409        };
410        if idxs.iter().any(|&i| i == 0 || i > full_shape[d]) {
411            return Err(mex("IndexOutOfBounds", "Index out of bounds"));
412        }
413        selection_lengths.push(idxs.len());
414        per_dim_indices.push(idxs);
415        scalar_mask.push(matches!(sel, ExprSel::Scalar(_)));
416    }
417
418    let mut strides: Vec<usize> = vec![0; spec.dims];
419    let mut acc = 1usize;
420    for (d, stride) in strides.iter_mut().enumerate().take(spec.dims) {
421        *stride = acc;
422        acc *= full_shape[d];
423    }
424    let total_out: usize = per_dim_indices.iter().map(|v| v.len()).product();
425    if total_out == 0 {
426        let output_shape = if spec.dims == 1 {
427            vec![1, 0]
428        } else {
429            let mut dims_out: Vec<(usize, usize, bool)> = selection_lengths
430                .iter()
431                .enumerate()
432                .map(|(d, &len)| (d, len, scalar_mask.get(d).copied().unwrap_or(false)))
433                .collect();
434            while dims_out.len() > 2
435                && dims_out
436                    .last()
437                    .map(|&(_, len, is_scalar)| len == 1 && is_scalar)
438                    .unwrap_or(false)
439            {
440                dims_out.pop();
441            }
442            if dims_out.is_empty() {
443                vec![1, 1]
444            } else if dims_out.len() == 1 {
445                let (dim, len, _) = dims_out[0];
446                if dim == 1 {
447                    vec![1, len]
448                } else {
449                    vec![len, 1]
450                }
451            } else {
452                dims_out.into_iter().map(|(_, len, _)| len).collect()
453            }
454        };
455        return Ok(IndexPlan::new(
456            Vec::new(),
457            output_shape,
458            selection_lengths,
459            spec.dims,
460            spec.shape.to_vec(),
461        ));
462    }
463
464    let mut indices: Vec<u32> = Vec::with_capacity(total_out);
465    let mut idx = vec![0usize; spec.dims];
466    loop {
467        let mut lin = 0usize;
468        for d in 0..spec.dims {
469            let i0 = per_dim_indices[d][idx[d]] - 1;
470            lin += i0 * strides[d];
471        }
472        indices.push(lin as u32);
473        let mut d = 0usize;
474        while d < spec.dims {
475            idx[d] += 1;
476            if idx[d] < per_dim_indices[d].len() {
477                break;
478            }
479            idx[d] = 0;
480            d += 1;
481        }
482        if d == spec.dims {
483            break;
484        }
485    }
486
487    let output_shape = if spec.dims == 1 {
488        if total_out <= 1 {
489            vec![1, 1]
490        } else {
491            vec![1, total_out]
492        }
493    } else {
494        let mut dims_out: Vec<(usize, usize, bool)> = selection_lengths
495            .iter()
496            .enumerate()
497            .map(|(d, &len)| (d, len, scalar_mask.get(d).copied().unwrap_or(false)))
498            .collect();
499        while dims_out.len() > 2
500            && dims_out
501                .last()
502                .map(|&(_, len, is_scalar)| len == 1 && is_scalar)
503                .unwrap_or(false)
504        {
505            dims_out.pop();
506        }
507        if dims_out.is_empty() {
508            vec![1, 1]
509        } else if dims_out.len() == 1 {
510            let (dim, len, _) = dims_out[0];
511            if dim == 1 {
512                vec![1, len]
513            } else {
514                vec![len, 1]
515            }
516        } else {
517            dims_out.into_iter().map(|(_, len, _)| len).collect()
518        }
519    };
520    Ok(IndexPlan::new(
521        indices,
522        output_shape,
523        selection_lengths,
524        spec.dims,
525        spec.shape.to_vec(),
526    ))
527}
528
529#[cfg(test)]
530mod tests {
531    use super::{build_expr_index_plan, build_index_plan, ExprPlanSpec};
532    use crate::bytecode::EndExpr;
533    use crate::indexing::selectors::build_slice_selectors;
534    use runmat_builtins::{Tensor, Value};
535
536    #[test]
537    fn plain_and_expr_linear_range_plans_match() {
538        futures::executor::block_on(async {
539            let shape = vec![1, 10];
540            let numeric = vec![Value::Tensor(
541                Tensor::new(vec![2.0, 4.0, 6.0, 8.0], vec![1, 4]).unwrap(),
542            )];
543            let plain_selectors = build_slice_selectors(1, 0, 0, &numeric, &shape)
544                .await
545                .unwrap();
546            let plain = build_index_plan(&plain_selectors, 1, &shape).unwrap();
547            let expr = build_expr_index_plan(
548                ExprPlanSpec {
549                    dims: 1,
550                    colon_mask: 0,
551                    end_mask: 0,
552                    range_dims: &[0],
553                    range_params: &[(2.0, 2.0)],
554                    range_start_exprs: &[None],
555                    range_step_exprs: &[None],
556                    range_end_exprs: &[EndExpr::Sub(
557                        Box::new(EndExpr::End),
558                        Box::new(EndExpr::Const(1.0)),
559                    )],
560                    numeric: &[],
561                    shape: &shape,
562                },
563                |dim_len, expr| {
564                    let expr = expr.clone();
565                    async move {
566                        Ok(match &expr {
567                            EndExpr::End => dim_len as i64,
568                            EndExpr::Const(value) => *value as i64,
569                            EndExpr::Sub(lhs, rhs) => {
570                                let lhs_val = match lhs.as_ref() {
571                                    EndExpr::End => dim_len as i64,
572                                    EndExpr::Const(value) => *value as i64,
573                                    other => panic!("unsupported lhs expr: {other:?}"),
574                                };
575                                let rhs_val = match rhs.as_ref() {
576                                    EndExpr::Const(value) => *value as i64,
577                                    other => panic!("unsupported rhs expr: {other:?}"),
578                                };
579                                lhs_val - rhs_val
580                            }
581                            other => panic!("unsupported expr: {other:?}"),
582                        })
583                    }
584                },
585            )
586            .await
587            .unwrap();
588            assert_eq!(plain.indices, expr.indices);
589            assert_eq!(plain.output_shape, expr.output_shape);
590            assert_eq!(plain.selection_lengths, expr.selection_lengths);
591            assert_eq!(plain.properties.full_row, expr.properties.full_row);
592            assert_eq!(plain.properties.full_column, expr.properties.full_column);
593        })
594    }
595
596    #[test]
597    fn plain_and_expr_column_plans_match_properties() {
598        futures::executor::block_on(async {
599            let shape = vec![3, 4];
600            let numeric = vec![Value::Num(3.0)];
601            let plain_selectors = build_slice_selectors(2, 1, 0, &numeric, &shape)
602                .await
603                .unwrap();
604            let plain = build_index_plan(&plain_selectors, 2, &shape).unwrap();
605            let expr = build_expr_index_plan(
606                ExprPlanSpec {
607                    dims: 2,
608                    colon_mask: 1,
609                    end_mask: 0,
610                    range_dims: &[],
611                    range_params: &[],
612                    range_start_exprs: &[],
613                    range_step_exprs: &[],
614                    range_end_exprs: &[],
615                    numeric: &numeric,
616                    shape: &shape,
617                },
618                |_dim_len, _expr| async move { unreachable!() },
619            )
620            .await
621            .unwrap();
622            assert_eq!(plain.indices, expr.indices);
623            assert_eq!(plain.properties.full_column, Some(2));
624            assert_eq!(plain.properties.full_column, expr.properties.full_column);
625            assert_eq!(plain.properties.full_row, expr.properties.full_row);
626        })
627    }
628}