Skip to main content

runmat_vm/indexing/
selectors.rs

1use crate::indexing::plan::total_len_from_shape;
2use crate::interpreter::errors::mex;
3use runmat_builtins::Value;
4use runmat_runtime::{
5    builtins::common::shape::is_scalar_shape, dispatcher::gather_if_needed_async, RuntimeError,
6};
7
8pub type VmResult<T> = Result<T, RuntimeError>;
9
10#[derive(Clone)]
11pub enum SliceSelector {
12    Colon,
13    Scalar(usize),
14    Indices(Vec<usize>),
15    LinearIndices {
16        values: Vec<usize>,
17        output_shape: Vec<usize>,
18    },
19}
20
21fn index_scalar_from_host_value(value: &Value) -> Option<i64> {
22    match value {
23        Value::Num(n) => Some(*n as i64),
24        Value::Int(int_val) => Some(int_val.to_i64()),
25        Value::Tensor(t) if t.data.len() == 1 && is_scalar_shape(&t.shape) => {
26            Some(t.data[0] as i64)
27        }
28        _ => None,
29    }
30}
31
32pub async fn index_scalar_from_value(value: &Value) -> VmResult<Option<i64>> {
33    if let Value::GpuTensor(handle) = value {
34        let total = total_len_from_shape(&handle.shape);
35        if total != 1 {
36            return Ok(None);
37        }
38        let gathered = gather_if_needed_async(value).await?;
39        return Ok(index_scalar_from_host_value(&gathered));
40    }
41    Ok(index_scalar_from_host_value(value))
42}
43
44pub async fn materialize_index_value(value: &Value) -> VmResult<Value> {
45    if matches!(value, Value::GpuTensor(_)) {
46        return gather_if_needed_async(value)
47            .await
48            .map_err(|e| mex("IndexGather", &format!("Failed to gather index value: {e}")));
49    }
50    Ok(value.clone())
51}
52
53pub async fn indices_from_value_linear(value: &Value, total_len: usize) -> VmResult<Vec<usize>> {
54    if let Value::Bool(b) = value {
55        return Ok(if *b { vec![1] } else { Vec::new() });
56    }
57    if let Value::LogicalArray(la) = value {
58        if la.data.len() == 1 && is_scalar_shape(&la.shape) {
59            return Ok(if la.data[0] != 0 { vec![1] } else { Vec::new() });
60        }
61    }
62    if let Some(idx_val) = index_scalar_from_value(value).await? {
63        if idx_val < 1 || (idx_val as usize) > total_len {
64            return Err(mex("IndexOutOfBounds", "Index out of bounds"));
65        }
66        return Ok(vec![idx_val as usize]);
67    }
68    let materialized;
69    let value = if matches!(value, Value::GpuTensor(_)) {
70        materialized = materialize_index_value(value).await?;
71        &materialized
72    } else {
73        value
74    };
75    match value {
76        Value::Tensor(idx_t) => {
77            let len = idx_t.shape.iter().product::<usize>();
78            let mut indices = Vec::with_capacity(len);
79            for &val in &idx_t.data {
80                let idx = val as isize;
81                if idx < 1 || (idx as usize) > total_len {
82                    return Err(mex("IndexOutOfBounds", "Index out of bounds"));
83                }
84                indices.push(idx as usize);
85            }
86            Ok(indices)
87        }
88        Value::LogicalArray(la) => {
89            if la.data.len() != total_len {
90                return Err(mex(
91                    "IndexShape",
92                    "Logical mask length mismatch for linear indexing",
93                ));
94            }
95            let mut indices = Vec::new();
96            for (i, &b) in la.data.iter().enumerate() {
97                if b != 0 {
98                    indices.push(i + 1);
99                }
100            }
101            Ok(indices)
102        }
103        _ => Err(mex(
104            "UnsupportedIndexType",
105            "Unsupported index type for linear indexing",
106        )),
107    }
108}
109
110pub async fn selector_from_value_dim(value: &Value, dim_len: usize) -> VmResult<SliceSelector> {
111    if let Value::Bool(b) = value {
112        if *b {
113            return Ok(SliceSelector::Indices(vec![1]));
114        }
115        return Ok(SliceSelector::Indices(Vec::new()));
116    }
117    if let Value::LogicalArray(la) = value {
118        if la.data.len() == 1 && is_scalar_shape(&la.shape) {
119            if la.data[0] != 0 {
120                return Ok(SliceSelector::Indices(vec![1]));
121            }
122            return Ok(SliceSelector::Indices(Vec::new()));
123        }
124    }
125    if let Some(idx_val) = index_scalar_from_value(value).await? {
126        if idx_val < 1 || (idx_val as usize) > dim_len {
127            return Err(mex("IndexOutOfBounds", "Index out of bounds"));
128        }
129        return Ok(SliceSelector::Scalar(idx_val as usize));
130    }
131    let materialized;
132    let value = if matches!(value, Value::GpuTensor(_)) {
133        materialized = materialize_index_value(value).await?;
134        &materialized
135    } else {
136        value
137    };
138    match value {
139        Value::Tensor(idx_t) => {
140            let len = idx_t.shape.iter().product::<usize>();
141            let mut indices = Vec::with_capacity(len);
142            for &val in &idx_t.data {
143                let idx = val as isize;
144                if idx < 1 || (idx as usize) > dim_len {
145                    return Err(mex("IndexOutOfBounds", "Index out of bounds"));
146                }
147                indices.push(idx as usize);
148            }
149            Ok(SliceSelector::Indices(indices))
150        }
151        Value::LogicalArray(la) => {
152            if la.data.len() != dim_len {
153                return Err(mex(
154                    "IndexShape",
155                    "Logical mask length mismatch for dimension",
156                ));
157            }
158            let mut indices = Vec::new();
159            for (i, &b) in la.data.iter().enumerate() {
160                if b != 0 {
161                    indices.push(i + 1);
162                }
163            }
164            Ok(SliceSelector::Indices(indices))
165        }
166        _ => Err(mex(
167            "UnsupportedIndexType",
168            "Unsupported index type for slicing",
169        )),
170    }
171}
172
173pub async fn build_slice_selectors(
174    dims: usize,
175    colon_mask: u32,
176    end_mask: u32,
177    numeric: &[Value],
178    base_shape: &[usize],
179) -> VmResult<Vec<SliceSelector>> {
180    let mut selectors = Vec::with_capacity(dims);
181    if dims == 1 {
182        let total_len = total_len_from_shape(base_shape);
183        if (colon_mask & 1u32) != 0 {
184            selectors.push(SliceSelector::Indices((1..=total_len).collect()));
185            return Ok(selectors);
186        }
187        if (end_mask & 1u32) != 0 {
188            selectors.push(SliceSelector::Scalar(total_len.max(1)));
189            return Ok(selectors);
190        }
191        let value = numeric.first().ok_or_else(|| {
192            mex(
193                "MissingNumericIndex",
194                "missing numeric index for linear slice",
195            )
196        })?;
197        let materialized = materialize_index_value(value).await?;
198        if let Value::Tensor(idx_t) = &materialized {
199            let len = idx_t.shape.iter().product::<usize>();
200            let mut indices = Vec::with_capacity(len);
201            for &val in &idx_t.data {
202                let idx = val as isize;
203                if idx < 1 || (idx as usize) > total_len {
204                    return Err(mex("IndexOutOfBounds", "Index out of bounds"));
205                }
206                indices.push(idx as usize);
207            }
208            selectors.push(SliceSelector::LinearIndices {
209                values: indices,
210                output_shape: idx_t.shape.clone(),
211            });
212        } else {
213            let idxs = indices_from_value_linear(&materialized, total_len).await?;
214            selectors.push(SliceSelector::Indices(idxs));
215        }
216        return Ok(selectors);
217    }
218
219    let mut numeric_iter = 0usize;
220    for d in 0..dims {
221        let is_colon = (colon_mask & (1u32 << d)) != 0;
222        if is_colon {
223            selectors.push(SliceSelector::Colon);
224            continue;
225        }
226        let dim_len = base_shape.get(d).copied().unwrap_or(1);
227        let is_end = (end_mask & (1u32 << d)) != 0;
228        if is_end {
229            selectors.push(SliceSelector::Scalar(dim_len));
230            continue;
231        }
232        let value = numeric
233            .get(numeric_iter)
234            .ok_or_else(|| mex("MissingNumericIndex", "missing numeric index for slice"))?;
235        numeric_iter += 1;
236        selectors.push(selector_from_value_dim(value, dim_len).await?);
237    }
238    Ok(selectors)
239}