Skip to main content

runmat_vm/indexing/
write_slice.rs

1use crate::indexing::plan::IndexPlan;
2use crate::indexing::selectors::SliceSelector;
3use crate::interpreter::errors::mex;
4use runmat_builtins::{CellArray, ComplexTensor, StringArray, Tensor, Value};
5use runmat_runtime::RuntimeError;
6
7pub fn build_subsasgn_paren_cell(numeric: &[Value]) -> Result<Value, RuntimeError> {
8    let cell = CellArray::new(numeric.to_vec(), 1, numeric.len())
9        .map_err(|e| format!("subsasgn build error: {e}"))?;
10    Ok(Value::Cell(cell))
11}
12
13pub async fn object_subsasgn_paren(
14    base: Value,
15    numeric: &[Value],
16    rhs: Value,
17) -> Result<Value, RuntimeError> {
18    let cell = build_subsasgn_paren_cell(numeric)?;
19    match base {
20        Value::Object(obj) => {
21            let args = vec![
22                Value::Object(obj),
23                Value::String("subsasgn".to_string()),
24                Value::String("()".to_string()),
25                cell,
26                rhs,
27            ];
28            runmat_runtime::call_builtin_async("call_method", &args).await
29        }
30        Value::HandleObject(handle) => {
31            let args = vec![
32                Value::HandleObject(handle),
33                Value::String("subsasgn".to_string()),
34                Value::String("()".to_string()),
35                cell,
36                rhs,
37            ];
38            runmat_runtime::call_builtin_async("call_method", &args).await
39        }
40        other => Err(format!("slice subsasgn requires object/handle, got {other:?}").into()),
41    }
42}
43
44pub enum ComplexRhsView {
45    Scalar((f64, f64)),
46    Tensor {
47        data: Vec<(f64, f64)>,
48        shape: Vec<usize>,
49        strides: Vec<usize>,
50    },
51}
52
53pub fn build_complex_rhs_view(
54    rhs: &Value,
55    selection_lengths: &[usize],
56) -> Result<ComplexRhsView, RuntimeError> {
57    match rhs {
58        Value::Complex(re, im) => Ok(ComplexRhsView::Scalar((*re, *im))),
59        Value::Num(n) => Ok(ComplexRhsView::Scalar((*n, 0.0))),
60        Value::ComplexTensor(rt) => {
61            let dims = selection_lengths.len();
62            let mut shape = rt.shape.clone();
63            if shape.len() < dims {
64                shape.resize(dims, 1);
65            }
66            if shape.len() > dims {
67                if shape.iter().skip(dims).any(|&s| s != 1) {
68                    return Err("shape mismatch for slice assign".to_string().into());
69                }
70                shape.truncate(dims);
71            }
72            for d in 0..dims {
73                let out_len = selection_lengths[d];
74                let rhs_len = shape[d];
75                if !(rhs_len == 1 || rhs_len == out_len) {
76                    return Err("shape mismatch for slice assign".to_string().into());
77                }
78            }
79            let mut rstrides = vec![0usize; dims];
80            let mut racc = 1usize;
81            for d in 0..dims {
82                rstrides[d] = racc;
83                racc *= shape[d];
84            }
85            Ok(ComplexRhsView::Tensor {
86                data: rt.data.clone(),
87                shape,
88                strides: rstrides,
89            })
90        }
91        _ => Err("rhs must be numeric or tensor".to_string().into()),
92    }
93}
94
95pub fn scatter_complex_with_plan(
96    t: &mut ComplexTensor,
97    plan: &IndexPlan,
98    rhs_view: &ComplexRhsView,
99) -> Result<(), RuntimeError> {
100    let dims = plan.dims;
101    let mut idx = vec![0usize; dims];
102    if plan.indices.is_empty() {
103        return Ok(());
104    }
105    let selection_lengths = if plan.selection_lengths.is_empty() {
106        plan.output_shape.clone()
107    } else {
108        plan.selection_lengths.clone()
109    };
110    loop {
111        let mut rlin = 0usize;
112        match rhs_view {
113            ComplexRhsView::Scalar(val) => {
114                let pos = plan.indices[rlin] as usize;
115                t.data[pos] = *val;
116            }
117            ComplexRhsView::Tensor {
118                data,
119                shape,
120                strides,
121            } => {
122                for d in 0..dims {
123                    let rhs_len = shape[d];
124                    let pos = if rhs_len == 1 { 0 } else { idx[d] };
125                    rlin += pos * strides[d];
126                }
127                let lin_pos = {
128                    let mut p = 0usize;
129                    let mut mul = 1usize;
130                    for d in 0..dims {
131                        p += idx[d] * mul;
132                        mul *= selection_lengths[d].max(1);
133                    }
134                    p
135                };
136                let dst = plan.indices[lin_pos] as usize;
137                t.data[dst] = data[rlin];
138            }
139        }
140        let mut d = 0usize;
141        while d < dims {
142            idx[d] += 1;
143            if idx[d] < selection_lengths[d].max(1) {
144                break;
145            }
146            idx[d] = 0;
147            d += 1;
148        }
149        if d == dims {
150            break;
151        }
152    }
153    Ok(())
154}
155
156pub enum StringRhsView {
157    Scalar(String),
158    Tensor {
159        data: Vec<String>,
160        shape: Vec<usize>,
161        strides: Vec<usize>,
162    },
163}
164
165pub fn build_string_rhs_view(
166    rhs: &Value,
167    selection_lengths: &[usize],
168) -> Result<StringRhsView, RuntimeError> {
169    let scalar = match rhs {
170        Value::String(s) => Some(s.clone()),
171        Value::CharArray(ca) => Some(ca.to_string()),
172        _ => None,
173    };
174    if let Some(s) = scalar {
175        return Ok(StringRhsView::Scalar(s));
176    }
177    if let Value::StringArray(rt) = rhs {
178        let dims = selection_lengths.len();
179        let mut shape = rt.shape.clone();
180        if shape.len() < dims {
181            shape.resize(dims, 1);
182        }
183        if shape.len() > dims {
184            if shape.iter().skip(dims).any(|&s| s != 1) {
185                return Err("shape mismatch for slice assign".to_string().into());
186            }
187            shape.truncate(dims);
188        }
189        for d in 0..dims {
190            let out_len = selection_lengths[d];
191            let rhs_len = shape[d];
192            if !(rhs_len == 1 || rhs_len == out_len) {
193                return Err("shape mismatch for slice assign".to_string().into());
194            }
195        }
196        let mut rstrides = vec![0usize; dims];
197        let mut racc = 1usize;
198        for d in 0..dims {
199            rstrides[d] = racc;
200            racc *= shape[d];
201        }
202        return Ok(StringRhsView::Tensor {
203            data: rt.data.clone(),
204            shape,
205            strides: rstrides,
206        });
207    }
208    Err("rhs must be string or string array".to_string().into())
209}
210
211pub fn scatter_string_with_plan(
212    sa: &mut StringArray,
213    plan: &IndexPlan,
214    rhs_view: &StringRhsView,
215) -> Result<(), RuntimeError> {
216    let dims = plan.dims;
217    let mut idx = vec![0usize; dims];
218    if plan.indices.is_empty() {
219        return Ok(());
220    }
221    let selection_lengths = if plan.selection_lengths.is_empty() {
222        plan.output_shape.clone()
223    } else {
224        plan.selection_lengths.clone()
225    };
226    loop {
227        match rhs_view {
228            StringRhsView::Scalar(val) => {
229                let lin_pos = {
230                    let mut p = 0usize;
231                    let mut mul = 1usize;
232                    for d in 0..dims {
233                        p += idx[d] * mul;
234                        mul *= selection_lengths[d].max(1);
235                    }
236                    p
237                };
238                let dst = plan.indices[lin_pos] as usize;
239                sa.data[dst] = val.clone();
240            }
241            StringRhsView::Tensor {
242                data,
243                shape,
244                strides,
245            } => {
246                let mut rlin = 0usize;
247                for d in 0..dims {
248                    let rhs_len = shape[d];
249                    let pos = if rhs_len == 1 { 0 } else { idx[d] };
250                    rlin += pos * strides[d];
251                }
252                let lin_pos = {
253                    let mut p = 0usize;
254                    let mut mul = 1usize;
255                    for d in 0..dims {
256                        p += idx[d] * mul;
257                        mul *= selection_lengths[d].max(1);
258                    }
259                    p
260                };
261                let dst = plan.indices[lin_pos] as usize;
262                sa.data[dst] = data[rlin].clone();
263            }
264        }
265        let mut d = 0usize;
266        while d < dims {
267            idx[d] += 1;
268            if idx[d] < selection_lengths[d].max(1) {
269                break;
270            }
271            idx[d] = 0;
272            d += 1;
273        }
274        if d == dims {
275            break;
276        }
277    }
278    Ok(())
279}
280
281pub async fn materialize_rhs_real_for_plan(
282    rhs: &Value,
283    plan: &IndexPlan,
284) -> Result<Vec<f64>, RuntimeError> {
285    if plan.dims == 1 {
286        let count = plan.selection_lengths.first().copied().unwrap_or(0);
287        materialize_rhs_linear_real(rhs, count).await
288    } else {
289        materialize_rhs_nd_real(rhs, &plan.selection_lengths).await
290    }
291}
292
293pub fn scatter_real_with_plan(
294    t: &mut Tensor,
295    plan: &IndexPlan,
296    rhs_values: &[f64],
297) -> Result<(), RuntimeError> {
298    if rhs_values.len() != plan.indices.len() {
299        return Err(mex("ShapeMismatch", "shape mismatch for slice assign"));
300    }
301    for (&dst, &value) in plan.indices.iter().zip(rhs_values.iter()) {
302        t.data[dst as usize] = value;
303    }
304    Ok(())
305}
306
307pub async fn assign_tensor_with_plan(
308    mut t: Tensor,
309    plan: &IndexPlan,
310    rhs: &Value,
311) -> Result<Value, RuntimeError> {
312    if plan.indices.is_empty() {
313        return Ok(Value::Tensor(t));
314    }
315    let rhs_values = materialize_rhs_real_for_plan(rhs, plan).await?;
316    scatter_real_with_plan(&mut t, plan, &rhs_values)?;
317    Ok(Value::Tensor(t))
318}
319
320pub async fn assign_gpu_slice_with_plan(
321    handle: &runmat_accelerate_api::GpuTensorHandle,
322    plan: &IndexPlan,
323    rhs: &Value,
324) -> Result<Value, RuntimeError> {
325    if plan.indices.is_empty() {
326        return Ok(Value::GpuTensor(handle.clone()));
327    }
328    let provider = runmat_accelerate_api::provider().ok_or_else(|| {
329        mex(
330            "AccelerationProviderUnavailable",
331            "No acceleration provider registered",
332        )
333    })?;
334    if let Value::GpuTensor(vh) = rhs {
335        let rows = plan.base_shape.first().copied().unwrap_or(1);
336        let cols = plan.base_shape.get(1).copied().unwrap_or(1);
337        if let Some(col) = plan.properties.full_column {
338            if col < cols {
339                let v_rows = match vh.shape.len() {
340                    1 | 2 => vh.shape[0],
341                    _ => 0,
342                };
343                if v_rows == rows {
344                    if let Ok(new_h) = provider.scatter_column(handle, col, vh) {
345                        return Ok(Value::GpuTensor(new_h));
346                    }
347                }
348            }
349        }
350        if let Some(row) = plan.properties.full_row {
351            if row < rows {
352                let v_cols = match vh.shape.len() {
353                    1 => vh.shape[0],
354                    2 => vh.shape[1],
355                    _ => 0,
356                };
357                if v_cols == cols {
358                    if let Ok(new_h) = provider.scatter_row(handle, row, vh) {
359                        return Ok(Value::GpuTensor(new_h));
360                    }
361                }
362            }
363        }
364    }
365    let rhs_values = materialize_rhs_real_for_plan(rhs, plan).await?;
366    let value_shape = vec![rhs_values.len().max(1), 1];
367    let upload_result = if rhs_values.is_empty() {
368        provider.zeros(&[0, 1])
369    } else {
370        provider.upload(&runmat_accelerate_api::HostTensorView {
371            data: &rhs_values,
372            shape: &value_shape,
373        })
374    };
375    if let Ok(values_handle) = upload_result {
376        if provider
377            .scatter_linear(handle, &plan.indices, &values_handle)
378            .is_ok()
379        {
380            return Ok(Value::GpuTensor(handle.clone()));
381        }
382    }
383
384    let host = provider
385        .download(handle)
386        .await
387        .map_err(|e| format!("gather for slice assign: {e}"))?;
388    let mut t = Tensor::new(host.data, host.shape).map_err(|e| format!("slice assign: {e}"))?;
389    scatter_real_with_plan(&mut t, plan, &rhs_values)?;
390    upload_tensor_to_gpu(&t)
391}
392
393pub async fn materialize_rhs_linear_real(
394    rhs: &Value,
395    count: usize,
396) -> Result<Vec<f64>, RuntimeError> {
397    let host_rhs = runmat_runtime::dispatcher::gather_if_needed_async(rhs).await?;
398    match host_rhs {
399        Value::Num(n) => Ok(vec![n; count]),
400        Value::Int(int_val) => Ok(vec![int_val.to_f64(); count]),
401        Value::Bool(b) => Ok(vec![if b { 1.0 } else { 0.0 }; count]),
402        Value::Tensor(t) => {
403            if t.data.len() == count {
404                Ok(t.data)
405            } else if t.data.len() == 1 {
406                Ok(vec![t.data[0]; count])
407            } else {
408                Err(mex("ShapeMismatch", "shape mismatch for slice assign"))
409            }
410        }
411        Value::LogicalArray(la) => {
412            if la.data.len() == count {
413                Ok(la
414                    .data
415                    .into_iter()
416                    .map(|b| if b != 0 { 1.0 } else { 0.0 })
417                    .collect())
418            } else if la.data.len() == 1 {
419                let val = if la.data[0] != 0 { 1.0 } else { 0.0 };
420                Ok(vec![val; count])
421            } else {
422                Err(mex("ShapeMismatch", "shape mismatch for slice assign"))
423            }
424        }
425        other => Err(mex(
426            "InvalidSliceAssignmentRhs",
427            &format!("slice assign: unsupported RHS type {:?}", other),
428        )),
429    }
430}
431
432pub async fn materialize_rhs_nd_real(
433    rhs: &Value,
434    selection_lengths: &[usize],
435) -> Result<Vec<f64>, RuntimeError> {
436    let rhs_host = runmat_runtime::dispatcher::gather_if_needed_async(rhs).await?;
437    enum RhsView {
438        Scalar(f64),
439        Tensor {
440            data: Vec<f64>,
441            shape: Vec<usize>,
442            strides: Vec<usize>,
443        },
444    }
445    let view = match rhs_host {
446        Value::Num(n) => RhsView::Scalar(n),
447        Value::Int(iv) => RhsView::Scalar(iv.to_f64()),
448        Value::Bool(b) => RhsView::Scalar(if b { 1.0 } else { 0.0 }),
449        Value::Tensor(t) => {
450            let mut shape = t.shape.clone();
451            if shape.len() < selection_lengths.len() {
452                shape.resize(selection_lengths.len(), 1);
453            }
454            if shape.len() > selection_lengths.len() {
455                if shape.iter().skip(selection_lengths.len()).any(|&s| s != 1) {
456                    return Err(mex("ShapeMismatch", "shape mismatch for slice assign"));
457                }
458                shape.truncate(selection_lengths.len());
459            }
460            for (dim_len, &sel_len) in shape.iter().zip(selection_lengths.iter()) {
461                if *dim_len != 1 && *dim_len != sel_len {
462                    return Err(mex("ShapeMismatch", "shape mismatch for slice assign"));
463                }
464            }
465            let mut strides = vec![1usize; selection_lengths.len()];
466            for d in 1..selection_lengths.len() {
467                strides[d] = strides[d - 1] * shape[d - 1].max(1);
468            }
469            if t.data.len()
470                != shape
471                    .iter()
472                    .copied()
473                    .fold(1usize, |acc, len| acc.saturating_mul(len.max(1)))
474            {
475                return Err(mex("ShapeMismatch", "shape mismatch for slice assign"));
476            }
477            RhsView::Tensor {
478                data: t.data,
479                shape,
480                strides,
481            }
482        }
483        Value::LogicalArray(la) => {
484            if la.shape.len() > selection_lengths.len()
485                && la
486                    .shape
487                    .iter()
488                    .skip(selection_lengths.len())
489                    .any(|&s| s != 1)
490            {
491                return Err(mex("ShapeMismatch", "shape mismatch for slice assign"));
492            }
493            let mut shape = la.shape.clone();
494            if shape.len() < selection_lengths.len() {
495                shape.resize(selection_lengths.len(), 1);
496            } else {
497                shape.truncate(selection_lengths.len());
498            }
499            for (dim_len, &sel_len) in shape.iter().zip(selection_lengths.iter()) {
500                if *dim_len != 1 && *dim_len != sel_len {
501                    return Err(mex("ShapeMismatch", "shape mismatch for slice assign"));
502                }
503            }
504            let mut strides = vec![1usize; selection_lengths.len()];
505            for d in 1..selection_lengths.len() {
506                strides[d] = strides[d - 1] * shape[d - 1].max(1);
507            }
508            if la.data.len()
509                != shape
510                    .iter()
511                    .copied()
512                    .fold(1usize, |acc, len| acc.saturating_mul(len.max(1)))
513            {
514                return Err(mex("ShapeMismatch", "shape mismatch for slice assign"));
515            }
516            let data: Vec<f64> = la
517                .data
518                .into_iter()
519                .map(|b| if b != 0 { 1.0 } else { 0.0 })
520                .collect();
521            RhsView::Tensor {
522                data,
523                shape,
524                strides,
525            }
526        }
527        other => {
528            return Err(mex(
529                "InvalidSliceAssignmentRhs",
530                &format!("slice assign: unsupported RHS type {:?}", other),
531            ))
532        }
533    };
534
535    let total = selection_lengths
536        .iter()
537        .copied()
538        .fold(1usize, |acc, len| acc.saturating_mul(len.max(1)));
539    let mut out = Vec::with_capacity(total);
540    let mut idx = vec![0usize; selection_lengths.len()];
541    if selection_lengths.is_empty() {
542        return Ok(out);
543    }
544    loop {
545        match &view {
546            RhsView::Scalar(val) => out.push(*val),
547            RhsView::Tensor {
548                data,
549                shape,
550                strides,
551            } => {
552                let mut rlin = 0usize;
553                for d in 0..idx.len() {
554                    let rhs_len = shape[d];
555                    let pos = if rhs_len == 1 { 0 } else { idx[d] };
556                    rlin += pos * strides[d];
557                }
558                out.push(data.get(rlin).copied().unwrap_or(0.0));
559            }
560        }
561        let mut d = 0usize;
562        while d < idx.len() {
563            idx[d] += 1;
564            if idx[d] < selection_lengths[d].max(1) {
565                break;
566            }
567            idx[d] = 0;
568            d += 1;
569        }
570        if d == idx.len() {
571            break;
572        }
573    }
574    Ok(out)
575}
576
577pub fn upload_tensor_to_gpu(t: &Tensor) -> Result<Value, RuntimeError> {
578    let provider = runmat_accelerate_api::provider().ok_or_else(|| {
579        mex(
580            "AccelerationProviderUnavailable",
581            "No acceleration provider registered",
582        )
583    })?;
584    let view = runmat_accelerate_api::HostTensorView {
585        data: &t.data,
586        shape: &t.shape,
587    };
588    let new_h = provider
589        .upload(&view)
590        .map_err(|e| format!("reupload after slice assign: {e}"))?;
591    Ok(Value::GpuTensor(new_h))
592}
593
594pub struct ExprSelectorSpec<'a> {
595    pub dims: usize,
596    pub colon_mask: u32,
597    pub end_mask: u32,
598    pub range_dims: &'a [usize],
599    pub range_params: &'a [(f64, f64)],
600    pub range_start_exprs: &'a [Option<crate::bytecode::EndExpr>],
601    pub range_step_exprs: &'a [Option<crate::bytecode::EndExpr>],
602    pub range_end_exprs: &'a [crate::bytecode::EndExpr],
603    pub numeric: &'a [Value],
604    pub shape: &'a [usize],
605}
606
607pub async fn build_expr_selectors<ResolveEnd, Fut>(
608    spec: ExprSelectorSpec<'_>,
609    mut resolve_end: ResolveEnd,
610) -> Result<Vec<SliceSelector>, RuntimeError>
611where
612    ResolveEnd: FnMut(usize, &crate::bytecode::EndExpr) -> Fut,
613    Fut: std::future::Future<Output = Result<i64, RuntimeError>>,
614{
615    let mut selectors: Vec<SliceSelector> = Vec::with_capacity(spec.dims);
616    let mut num_iter = 0usize;
617    let mut rp_iter = 0usize;
618    for d in 0..spec.dims {
619        if let Some(pos) = spec.range_dims.iter().position(|&rd| rd == d) {
620            let (raw_st, raw_sp) = spec.range_params[rp_iter];
621            let dim_len = *spec.shape.get(d).unwrap_or(&1);
622            let st = if let Some(expr) = &spec.range_start_exprs[rp_iter] {
623                resolve_end(dim_len, expr).await? as f64
624            } else {
625                raw_st
626            };
627            let sp = if let Some(expr) = &spec.range_step_exprs[rp_iter] {
628                resolve_end(dim_len, expr).await? as f64
629            } else {
630                raw_sp
631            };
632            rp_iter += 1;
633            let step_i = if sp >= 0.0 {
634                sp as i64
635            } else {
636                -(sp.abs() as i64)
637            };
638            let end_i = resolve_end(dim_len, &spec.range_end_exprs[pos]).await?;
639            if step_i == 0 {
640                return Err(mex("IndexStepZero", "Index step cannot be zero"));
641            }
642            let mut vals = Vec::new();
643            let mut cur = st as i64;
644            if step_i > 0 {
645                while cur <= end_i {
646                    if cur < 1 || cur > dim_len as i64 {
647                        break;
648                    }
649                    vals.push(cur as usize);
650                    cur += step_i;
651                }
652            } else {
653                while cur >= end_i {
654                    if cur < 1 || cur > dim_len as i64 {
655                        break;
656                    }
657                    vals.push(cur as usize);
658                    cur += step_i;
659                }
660            }
661            selectors.push(SliceSelector::Indices(vals));
662            continue;
663        }
664        let is_colon = (spec.colon_mask & (1u32 << d)) != 0;
665        let is_end = (spec.end_mask & (1u32 << d)) != 0;
666        if is_colon {
667            selectors.push(SliceSelector::Colon);
668        } else if is_end {
669            selectors.push(SliceSelector::Scalar(*spec.shape.get(d).unwrap_or(&1)));
670        } else {
671            let v = spec
672                .numeric
673                .get(num_iter)
674                .ok_or_else(|| mex("MissingNumericIndex", "missing numeric index"))?;
675            num_iter += 1;
676            let dim_len = *spec.shape.get(d).unwrap_or(&1);
677            selectors.push(
678                match crate::indexing::selectors::selector_from_value_dim(v, dim_len).await? {
679                    SliceSelector::LinearIndices { values, .. } => SliceSelector::Indices(values),
680                    other => other,
681                },
682            );
683        }
684    }
685    Ok(selectors)
686}