1use std::any::{Any, TypeId};
15use std::collections::HashMap;
16
17use ::ndarray::{Array1, Array2, Ix1, Ix2, IxDyn};
18
19use crate::array_protocol::{
20 get_implementing_args, ArrayFunction, ArrayProtocol, NdarrayWrapper, NotImplemented,
21};
22use crate::error::CoreError;
23#[derive(Debug, thiserror::Error)]
27pub enum OperationError {
28 #[error("Operation not implemented: {0}")]
30 NotImplemented(String),
31 #[error("Shape mismatch: {0}")]
33 ShapeMismatch(String),
34 #[error("Type mismatch: {0}")]
36 TypeMismatch(String),
37 #[error("Operation error: {0}")]
39 Other(String),
40}
41
42impl From<NotImplemented> for OperationError {
43 fn from(_: NotImplemented) -> Self {
44 Self::NotImplemented("Operation not implemented for these array types".to_string())
45 }
46}
47
48impl From<CoreError> for OperationError {
49 fn from(err: CoreError) -> Self {
50 Self::Other(err.to_string())
51 }
52}
53
54#[macro_export]
58macro_rules! array_function_dispatch {
59 (fn $name:ident($($arg:ident: $arg_ty:ty),*) -> Result<$ret:ty, $err:ty> $body:block, $funcname:expr) => {
61 pub fn $name($($arg: $arg_ty),*) -> Result<$ret, $err> $body
62 };
63
64 (fn $name:ident($($arg:ident: $arg_ty:ty,)*) -> Result<$ret:ty, $err:ty> $body:block, $funcname:expr) => {
66 pub fn $name($($arg: $arg_ty),*) -> Result<$ret, $err> $body
67 };
68
69 (fn $name:ident<$($type_param:ident $(: $type_bound:path)?),*>($($arg:ident: $arg_ty:ty),*) -> Result<$ret:ty, $err:ty> $body:block, $funcname:expr) => {
71 pub fn $name <$($type_param $(: $type_bound)?),*>($($arg: $arg_ty),*) -> Result<$ret, $err> $body
72 };
73
74 (fn $name:ident<$($type_param:ident $(: $type_bound:path)?),*>($($arg:ident: $arg_ty:ty,)*) -> Result<$ret:ty, $err:ty> $body:block, $funcname:expr) => {
76 pub fn $name <$($type_param $(: $type_bound)?),*>($($arg: $arg_ty),*) -> Result<$ret, $err> $body
77 };
78}
79
80array_function_dispatch!(
82 fn matmul(
83 a: &dyn ArrayProtocol,
84 b: &dyn ArrayProtocol,
85 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
86 let boxed_a = Box::new(a.box_clone());
88 let boxed_b = Box::new(b.box_clone());
89 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_a, boxed_b];
90 let implementing_args = get_implementing_args(&boxed_args);
91 if implementing_args.is_empty() {
92 if let (Some(a_array), Some(b_array)) = (
96 a.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>(),
97 b.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>(),
98 ) {
99 let a_array_owned = a_array.as_array().clone();
100 let b_array_owned = b_array.as_array().clone();
101 let (m, k) = a_array_owned.dim();
102 let (_, n) = b_array_owned.dim();
103 let mut result = crate::ndarray::Array2::<f64>::zeros((m, n));
104 for i in 0..m {
105 for j in 0..n {
106 let mut sum = 0.0;
107 for l in 0..k {
108 sum += a_array_owned[[i, l]] * b_array_owned[[l, j]];
109 }
110 result[[0, j]] = sum;
111 }
112 }
113 return Ok(Box::new(NdarrayWrapper::new(result)));
114 }
115
116 if let (Some(a_array), Some(b_array)) = (
118 a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
119 b.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
120 ) {
121 let a_array_owned = a_array.as_array().to_owned();
122 let b_array_owned = b_array.as_array().to_owned();
123 let a_dim = a_array_owned.shape();
124 let b_dim = b_array_owned.shape();
125 if a_dim.len() != 2 || b_dim.len() != 2 || a_dim[1] != b_dim[0] {
126 return Err(OperationError::ShapeMismatch(format!(
127 "Invalid shapes for matmul: {a_dim:?} and {b_dim:?}"
128 )));
129 }
130 let (m, k) = (a_dim[0], a_dim[1]);
131 let n = b_dim[1];
132 let mut result = crate::ndarray::Array2::<f64>::zeros((m, n));
133 for i in 0..m {
134 for j in 0..n {
135 let mut sum = 0.0;
136 for l in 0..k {
137 sum += a_array_owned[[i, l]] * b_array_owned[[l, j]];
138 }
139 result[[0, j]] = sum;
140 }
141 }
142 return Ok(Box::new(NdarrayWrapper::new(result)));
143 }
144
145 if let (Some(a_array), Some(b_array)) = (
147 a.as_any().downcast_ref::<NdarrayWrapper<f32, Ix2>>(),
148 b.as_any().downcast_ref::<NdarrayWrapper<f32, Ix2>>(),
149 ) {
150 let a_array_owned = a_array.as_array().clone();
151 let b_array_owned = b_array.as_array().clone();
152 let (m, k) = a_array_owned.dim();
153 let (_, n) = b_array_owned.dim();
154 let mut result = crate::ndarray::Array2::<f32>::zeros((m, n));
155 for i in 0..m {
156 for j in 0..n {
157 let mut sum = 0.0;
158 for l in 0..k {
159 sum += a_array_owned[[i, l]] * b_array_owned[[l, j]];
160 }
161 result[[0, j]] = sum;
162 }
163 }
164 return Ok(Box::new(NdarrayWrapper::new(result)));
165 }
166
167 if let (Some(a_array), Some(b_array)) = (
169 a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
170 b.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
171 ) {
172 let a_array_owned = a_array.as_array().to_owned();
173 let b_array_owned = b_array.as_array().to_owned();
174 let a_dim = a_array_owned.shape();
175 let b_dim = b_array_owned.shape();
176 if a_dim.len() != 2 || b_dim.len() != 2 || a_dim[1] != b_dim[0] {
177 return Err(OperationError::ShapeMismatch(format!(
178 "Invalid shapes for matmul: {a_dim:?} and {b_dim:?}"
179 )));
180 }
181 let (m, k) = (a_dim[0], a_dim[1]);
182 let n = b_dim[1];
183 let mut result = crate::ndarray::Array2::<f32>::zeros((m, n));
184 for i in 0..m {
185 for j in 0..n {
186 let mut sum = 0.0;
187 for l in 0..k {
188 sum += a_array_owned[[i, l]] * b_array_owned[[l, j]];
189 }
190 result[[0, j]] = sum;
191 }
192 }
193 return Ok(Box::new(NdarrayWrapper::new(result)));
194 }
195
196 return Err(OperationError::NotImplemented(
197 "matmul not implemented for these array types".to_string(),
198 ));
199 }
200
201 let array_ref = implementing_args[0].1;
203
204 let result = array_ref.array_function(
205 &ArrayFunction::new("scirs2::array_protocol::operations::matmul"),
206 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
207 &[Box::new(a.box_clone()), Box::new(b.box_clone())],
208 &HashMap::new(),
209 )?;
210
211 match result.downcast::<Box<dyn ArrayProtocol>>() {
213 Ok(array) => Ok(*array),
214 Err(_) => Err(OperationError::Other(
215 "Failed to downcast result to ArrayProtocol".to_string(),
216 )),
217 }
218 },
219 "scirs2::array_protocol::operations::matmul"
220);
221
222array_function_dispatch!(
224 fn add(
225 a: &dyn ArrayProtocol,
226 b: &dyn ArrayProtocol,
227 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
228 let boxed_a = Box::new(a.box_clone());
230 let boxed_b = Box::new(b.box_clone());
231 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_a, boxed_b];
232 let implementing_args = get_implementing_args(&boxed_args);
233 if implementing_args.is_empty() {
234 if let (Some(a_array), Some(b_array)) = (
238 a.as_any().downcast_ref::<NdarrayWrapper<f64, Ix1>>(),
239 b.as_any().downcast_ref::<NdarrayWrapper<f64, Ix1>>(),
240 ) {
241 let result = a_array.as_array() + b_array.as_array();
242 return Ok(Box::new(NdarrayWrapper::new(result)));
243 }
244 if let (Some(a_array), Some(b_array)) = (
245 a.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>(),
246 b.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>(),
247 ) {
248 let result = a_array.as_array() + b_array.as_array();
249 return Ok(Box::new(NdarrayWrapper::new(result)));
250 }
251 if let (Some(a_array), Some(b_array)) = (
252 a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
253 b.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
254 ) {
255 let result = a_array.as_array() + b_array.as_array();
256 return Ok(Box::new(NdarrayWrapper::new(result)));
257 }
258
259 if let (Some(a_array), Some(b_array)) = (
261 a.as_any().downcast_ref::<NdarrayWrapper<f32, Ix1>>(),
262 b.as_any().downcast_ref::<NdarrayWrapper<f32, Ix1>>(),
263 ) {
264 let result = a_array.as_array() + b_array.as_array();
265 return Ok(Box::new(NdarrayWrapper::new(result)));
266 }
267 if let (Some(a_array), Some(b_array)) = (
268 a.as_any().downcast_ref::<NdarrayWrapper<f32, Ix2>>(),
269 b.as_any().downcast_ref::<NdarrayWrapper<f32, Ix2>>(),
270 ) {
271 let result = a_array.as_array() + b_array.as_array();
272 return Ok(Box::new(NdarrayWrapper::new(result)));
273 }
274 if let (Some(a_array), Some(b_array)) = (
275 a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
276 b.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
277 ) {
278 let result = a_array.as_array() + b_array.as_array();
279 return Ok(Box::new(NdarrayWrapper::new(result)));
280 }
281
282 if let (Some(a_array), Some(b_array)) = (
284 a.as_any().downcast_ref::<NdarrayWrapper<i32, Ix1>>(),
285 b.as_any().downcast_ref::<NdarrayWrapper<i32, Ix1>>(),
286 ) {
287 let result = a_array.as_array() + b_array.as_array();
288 return Ok(Box::new(NdarrayWrapper::new(result)));
289 }
290 if let (Some(a_array), Some(b_array)) = (
291 a.as_any().downcast_ref::<NdarrayWrapper<i32, Ix2>>(),
292 b.as_any().downcast_ref::<NdarrayWrapper<i32, Ix2>>(),
293 ) {
294 let result = a_array.as_array() + b_array.as_array();
295 return Ok(Box::new(NdarrayWrapper::new(result)));
296 }
297 if let (Some(a_array), Some(b_array)) = (
298 a.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>(),
299 b.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>(),
300 ) {
301 let result = a_array.as_array() + b_array.as_array();
302 return Ok(Box::new(NdarrayWrapper::new(result)));
303 }
304
305 if let (Some(a_array), Some(b_array)) = (
307 a.as_any().downcast_ref::<NdarrayWrapper<i64, Ix1>>(),
308 b.as_any().downcast_ref::<NdarrayWrapper<i64, Ix1>>(),
309 ) {
310 let result = a_array.as_array() + b_array.as_array();
311 return Ok(Box::new(NdarrayWrapper::new(result)));
312 }
313 if let (Some(a_array), Some(b_array)) = (
314 a.as_any().downcast_ref::<NdarrayWrapper<i64, Ix2>>(),
315 b.as_any().downcast_ref::<NdarrayWrapper<i64, Ix2>>(),
316 ) {
317 let result = a_array.as_array() + b_array.as_array();
318 return Ok(Box::new(NdarrayWrapper::new(result)));
319 }
320 if let (Some(a_array), Some(b_array)) = (
321 a.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>(),
322 b.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>(),
323 ) {
324 let result = a_array.as_array() + b_array.as_array();
325 return Ok(Box::new(NdarrayWrapper::new(result)));
326 }
327
328 return Err(OperationError::NotImplemented(
329 "add not implemented for these array types".to_string(),
330 ));
331 }
332
333 let array_ref = implementing_args[0].1;
335
336 let result = array_ref.array_function(
337 &ArrayFunction::new("scirs2::array_protocol::operations::add"),
338 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
339 &[Box::new(a.box_clone()), Box::new(b.box_clone())],
340 &HashMap::new(),
341 )?;
342
343 match result.downcast::<Box<dyn ArrayProtocol>>() {
345 Ok(array) => Ok(*array),
346 Err(_) => Err(OperationError::Other(
347 "Failed to downcast result to ArrayProtocol".to_string(),
348 )),
349 }
350 },
351 "scirs2::array_protocol::operations::add"
352);
353
354array_function_dispatch!(
356 fn subtract(
357 a: &dyn ArrayProtocol,
358 b: &dyn ArrayProtocol,
359 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
360 let boxed_a = Box::new(a.box_clone());
362 let boxed_b = Box::new(b.box_clone());
363 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_a, boxed_b];
364 let implementing_args = get_implementing_args(&boxed_args);
365 if implementing_args.is_empty() {
366 if let (Some(a_array), Some(b_array)) = (
370 a.as_any().downcast_ref::<NdarrayWrapper<f64, Ix1>>(),
371 b.as_any().downcast_ref::<NdarrayWrapper<f64, Ix1>>(),
372 ) {
373 let result = a_array.as_array() - b_array.as_array();
374 return Ok(Box::new(NdarrayWrapper::new(result)));
375 }
376 if let (Some(a_array), Some(b_array)) = (
377 a.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>(),
378 b.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>(),
379 ) {
380 let result = a_array.as_array() - b_array.as_array();
381 return Ok(Box::new(NdarrayWrapper::new(result)));
382 }
383 if let (Some(a_array), Some(b_array)) = (
384 a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
385 b.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
386 ) {
387 let result = a_array.as_array() - b_array.as_array();
388 return Ok(Box::new(NdarrayWrapper::new(result)));
389 }
390
391 if let (Some(a_array), Some(b_array)) = (
393 a.as_any().downcast_ref::<NdarrayWrapper<f32, Ix1>>(),
394 b.as_any().downcast_ref::<NdarrayWrapper<f32, Ix1>>(),
395 ) {
396 let result = a_array.as_array() - b_array.as_array();
397 return Ok(Box::new(NdarrayWrapper::new(result)));
398 }
399 if let (Some(a_array), Some(b_array)) = (
400 a.as_any().downcast_ref::<NdarrayWrapper<f32, Ix2>>(),
401 b.as_any().downcast_ref::<NdarrayWrapper<f32, Ix2>>(),
402 ) {
403 let result = a_array.as_array() - b_array.as_array();
404 return Ok(Box::new(NdarrayWrapper::new(result)));
405 }
406 if let (Some(a_array), Some(b_array)) = (
407 a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
408 b.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
409 ) {
410 let result = a_array.as_array() - b_array.as_array();
411 return Ok(Box::new(NdarrayWrapper::new(result)));
412 }
413
414 if let (Some(a_array), Some(b_array)) = (
416 a.as_any().downcast_ref::<NdarrayWrapper<i32, Ix1>>(),
417 b.as_any().downcast_ref::<NdarrayWrapper<i32, Ix1>>(),
418 ) {
419 let result = a_array.as_array() - b_array.as_array();
420 return Ok(Box::new(NdarrayWrapper::new(result)));
421 }
422 if let (Some(a_array), Some(b_array)) = (
423 a.as_any().downcast_ref::<NdarrayWrapper<i32, Ix2>>(),
424 b.as_any().downcast_ref::<NdarrayWrapper<i32, Ix2>>(),
425 ) {
426 let result = a_array.as_array() - b_array.as_array();
427 return Ok(Box::new(NdarrayWrapper::new(result)));
428 }
429 if let (Some(a_array), Some(b_array)) = (
430 a.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>(),
431 b.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>(),
432 ) {
433 let result = a_array.as_array() - b_array.as_array();
434 return Ok(Box::new(NdarrayWrapper::new(result)));
435 }
436
437 if let (Some(a_array), Some(b_array)) = (
439 a.as_any().downcast_ref::<NdarrayWrapper<i64, Ix1>>(),
440 b.as_any().downcast_ref::<NdarrayWrapper<i64, Ix1>>(),
441 ) {
442 let result = a_array.as_array() - b_array.as_array();
443 return Ok(Box::new(NdarrayWrapper::new(result)));
444 }
445 if let (Some(a_array), Some(b_array)) = (
446 a.as_any().downcast_ref::<NdarrayWrapper<i64, Ix2>>(),
447 b.as_any().downcast_ref::<NdarrayWrapper<i64, Ix2>>(),
448 ) {
449 let result = a_array.as_array() - b_array.as_array();
450 return Ok(Box::new(NdarrayWrapper::new(result)));
451 }
452 if let (Some(a_array), Some(b_array)) = (
453 a.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>(),
454 b.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>(),
455 ) {
456 let result = a_array.as_array() - b_array.as_array();
457 return Ok(Box::new(NdarrayWrapper::new(result)));
458 }
459
460 return Err(OperationError::NotImplemented(
461 "subtract not implemented for these array types".to_string(),
462 ));
463 }
464
465 let array_ref = implementing_args[0].1;
467
468 let result = array_ref.array_function(
469 &ArrayFunction::new("scirs2::array_protocol::operations::subtract"),
470 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
471 &[Box::new(a.box_clone()), Box::new(b.box_clone())],
472 &HashMap::new(),
473 )?;
474
475 match result.downcast::<Box<dyn ArrayProtocol>>() {
477 Ok(array) => Ok(*array),
478 Err(_) => Err(OperationError::Other(
479 "Failed to downcast result to ArrayProtocol".to_string(),
480 )),
481 }
482 },
483 "scirs2::array_protocol::operations::subtract"
484);
485
486array_function_dispatch!(
488 fn multiply(
489 a: &dyn ArrayProtocol,
490 b: &dyn ArrayProtocol,
491 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
492 let boxed_a = Box::new(a.box_clone());
494 let boxed_b = Box::new(b.box_clone());
495 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_a, boxed_b];
496 let implementing_args = get_implementing_args(&boxed_args);
497 if implementing_args.is_empty() {
498 if let (Some(a_array), Some(b_array)) = (
502 a.as_any().downcast_ref::<NdarrayWrapper<f64, Ix1>>(),
503 b.as_any().downcast_ref::<NdarrayWrapper<f64, Ix1>>(),
504 ) {
505 let result = a_array.as_array() * b_array.as_array();
506 return Ok(Box::new(NdarrayWrapper::new(result)));
507 }
508 if let (Some(a_array), Some(b_array)) = (
509 a.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>(),
510 b.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>(),
511 ) {
512 let result = a_array.as_array() * b_array.as_array();
513 return Ok(Box::new(NdarrayWrapper::new(result)));
514 }
515 if let (Some(a_array), Some(b_array)) = (
516 a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
517 b.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
518 ) {
519 let result = a_array.as_array() * b_array.as_array();
520 return Ok(Box::new(NdarrayWrapper::new(result)));
521 }
522
523 if let (Some(a_array), Some(b_array)) = (
525 a.as_any().downcast_ref::<NdarrayWrapper<f32, Ix1>>(),
526 b.as_any().downcast_ref::<NdarrayWrapper<f32, Ix1>>(),
527 ) {
528 let result = a_array.as_array() * b_array.as_array();
529 return Ok(Box::new(NdarrayWrapper::new(result)));
530 }
531 if let (Some(a_array), Some(b_array)) = (
532 a.as_any().downcast_ref::<NdarrayWrapper<f32, Ix2>>(),
533 b.as_any().downcast_ref::<NdarrayWrapper<f32, Ix2>>(),
534 ) {
535 let result = a_array.as_array() * b_array.as_array();
536 return Ok(Box::new(NdarrayWrapper::new(result)));
537 }
538 if let (Some(a_array), Some(b_array)) = (
539 a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
540 b.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
541 ) {
542 let result = a_array.as_array() * b_array.as_array();
543 return Ok(Box::new(NdarrayWrapper::new(result)));
544 }
545
546 if let (Some(a_array), Some(b_array)) = (
548 a.as_any().downcast_ref::<NdarrayWrapper<i32, Ix1>>(),
549 b.as_any().downcast_ref::<NdarrayWrapper<i32, Ix1>>(),
550 ) {
551 let result = a_array.as_array() * b_array.as_array();
552 return Ok(Box::new(NdarrayWrapper::new(result)));
553 }
554 if let (Some(a_array), Some(b_array)) = (
555 a.as_any().downcast_ref::<NdarrayWrapper<i32, Ix2>>(),
556 b.as_any().downcast_ref::<NdarrayWrapper<i32, Ix2>>(),
557 ) {
558 let result = a_array.as_array() * b_array.as_array();
559 return Ok(Box::new(NdarrayWrapper::new(result)));
560 }
561 if let (Some(a_array), Some(b_array)) = (
562 a.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>(),
563 b.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>(),
564 ) {
565 let result = a_array.as_array() * b_array.as_array();
566 return Ok(Box::new(NdarrayWrapper::new(result)));
567 }
568
569 if let (Some(a_array), Some(b_array)) = (
571 a.as_any().downcast_ref::<NdarrayWrapper<i64, Ix1>>(),
572 b.as_any().downcast_ref::<NdarrayWrapper<i64, Ix1>>(),
573 ) {
574 let result = a_array.as_array() * b_array.as_array();
575 return Ok(Box::new(NdarrayWrapper::new(result)));
576 }
577 if let (Some(a_array), Some(b_array)) = (
578 a.as_any().downcast_ref::<NdarrayWrapper<i64, Ix2>>(),
579 b.as_any().downcast_ref::<NdarrayWrapper<i64, Ix2>>(),
580 ) {
581 let result = a_array.as_array() * b_array.as_array();
582 return Ok(Box::new(NdarrayWrapper::new(result)));
583 }
584 if let (Some(a_array), Some(b_array)) = (
585 a.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>(),
586 b.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>(),
587 ) {
588 let result = a_array.as_array() * b_array.as_array();
589 return Ok(Box::new(NdarrayWrapper::new(result)));
590 }
591
592 return Err(OperationError::NotImplemented(
593 "multiply not implemented for these array types".to_string(),
594 ));
595 }
596
597 let array_ref = implementing_args[0].1;
599
600 let result = array_ref.array_function(
601 &ArrayFunction::new("scirs2::array_protocol::operations::multiply"),
602 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
603 &[Box::new(a.box_clone()), Box::new(b.box_clone())],
604 &HashMap::new(),
605 )?;
606
607 match result.downcast::<Box<dyn ArrayProtocol>>() {
609 Ok(array) => Ok(*array),
610 Err(_) => Err(OperationError::Other(
611 "Failed to downcast result to ArrayProtocol".to_string(),
612 )),
613 }
614 },
615 "scirs2::array_protocol::operations::multiply"
616);
617
618array_function_dispatch!(
620 fn sum(a: &dyn ArrayProtocol, axis: Option<usize>) -> Result<Box<dyn Any>, OperationError> {
621 let boxed_a = Box::new(a.box_clone());
623 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_a];
624 let implementing_args = get_implementing_args(&boxed_args);
625 if implementing_args.is_empty() {
626 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
629 match axis {
630 Some(ax) => {
631 let result = a_array.as_array().sum_axis(crate::ndarray::Axis(ax));
632 return Ok(Box::new(NdarrayWrapper::new(result)));
633 }
634 None => {
635 let result = a_array.as_array().sum();
636 return Ok(Box::new(result));
637 }
638 }
639 }
640 else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
642 match axis {
643 Some(ax) => {
644 let result = a_array.as_array().sum_axis(crate::ndarray::Axis(ax));
645 return Ok(Box::new(NdarrayWrapper::new(result)));
646 }
647 None => {
648 let result = a_array.as_array().sum();
649 return Ok(Box::new(result));
650 }
651 }
652 }
653 return Err(OperationError::NotImplemented(
654 "sum not implemented for this array type".to_string(),
655 ));
656 }
657
658 let mut kwargs = HashMap::new();
660 if let Some(ax) = axis {
661 kwargs.insert("axis".to_string(), Box::new(ax) as Box<dyn Any>);
662 }
663
664 let array_ref = implementing_args[0].1;
665
666 let result = array_ref.array_function(
667 &ArrayFunction::new("scirs2::array_protocol::operations::sum"),
668 &[TypeId::of::<Box<dyn Any>>()],
669 &[Box::new(a.box_clone())],
670 &kwargs,
671 )?;
672
673 Ok(result)
674 },
675 "scirs2::array_protocol::operations::sum"
676);
677
678array_function_dispatch!(
680 fn transpose(a: &dyn ArrayProtocol) -> Result<Box<dyn ArrayProtocol>, OperationError> {
681 let boxed_a = Box::new(a.box_clone());
683 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_a];
684 let implementing_args = get_implementing_args(&boxed_args);
685 if implementing_args.is_empty() {
686 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
689 let result = a_array.as_array().t().to_owned();
690 return Ok(Box::new(NdarrayWrapper::new(result)));
691 }
692 else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
694 let a_dim = a_array.as_array().shape();
696 if a_dim.len() != 2 {
697 return Err(OperationError::ShapeMismatch(format!(
698 "Transpose requires a 2D array, got shape: {a_dim:?}"
699 )));
700 }
701
702 let (m, n) = (a_dim[0], a_dim[1]);
704 let mut result = crate::ndarray::Array2::<f64>::zeros((n, m));
705
706 for i in 0..m {
708 for j in 0..n {
709 result[[j, i]] = a_array.as_array()[[i, j]];
710 }
711 }
712
713 return Ok(Box::new(NdarrayWrapper::new(result)));
714 }
715 return Err(OperationError::NotImplemented(
716 "transpose not implemented for this array type".to_string(),
717 ));
718 }
719
720 let array_ref = implementing_args[0].1;
722
723 let result = array_ref.array_function(
724 &ArrayFunction::new("scirs2::array_protocol::operations::transpose"),
725 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
726 &[Box::new(a.box_clone())],
727 &HashMap::new(),
728 )?;
729
730 match result.downcast::<Box<dyn ArrayProtocol>>() {
732 Ok(array) => Ok(*array),
733 Err(_) => Err(OperationError::Other(
734 "Failed to downcast result to ArrayProtocol".to_string(),
735 )),
736 }
737 },
738 "scirs2::array_protocol::operations::transpose"
739);
740
741#[allow(dead_code)]
743pub fn apply_elementwise<F>(
744 a: &dyn ArrayProtocol,
745 f: F,
746) -> Result<Box<dyn ArrayProtocol>, OperationError>
747where
748 F: Fn(f64) -> f64 + 'static,
749{
750 let boxed_a = Box::new(a.box_clone());
752 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_a];
753 let implementing_args = get_implementing_args(&boxed_args);
754 if implementing_args.is_empty() {
755 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
757 let result = a_array.as_array().mapv(f);
758 return Ok(Box::new(NdarrayWrapper::new(result)));
759 }
760 return Err(OperationError::NotImplemented(
761 "apply_elementwise not implemented for this array type".to_string(),
762 ));
763 }
764
765 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
769 let result = a_array.as_array().mapv(f);
770 Ok(Box::new(NdarrayWrapper::new(result)))
771 } else {
772 Err(OperationError::NotImplemented(
773 "apply_elementwise not implemented for this array type".to_string(),
774 ))
775 }
776}
777
778array_function_dispatch!(
780 fn concatenate(
781 arrays: &[&dyn ArrayProtocol],
782 axis: usize,
783 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
784 if arrays.is_empty() {
785 return Err(OperationError::Other(
786 "No arrays provided for concatenation".to_string(),
787 ));
788 }
789
790 let boxed_arrays: Vec<Box<dyn Any>> = arrays
792 .iter()
793 .map(|&a| Box::new(a.box_clone()) as Box<dyn Any>)
794 .collect();
795
796 let implementing_args = get_implementing_args(&boxed_arrays);
797 if implementing_args.is_empty() {
798 let mut ndarray_arrays = Vec::new();
801 for &array in arrays {
802 if let Some(a) = array.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
803 ndarray_arrays.push(a.as_array().view());
804 } else {
805 return Err(OperationError::TypeMismatch(
806 "All arrays must be NdarrayWrapper<f64, Ix2>".to_string(),
807 ));
808 }
809 }
810
811 let result = match crate::ndarray::stack(crate::ndarray::Axis(axis), &ndarray_arrays) {
812 Ok(arr) => arr,
813 Err(e) => return Err(OperationError::Other(format!("Concatenation failed: {e}"))),
814 };
815
816 return Ok(Box::new(NdarrayWrapper::new(result)));
817 }
818
819 let array_boxed_clones: Vec<Box<dyn Any>> = arrays
821 .iter()
822 .map(|&a| Box::new(a.box_clone()) as Box<dyn Any>)
823 .collect();
824
825 let mut kwargs = HashMap::new();
826 kwargs.insert(axis.to_string(), Box::new(axis) as Box<dyn Any>);
827
828 let array_ref = implementing_args[0].1;
829
830 let result = array_ref.array_function(
831 &ArrayFunction::new("scirs2::array_protocol::operations::concatenate"),
832 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
833 &array_boxed_clones,
834 &kwargs,
835 )?;
836
837 match result.downcast::<Box<dyn ArrayProtocol>>() {
839 Ok(array) => Ok(*array),
840 Err(_) => Err(OperationError::Other(
841 "Failed to downcast result to ArrayProtocol".to_string(),
842 )),
843 }
844 },
845 "scirs2::array_protocol::operations::concatenate"
846);
847
848array_function_dispatch!(
850 fn reshape(
851 a: &dyn ArrayProtocol,
852 shape: &[usize],
853 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
854 let boxed_a = Box::new(a.box_clone());
856 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_a];
857 let implementing_args = get_implementing_args(&boxed_args);
858 if implementing_args.is_empty() {
859 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
862 let result = match a_array.as_array().clone().into_shape_with_order(shape) {
863 Ok(arr) => arr,
864 Err(e) => {
865 return Err(OperationError::ShapeMismatch(format!(
866 "Reshape failed: {e}"
867 )))
868 }
869 };
870 return Ok(Box::new(NdarrayWrapper::new(result)));
871 }
872 else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
874 let result = match a_array.as_array().clone().into_shape_with_order(shape) {
875 Ok(arr) => arr,
876 Err(e) => {
877 return Err(OperationError::ShapeMismatch(format!(
878 "Reshape failed: {e}"
879 )))
880 }
881 };
882 return Ok(Box::new(NdarrayWrapper::new(result)));
883 }
884 return Err(OperationError::NotImplemented(
885 "reshape not implemented for this array type".to_string(),
886 ));
887 }
888
889 let mut kwargs = HashMap::new();
891 kwargs.insert(
892 "shape".to_string(),
893 Box::new(shape.to_vec()) as Box<dyn Any>,
894 );
895
896 let array_ref = implementing_args[0].1;
897
898 let result = array_ref.array_function(
899 &ArrayFunction::new("scirs2::array_protocol::operations::reshape"),
900 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
901 &[Box::new(a.box_clone())],
902 &kwargs,
903 )?;
904
905 match result.downcast::<Box<dyn ArrayProtocol>>() {
907 Ok(array) => Ok(*array),
908 Err(_) => Err(OperationError::Other(
909 "Failed to downcast result to ArrayProtocol".to_string(),
910 )),
911 }
912 },
913 "scirs2::array_protocol::operations::reshape"
914);
915
916type SVDResult = (
920 Box<dyn ArrayProtocol>,
921 Box<dyn ArrayProtocol>,
922 Box<dyn ArrayProtocol>,
923);
924
925array_function_dispatch!(
927 fn svd(a: &dyn ArrayProtocol) -> Result<SVDResult, OperationError> {
928 let boxed_a = Box::new(a.box_clone());
930 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_a];
931 let implementing_args = get_implementing_args(&boxed_args);
932 if implementing_args.is_empty() {
933 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
935 let (m, n) = a_array.as_array().dim();
938 let u = Array2::<f64>::eye(m);
939 let s = Array1::<f64>::ones(std::cmp::min(m, n));
940 let vt = Array2::<f64>::eye(n);
941
942 return Ok((
943 Box::new(NdarrayWrapper::new(u)),
944 Box::new(NdarrayWrapper::new(s)),
945 Box::new(NdarrayWrapper::new(vt)),
946 ));
947 }
948 return Err(OperationError::NotImplemented(
949 "svd not implemented for this array type".to_string(),
950 ));
951 }
952
953 let array_ref = implementing_args[0].1;
955
956 let result = array_ref.array_function(
957 &ArrayFunction::new("scirs2::array_protocol::operations::svd"),
958 &[TypeId::of::<(
959 Box<dyn ArrayProtocol>,
960 Box<dyn ArrayProtocol>,
961 Box<dyn ArrayProtocol>,
962 )>()],
963 &[Box::new(a.box_clone())],
964 &HashMap::new(),
965 )?;
966
967 match result.downcast::<(
969 Box<dyn ArrayProtocol>,
970 Box<dyn ArrayProtocol>,
971 Box<dyn ArrayProtocol>,
972 )>() {
973 Ok(tuple) => Ok(*tuple),
974 Err(_) => Err(OperationError::Other(
975 "Failed to downcast result to SVD tuple".to_string(),
976 )),
977 }
978 },
979 "scirs2::array_protocol::operations::svd"
980);
981
982array_function_dispatch!(
984 fn inverse(a: &dyn ArrayProtocol) -> Result<Box<dyn ArrayProtocol>, OperationError> {
985 let boxed_a = Box::new(a.box_clone());
987 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_a];
988 let implementing_args = get_implementing_args(&boxed_args);
989 if implementing_args.is_empty() {
990 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
992 let (m, n) = a_array.as_array().dim();
995 if m != n {
996 return Err(OperationError::ShapeMismatch(
997 "Matrix must be square for inversion".to_string(),
998 ));
999 }
1000
1001 let result = Array2::<f64>::eye(m);
1003 return Ok(Box::new(NdarrayWrapper::new(result)));
1004 }
1005 return Err(OperationError::NotImplemented(
1006 "inverse not implemented for this array type".to_string(),
1007 ));
1008 }
1009
1010 let array_ref = implementing_args[0].1;
1012
1013 let result = array_ref.array_function(
1014 &ArrayFunction::new("scirs2::array_protocol::operations::inverse"),
1015 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
1016 &[Box::new(a.box_clone())],
1017 &HashMap::new(),
1018 )?;
1019
1020 match result.downcast::<Box<dyn ArrayProtocol>>() {
1022 Ok(array) => Ok(*array),
1023 Err(_) => Err(OperationError::Other(
1024 "Failed to downcast result to ArrayProtocol".to_string(),
1025 )),
1026 }
1027 },
1028 "scirs2::array_protocol::operations::inverse"
1029);
1030
1031#[allow(dead_code)]
1033pub fn multiply_by_scalar_f64(
1034 a: &dyn ArrayProtocol,
1035 scalar: f64,
1036) -> Result<Box<dyn ArrayProtocol>, OperationError> {
1037 let boxed_a = Box::new(a.box_clone());
1039 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_a];
1040 let implementing_args = get_implementing_args(&boxed_args);
1041 if implementing_args.is_empty() {
1042 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, Ix1>>() {
1044 let result = a_array.as_array() * scalar;
1045 return Ok(Box::new(NdarrayWrapper::new(result)));
1046 }
1047 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
1048 let result = a_array.as_array() * scalar;
1049 return Ok(Box::new(NdarrayWrapper::new(result)));
1050 }
1051 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
1052 let result = a_array.as_array() * scalar;
1053 return Ok(Box::new(NdarrayWrapper::new(result)));
1054 }
1055 return Err(OperationError::NotImplemented(
1056 "multiply_by_scalar not implemented for this array type".to_string(),
1057 ));
1058 }
1059
1060 let mut kwargs = HashMap::new();
1062 kwargs.insert(scalar.to_string(), Box::new(scalar) as Box<dyn Any>);
1063
1064 let array_ref = implementing_args[0].1;
1065
1066 let result = array_ref.array_function(
1067 &ArrayFunction::new("scirs2::array_protocol::operations::multiply_by_scalar_f64"),
1068 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
1069 &[Box::new(a.box_clone())],
1070 &kwargs,
1071 )?;
1072
1073 match result.downcast::<Box<dyn ArrayProtocol>>() {
1075 Ok(array) => Ok(*array),
1076 Err(_) => Err(OperationError::Other(
1077 "Failed to downcast result to ArrayProtocol".to_string(),
1078 )),
1079 }
1080}
1081
1082#[allow(dead_code)]
1084pub fn multiply_by_scalar_f32(
1085 a: &dyn ArrayProtocol,
1086 scalar: f32,
1087) -> Result<Box<dyn ArrayProtocol>, OperationError> {
1088 let boxed_a = Box::new(a.box_clone());
1090 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_a];
1091 let implementing_args = get_implementing_args(&boxed_args);
1092 if implementing_args.is_empty() {
1093 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f32, Ix1>>() {
1095 let result = a_array.as_array() * scalar;
1096 return Ok(Box::new(NdarrayWrapper::new(result)));
1097 }
1098 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f32, Ix2>>() {
1099 let result = a_array.as_array() * scalar;
1100 return Ok(Box::new(NdarrayWrapper::new(result)));
1101 }
1102 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>() {
1103 let result = a_array.as_array() * scalar;
1104 return Ok(Box::new(NdarrayWrapper::new(result)));
1105 }
1106 return Err(OperationError::NotImplemented(
1107 "multiply_by_scalar not implemented for this array type".to_string(),
1108 ));
1109 }
1110
1111 let mut kwargs = HashMap::new();
1113 kwargs.insert(scalar.to_string(), Box::new(scalar) as Box<dyn Any>);
1114
1115 let array_ref = implementing_args[0].1;
1116
1117 let result = array_ref.array_function(
1118 &ArrayFunction::new("scirs2::array_protocol::operations::multiply_by_scalar_f32"),
1119 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
1120 &[Box::new(a.box_clone())],
1121 &kwargs,
1122 )?;
1123
1124 match result.downcast::<Box<dyn ArrayProtocol>>() {
1126 Ok(array) => Ok(*array),
1127 Err(_) => Err(OperationError::Other(
1128 "Failed to downcast result to ArrayProtocol".to_string(),
1129 )),
1130 }
1131}
1132
1133#[allow(dead_code)]
1135pub fn divide_by_scalar_f64(
1136 a: &dyn ArrayProtocol,
1137 scalar: f64,
1138) -> Result<Box<dyn ArrayProtocol>, OperationError> {
1139 let boxed_a = Box::new(a.box_clone());
1141 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_a];
1142 let implementing_args = get_implementing_args(&boxed_args);
1143 if implementing_args.is_empty() {
1144 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, Ix1>>() {
1146 let result = a_array.as_array() / scalar;
1147 return Ok(Box::new(NdarrayWrapper::new(result)));
1148 }
1149 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
1150 let result = a_array.as_array() / scalar;
1151 return Ok(Box::new(NdarrayWrapper::new(result)));
1152 }
1153 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
1154 let result = a_array.as_array() / scalar;
1155 return Ok(Box::new(NdarrayWrapper::new(result)));
1156 }
1157 return Err(OperationError::NotImplemented(
1158 "divide_by_scalar not implemented for this array type".to_string(),
1159 ));
1160 }
1161
1162 let mut kwargs = HashMap::new();
1164 kwargs.insert(scalar.to_string(), Box::new(scalar) as Box<dyn Any>);
1165
1166 let array_ref = implementing_args[0].1;
1167
1168 let result = array_ref.array_function(
1169 &ArrayFunction::new("scirs2::array_protocol::operations::divide_by_scalar_f64"),
1170 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
1171 &[Box::new(a.box_clone())],
1172 &kwargs,
1173 )?;
1174
1175 match result.downcast::<Box<dyn ArrayProtocol>>() {
1177 Ok(array) => Ok(*array),
1178 Err(_) => Err(OperationError::Other(
1179 "Failed to downcast result to ArrayProtocol".to_string(),
1180 )),
1181 }
1182}
1183
1184#[cfg(test)]
1185mod tests {
1186 use super::*;
1187 use crate::array_protocol::{self, NdarrayWrapper};
1188 use ::ndarray::{array, Array2};
1189
1190 #[test]
1191 fn test_operations_with_ndarray() {
1192 use ::ndarray::array;
1193
1194 array_protocol::init();
1196
1197 let a = Array2::<f64>::eye(3);
1199 let b = Array2::<f64>::ones((3, 3));
1200
1201 let wrapped_a = NdarrayWrapper::new(a.clone());
1203 let wrapped_b = NdarrayWrapper::new(b.clone());
1204
1205 if let Ok(result) = matmul(&wrapped_a, &wrapped_b) {
1207 if let Some(result_array) = result.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
1208 assert_eq!(result_array.as_array(), &a.dot(&b));
1209 } else {
1210 panic!("Matrix multiplication result is not the expected type");
1211 }
1212 } else {
1213 println!("Skipping matrix multiplication test - operation not implemented");
1215 }
1216
1217 if let Ok(result) = add(&wrapped_a, &wrapped_b) {
1219 if let Some(result_array) = result.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
1220 assert_eq!(result_array.as_array(), &(a.clone() + b.clone()));
1221 } else {
1222 panic!("Addition result is not the expected type");
1223 }
1224 } else {
1225 println!("Skipping addition test - operation not implemented");
1226 }
1227
1228 if let Ok(result) = multiply(&wrapped_a, &wrapped_b) {
1230 if let Some(result_array) = result.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
1231 assert_eq!(result_array.as_array(), &(a.clone() * b.clone()));
1232 } else {
1233 panic!("Multiplication result is not the expected type");
1234 }
1235 } else {
1236 println!("Skipping multiplication test - operation not implemented");
1237 }
1238
1239 if let Ok(result) = sum(&wrapped_a, None) {
1241 if let Some(sum_value) = result.downcast_ref::<f64>() {
1242 assert_eq!(*sum_value, a.sum());
1243 } else {
1244 panic!("Sum result is not the expected type");
1245 }
1246 } else {
1247 println!("Skipping sum test - operation not implemented");
1248 }
1249
1250 if let Ok(result) = transpose(&wrapped_a) {
1252 if let Some(result_array) = result.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
1253 assert_eq!(result_array.as_array(), &a.t().to_owned());
1254 } else {
1255 panic!("Transpose result is not the expected type");
1256 }
1257 } else {
1258 println!("Skipping transpose test - operation not implemented");
1259 }
1260
1261 let c = array![[1., 2., 3.], [4., 5., 6.]];
1263 let wrapped_c = NdarrayWrapper::new(c.clone());
1264 if let Ok(result) = reshape(&wrapped_c, &[6]) {
1265 if let Some(result_array) = result.as_any().downcast_ref::<NdarrayWrapper<f64, Ix1>>() {
1266 let expected = c
1267 .clone()
1268 .into_shape_with_order(6)
1269 .expect("Operation failed");
1270 assert_eq!(result_array.as_array(), &expected);
1271 } else {
1272 panic!("Reshape result is not the expected type");
1273 }
1274 } else {
1275 println!("Skipping reshape test - operation not implemented");
1276 }
1277
1278 println!("All operations tested or skipped successfully");
1280 }
1281}