Skip to main content

runmat_vm/indexing/
write_linear.rs

1use crate::interpreter::errors::mex;
2use runmat_builtins::{ComplexTensor, Tensor, Value};
3use runmat_runtime::RuntimeError;
4
5pub async fn rhs_to_real_scalar(rhs: &Value) -> Result<f64, RuntimeError> {
6    match rhs {
7        Value::Num(x) => Ok(*x),
8        Value::Tensor(t2) => {
9            if t2.data.len() == 1 {
10                Ok(t2.data[0])
11            } else {
12                Err(mex("ScalarRequired", "RHS must be scalar"))
13            }
14        }
15        Value::GpuTensor(h2) => {
16            let total = h2.shape.iter().copied().product::<usize>();
17            if total != 1 {
18                return Err(mex("ScalarRequired", "RHS must be scalar"));
19            }
20            let provider = runmat_accelerate_api::provider().ok_or_else(|| {
21                mex(
22                    "AccelerationProviderUnavailable",
23                    "No acceleration provider registered",
24                )
25            })?;
26            let host = provider
27                .download(h2)
28                .await
29                .map_err(|e| format!("gather rhs: {e}"))?;
30            Ok(host.data[0])
31        }
32        _ => rhs
33            .try_into()
34            .map_err(|_| mex("NumericRequired", "RHS must be numeric")),
35    }
36}
37
38pub async fn rhs_to_complex_scalar(rhs: &Value) -> Result<(f64, f64), RuntimeError> {
39    match rhs {
40        Value::Complex(re, im) => Ok((*re, *im)),
41        Value::Num(n) => Ok((*n, 0.0)),
42        Value::Int(i) => Ok((i.to_f64(), 0.0)),
43        Value::Bool(b) => Ok((if *b { 1.0 } else { 0.0 }, 0.0)),
44        Value::Tensor(t) if t.data.len() == 1 => Ok((t.data[0], 0.0)),
45        Value::ComplexTensor(t) if t.data.len() == 1 => Ok(t.data[0]),
46        Value::GpuTensor(h) => {
47            let total = h.shape.iter().copied().product::<usize>();
48            if total != 1 {
49                return Err(mex("ScalarRequired", "RHS must be scalar"));
50            }
51            let provider = runmat_accelerate_api::provider().ok_or_else(|| {
52                mex(
53                    "AccelerationProviderUnavailable",
54                    "No acceleration provider registered",
55                )
56            })?;
57            let host = provider
58                .download(h)
59                .await
60                .map_err(|e| format!("gather rhs: {e}"))?;
61            Ok((host.data[0], 0.0))
62        }
63        _ => Err(mex("NumericRequired", "RHS must be numeric")),
64    }
65}
66
67pub async fn assign_tensor_scalar(
68    mut t: Tensor,
69    indices: &[usize],
70    rhs: &Value,
71) -> Result<Value, RuntimeError> {
72    if indices.len() == 1 {
73        let total = t.rows * t.cols;
74        let idx = indices[0];
75        if idx == 0 || idx > total {
76            return Err(mex("IndexOutOfBounds", "Index out of bounds"));
77        }
78        let val = rhs_to_real_scalar(rhs).await?;
79        t.data[idx - 1] = val;
80        Ok(Value::Tensor(t))
81    } else if indices.len() == 2 {
82        let i = indices[0];
83        let mut j = indices[1];
84        let rows = t.rows;
85        let cols = t.cols;
86        if j == 0 {
87            j = 1;
88        }
89        if j > cols {
90            j = cols;
91        }
92        if i == 0 || i > rows {
93            return Err(mex("SubscriptOutOfBounds", "Subscript out of bounds"));
94        }
95        let val = rhs_to_real_scalar(rhs).await?;
96        let idx = (i - 1) + (j - 1) * rows;
97        t.data[idx] = val;
98        Ok(Value::Tensor(t))
99    } else {
100        Err(mex(
101            "UnsupportedAssignmentRank",
102            "Only 1D/2D scalar assignment supported",
103        ))
104    }
105}
106
107pub async fn assign_complex_scalar(
108    mut t: ComplexTensor,
109    indices: &[usize],
110    rhs: &Value,
111) -> Result<Value, RuntimeError> {
112    if indices.len() == 1 {
113        let total = t.rows * t.cols;
114        let idx = indices[0];
115        if idx == 0 || idx > total {
116            return Err(mex("IndexOutOfBounds", "Index out of bounds"));
117        }
118        let val = rhs_to_complex_scalar(rhs).await?;
119        t.data[idx - 1] = val;
120        Ok(Value::ComplexTensor(t))
121    } else if indices.len() == 2 {
122        let i = indices[0];
123        let mut j = indices[1];
124        let rows = t.rows;
125        let cols = t.cols;
126        if j == 0 {
127            j = 1;
128        }
129        if j > cols {
130            j = cols;
131        }
132        if i == 0 || i > rows {
133            return Err(mex("SubscriptOutOfBounds", "Subscript out of bounds"));
134        }
135        let val = rhs_to_complex_scalar(rhs).await?;
136        let idx = (i - 1) + (j - 1) * rows;
137        t.data[idx] = val;
138        Ok(Value::ComplexTensor(t))
139    } else {
140        Err(mex(
141            "UnsupportedAssignmentRank",
142            "Only 1D/2D scalar assignment supported",
143        ))
144    }
145}
146
147pub async fn assign_gpu_scalar(
148    h: &runmat_accelerate_api::GpuTensorHandle,
149    indices: &[usize],
150    rhs: &Value,
151) -> Result<Value, RuntimeError> {
152    let provider = runmat_accelerate_api::provider().ok_or_else(|| {
153        mex(
154            "AccelerationProviderUnavailable",
155            "No acceleration provider registered",
156        )
157    })?;
158    let host = provider
159        .download(h)
160        .await
161        .map_err(|e| format!("gather for assignment: {e}"))?;
162    let t = Tensor::new(host.data, host.shape).map_err(|e| format!("assignment: {e}"))?;
163    let Value::Tensor(updated) = assign_tensor_scalar(t, indices, rhs).await? else {
164        unreachable!()
165    };
166    let view = runmat_accelerate_api::HostTensorView {
167        data: &updated.data,
168        shape: &updated.shape,
169    };
170    let new_h = provider
171        .upload(&view)
172        .map_err(|e| format!("reupload after assignment: {e}"))?;
173    Ok(Value::GpuTensor(new_h))
174}