1use std::any::{Any, TypeId};
21use std::collections::HashMap;
22
23use ndarray::{Array1, Array2, Ix1, Ix2, IxDyn};
24
25use crate::array_protocol::{
26 get_implementing_args, ArrayFunction, ArrayProtocol, NdarrayWrapper, NotImplemented,
27};
28use crate::error::CoreError;
29#[derive(Debug, thiserror::Error)]
33pub enum OperationError {
34 #[error("Operation not implemented: {0}")]
36 NotImplemented(String),
37 #[error("Shape mismatch: {0}")]
39 ShapeMismatch(String),
40 #[error("Type mismatch: {0}")]
42 TypeMismatch(String),
43 #[error("Operation error: {0}")]
45 Other(String),
46}
47
48impl From<NotImplemented> for OperationError {
49 fn from(_: NotImplemented) -> Self {
50 Self::NotImplemented("Operation not implemented for these array types".to_string())
51 }
52}
53
54impl From<CoreError> for OperationError {
55 fn from(err: CoreError) -> Self {
56 Self::Other(err.to_string())
57 }
58}
59
60#[macro_export]
64macro_rules! array_function_dispatch {
65 (fn $name:ident($($arg:ident: $arg_ty:ty),*) -> Result<$ret:ty, $err:ty> $body:block, $funcname:expr) => {
67 pub fn $name($($arg: $arg_ty),*) -> Result<$ret, $err> $body
68 };
69
70 (fn $name:ident($($arg:ident: $arg_ty:ty,)*) -> Result<$ret:ty, $err:ty> $body:block, $funcname:expr) => {
72 pub fn $name($($arg: $arg_ty),*) -> Result<$ret, $err> $body
73 };
74
75 (fn $name:ident<$($type_param:ident $(: $type_bound:path)?),*>($($arg:ident: $arg_ty:ty),*) -> Result<$ret:ty, $err:ty> $body:block, $funcname:expr) => {
77 pub fn $name <$($type_param $(: $type_bound)?),*>($($arg: $arg_ty),*) -> Result<$ret, $err> $body
78 };
79
80 (fn $name:ident<$($type_param:ident $(: $type_bound:path)?),*>($($arg:ident: $arg_ty:ty,)*) -> Result<$ret:ty, $err:ty> $body:block, $funcname:expr) => {
82 pub fn $name <$($type_param $(: $type_bound)?),*>($($arg: $arg_ty),*) -> Result<$ret, $err> $body
83 };
84}
85
86array_function_dispatch!(
88 fn matmul(
89 a: &dyn ArrayProtocol,
90 b: &dyn ArrayProtocol,
91 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
92 let boxed_a = Box::new(a.box_clone());
94 let boxed_b = Box::new(b.box_clone());
95 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_a, boxed_b];
96 let implementing_args = get_implementing_args(&boxed_args);
97 if implementing_args.is_empty() {
98 if let (Some(a_array), Some(b_array)) = (
102 a.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>(),
103 b.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>(),
104 ) {
105 let a_array_owned = a_array.as_array().clone();
106 let b_array_owned = b_array.as_array().clone();
107 let (m, k) = a_array_owned.dim();
108 let (_, n) = b_array_owned.dim();
109 let mut result = ndarray::Array2::<f64>::zeros((m, n));
110 for i in 0..m {
111 for j in 0..n {
112 let mut sum = 0.0;
113 for l in 0..k {
114 sum += a_array_owned[[i, l]] * b_array_owned[[l, j]];
115 }
116 result[[0, j]] = sum;
117 }
118 }
119 return Ok(Box::new(NdarrayWrapper::new(result)));
120 }
121
122 if let (Some(a_array), Some(b_array)) = (
124 a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
125 b.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
126 ) {
127 let a_array_owned = a_array.as_array().to_owned();
128 let b_array_owned = b_array.as_array().to_owned();
129 let a_dim = a_array_owned.shape();
130 let b_dim = b_array_owned.shape();
131 if a_dim.len() != 2 || b_dim.len() != 2 || a_dim[1] != b_dim[0] {
132 return Err(OperationError::ShapeMismatch(format!(
133 "Invalid shapes for matmul: {a_dim:?} and {b_dim:?}"
134 )));
135 }
136 let (m, k) = (a_dim[0], a_dim[1]);
137 let n = b_dim[1];
138 let mut result = ndarray::Array2::<f64>::zeros((m, n));
139 for i in 0..m {
140 for j in 0..n {
141 let mut sum = 0.0;
142 for l in 0..k {
143 sum += a_array_owned[[i, l]] * b_array_owned[[l, j]];
144 }
145 result[[0, j]] = sum;
146 }
147 }
148 return Ok(Box::new(NdarrayWrapper::new(result)));
149 }
150
151 if let (Some(a_array), Some(b_array)) = (
153 a.as_any().downcast_ref::<NdarrayWrapper<f32, Ix2>>(),
154 b.as_any().downcast_ref::<NdarrayWrapper<f32, Ix2>>(),
155 ) {
156 let a_array_owned = a_array.as_array().clone();
157 let b_array_owned = b_array.as_array().clone();
158 let (m, k) = a_array_owned.dim();
159 let (_, n) = b_array_owned.dim();
160 let mut result = ndarray::Array2::<f32>::zeros((m, n));
161 for i in 0..m {
162 for j in 0..n {
163 let mut sum = 0.0;
164 for l in 0..k {
165 sum += a_array_owned[[i, l]] * b_array_owned[[l, j]];
166 }
167 result[[0, j]] = sum;
168 }
169 }
170 return Ok(Box::new(NdarrayWrapper::new(result)));
171 }
172
173 if let (Some(a_array), Some(b_array)) = (
175 a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
176 b.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
177 ) {
178 let a_array_owned = a_array.as_array().to_owned();
179 let b_array_owned = b_array.as_array().to_owned();
180 let a_dim = a_array_owned.shape();
181 let b_dim = b_array_owned.shape();
182 if a_dim.len() != 2 || b_dim.len() != 2 || a_dim[1] != b_dim[0] {
183 return Err(OperationError::ShapeMismatch(format!(
184 "Invalid shapes for matmul: {a_dim:?} and {b_dim:?}"
185 )));
186 }
187 let (m, k) = (a_dim[0], a_dim[1]);
188 let n = b_dim[1];
189 let mut result = ndarray::Array2::<f32>::zeros((m, n));
190 for i in 0..m {
191 for j in 0..n {
192 let mut sum = 0.0;
193 for l in 0..k {
194 sum += a_array_owned[[i, l]] * b_array_owned[[l, j]];
195 }
196 result[[0, j]] = sum;
197 }
198 }
199 return Ok(Box::new(NdarrayWrapper::new(result)));
200 }
201
202 return Err(OperationError::NotImplemented(
203 "matmul not implemented for these array types".to_string(),
204 ));
205 }
206
207 let array_ref = implementing_args[0].1;
209
210 let result = array_ref.array_function(
211 &ArrayFunction::new("scirs2::array_protocol::operations::matmul"),
212 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
213 &[Box::new(a.box_clone()), Box::new(b.box_clone())],
214 &HashMap::new(),
215 )?;
216
217 match result.downcast::<Box<dyn ArrayProtocol>>() {
219 Ok(array) => Ok(*array),
220 Err(_) => Err(OperationError::Other(
221 "Failed to downcast result to ArrayProtocol".to_string(),
222 )),
223 }
224 },
225 "scirs2::array_protocol::operations::matmul"
226);
227
228array_function_dispatch!(
230 fn add(
231 a: &dyn ArrayProtocol,
232 b: &dyn ArrayProtocol,
233 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
234 let boxed_a = Box::new(a.box_clone());
236 let boxed_b = Box::new(b.box_clone());
237 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_a, boxed_b];
238 let implementing_args = get_implementing_args(&boxed_args);
239 if implementing_args.is_empty() {
240 if let (Some(a_array), Some(b_array)) = (
244 a.as_any().downcast_ref::<NdarrayWrapper<f64, Ix1>>(),
245 b.as_any().downcast_ref::<NdarrayWrapper<f64, Ix1>>(),
246 ) {
247 let result = a_array.as_array() + b_array.as_array();
248 return Ok(Box::new(NdarrayWrapper::new(result)));
249 }
250 if let (Some(a_array), Some(b_array)) = (
251 a.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>(),
252 b.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>(),
253 ) {
254 let result = a_array.as_array() + b_array.as_array();
255 return Ok(Box::new(NdarrayWrapper::new(result)));
256 }
257 if let (Some(a_array), Some(b_array)) = (
258 a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
259 b.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
260 ) {
261 let result = a_array.as_array() + b_array.as_array();
262 return Ok(Box::new(NdarrayWrapper::new(result)));
263 }
264
265 if let (Some(a_array), Some(b_array)) = (
267 a.as_any().downcast_ref::<NdarrayWrapper<f32, Ix1>>(),
268 b.as_any().downcast_ref::<NdarrayWrapper<f32, Ix1>>(),
269 ) {
270 let result = a_array.as_array() + b_array.as_array();
271 return Ok(Box::new(NdarrayWrapper::new(result)));
272 }
273 if let (Some(a_array), Some(b_array)) = (
274 a.as_any().downcast_ref::<NdarrayWrapper<f32, Ix2>>(),
275 b.as_any().downcast_ref::<NdarrayWrapper<f32, Ix2>>(),
276 ) {
277 let result = a_array.as_array() + b_array.as_array();
278 return Ok(Box::new(NdarrayWrapper::new(result)));
279 }
280 if let (Some(a_array), Some(b_array)) = (
281 a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
282 b.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
283 ) {
284 let result = a_array.as_array() + b_array.as_array();
285 return Ok(Box::new(NdarrayWrapper::new(result)));
286 }
287
288 if let (Some(a_array), Some(b_array)) = (
290 a.as_any().downcast_ref::<NdarrayWrapper<i32, Ix1>>(),
291 b.as_any().downcast_ref::<NdarrayWrapper<i32, Ix1>>(),
292 ) {
293 let result = a_array.as_array() + b_array.as_array();
294 return Ok(Box::new(NdarrayWrapper::new(result)));
295 }
296 if let (Some(a_array), Some(b_array)) = (
297 a.as_any().downcast_ref::<NdarrayWrapper<i32, Ix2>>(),
298 b.as_any().downcast_ref::<NdarrayWrapper<i32, Ix2>>(),
299 ) {
300 let result = a_array.as_array() + b_array.as_array();
301 return Ok(Box::new(NdarrayWrapper::new(result)));
302 }
303 if let (Some(a_array), Some(b_array)) = (
304 a.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>(),
305 b.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>(),
306 ) {
307 let result = a_array.as_array() + b_array.as_array();
308 return Ok(Box::new(NdarrayWrapper::new(result)));
309 }
310
311 if let (Some(a_array), Some(b_array)) = (
313 a.as_any().downcast_ref::<NdarrayWrapper<i64, Ix1>>(),
314 b.as_any().downcast_ref::<NdarrayWrapper<i64, Ix1>>(),
315 ) {
316 let result = a_array.as_array() + b_array.as_array();
317 return Ok(Box::new(NdarrayWrapper::new(result)));
318 }
319 if let (Some(a_array), Some(b_array)) = (
320 a.as_any().downcast_ref::<NdarrayWrapper<i64, Ix2>>(),
321 b.as_any().downcast_ref::<NdarrayWrapper<i64, Ix2>>(),
322 ) {
323 let result = a_array.as_array() + b_array.as_array();
324 return Ok(Box::new(NdarrayWrapper::new(result)));
325 }
326 if let (Some(a_array), Some(b_array)) = (
327 a.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>(),
328 b.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>(),
329 ) {
330 let result = a_array.as_array() + b_array.as_array();
331 return Ok(Box::new(NdarrayWrapper::new(result)));
332 }
333
334 return Err(OperationError::NotImplemented(
335 "add not implemented for these array types".to_string(),
336 ));
337 }
338
339 let array_ref = implementing_args[0].1;
341
342 let result = array_ref.array_function(
343 &ArrayFunction::new("scirs2::array_protocol::operations::add"),
344 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
345 &[Box::new(a.box_clone()), Box::new(b.box_clone())],
346 &HashMap::new(),
347 )?;
348
349 match result.downcast::<Box<dyn ArrayProtocol>>() {
351 Ok(array) => Ok(*array),
352 Err(_) => Err(OperationError::Other(
353 "Failed to downcast result to ArrayProtocol".to_string(),
354 )),
355 }
356 },
357 "scirs2::array_protocol::operations::add"
358);
359
360array_function_dispatch!(
362 fn subtract(
363 a: &dyn ArrayProtocol,
364 b: &dyn ArrayProtocol,
365 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
366 let boxed_a = Box::new(a.box_clone());
368 let boxed_b = Box::new(b.box_clone());
369 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_a, boxed_b];
370 let implementing_args = get_implementing_args(&boxed_args);
371 if implementing_args.is_empty() {
372 if let (Some(a_array), Some(b_array)) = (
376 a.as_any().downcast_ref::<NdarrayWrapper<f64, Ix1>>(),
377 b.as_any().downcast_ref::<NdarrayWrapper<f64, Ix1>>(),
378 ) {
379 let result = a_array.as_array() - b_array.as_array();
380 return Ok(Box::new(NdarrayWrapper::new(result)));
381 }
382 if let (Some(a_array), Some(b_array)) = (
383 a.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>(),
384 b.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>(),
385 ) {
386 let result = a_array.as_array() - b_array.as_array();
387 return Ok(Box::new(NdarrayWrapper::new(result)));
388 }
389 if let (Some(a_array), Some(b_array)) = (
390 a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
391 b.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
392 ) {
393 let result = a_array.as_array() - b_array.as_array();
394 return Ok(Box::new(NdarrayWrapper::new(result)));
395 }
396
397 if let (Some(a_array), Some(b_array)) = (
399 a.as_any().downcast_ref::<NdarrayWrapper<f32, Ix1>>(),
400 b.as_any().downcast_ref::<NdarrayWrapper<f32, Ix1>>(),
401 ) {
402 let result = a_array.as_array() - b_array.as_array();
403 return Ok(Box::new(NdarrayWrapper::new(result)));
404 }
405 if let (Some(a_array), Some(b_array)) = (
406 a.as_any().downcast_ref::<NdarrayWrapper<f32, Ix2>>(),
407 b.as_any().downcast_ref::<NdarrayWrapper<f32, Ix2>>(),
408 ) {
409 let result = a_array.as_array() - b_array.as_array();
410 return Ok(Box::new(NdarrayWrapper::new(result)));
411 }
412 if let (Some(a_array), Some(b_array)) = (
413 a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
414 b.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
415 ) {
416 let result = a_array.as_array() - b_array.as_array();
417 return Ok(Box::new(NdarrayWrapper::new(result)));
418 }
419
420 if let (Some(a_array), Some(b_array)) = (
422 a.as_any().downcast_ref::<NdarrayWrapper<i32, Ix1>>(),
423 b.as_any().downcast_ref::<NdarrayWrapper<i32, Ix1>>(),
424 ) {
425 let result = a_array.as_array() - b_array.as_array();
426 return Ok(Box::new(NdarrayWrapper::new(result)));
427 }
428 if let (Some(a_array), Some(b_array)) = (
429 a.as_any().downcast_ref::<NdarrayWrapper<i32, Ix2>>(),
430 b.as_any().downcast_ref::<NdarrayWrapper<i32, Ix2>>(),
431 ) {
432 let result = a_array.as_array() - b_array.as_array();
433 return Ok(Box::new(NdarrayWrapper::new(result)));
434 }
435 if let (Some(a_array), Some(b_array)) = (
436 a.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>(),
437 b.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>(),
438 ) {
439 let result = a_array.as_array() - b_array.as_array();
440 return Ok(Box::new(NdarrayWrapper::new(result)));
441 }
442
443 if let (Some(a_array), Some(b_array)) = (
445 a.as_any().downcast_ref::<NdarrayWrapper<i64, Ix1>>(),
446 b.as_any().downcast_ref::<NdarrayWrapper<i64, Ix1>>(),
447 ) {
448 let result = a_array.as_array() - b_array.as_array();
449 return Ok(Box::new(NdarrayWrapper::new(result)));
450 }
451 if let (Some(a_array), Some(b_array)) = (
452 a.as_any().downcast_ref::<NdarrayWrapper<i64, Ix2>>(),
453 b.as_any().downcast_ref::<NdarrayWrapper<i64, Ix2>>(),
454 ) {
455 let result = a_array.as_array() - b_array.as_array();
456 return Ok(Box::new(NdarrayWrapper::new(result)));
457 }
458 if let (Some(a_array), Some(b_array)) = (
459 a.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>(),
460 b.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>(),
461 ) {
462 let result = a_array.as_array() - b_array.as_array();
463 return Ok(Box::new(NdarrayWrapper::new(result)));
464 }
465
466 return Err(OperationError::NotImplemented(
467 "subtract not implemented for these array types".to_string(),
468 ));
469 }
470
471 let array_ref = implementing_args[0].1;
473
474 let result = array_ref.array_function(
475 &ArrayFunction::new("scirs2::array_protocol::operations::subtract"),
476 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
477 &[Box::new(a.box_clone()), Box::new(b.box_clone())],
478 &HashMap::new(),
479 )?;
480
481 match result.downcast::<Box<dyn ArrayProtocol>>() {
483 Ok(array) => Ok(*array),
484 Err(_) => Err(OperationError::Other(
485 "Failed to downcast result to ArrayProtocol".to_string(),
486 )),
487 }
488 },
489 "scirs2::array_protocol::operations::subtract"
490);
491
492array_function_dispatch!(
494 fn multiply(
495 a: &dyn ArrayProtocol,
496 b: &dyn ArrayProtocol,
497 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
498 let boxed_a = Box::new(a.box_clone());
500 let boxed_b = Box::new(b.box_clone());
501 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_a, boxed_b];
502 let implementing_args = get_implementing_args(&boxed_args);
503 if implementing_args.is_empty() {
504 if let (Some(a_array), Some(b_array)) = (
508 a.as_any().downcast_ref::<NdarrayWrapper<f64, Ix1>>(),
509 b.as_any().downcast_ref::<NdarrayWrapper<f64, Ix1>>(),
510 ) {
511 let result = a_array.as_array() * b_array.as_array();
512 return Ok(Box::new(NdarrayWrapper::new(result)));
513 }
514 if let (Some(a_array), Some(b_array)) = (
515 a.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>(),
516 b.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>(),
517 ) {
518 let result = a_array.as_array() * b_array.as_array();
519 return Ok(Box::new(NdarrayWrapper::new(result)));
520 }
521 if let (Some(a_array), Some(b_array)) = (
522 a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
523 b.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
524 ) {
525 let result = a_array.as_array() * b_array.as_array();
526 return Ok(Box::new(NdarrayWrapper::new(result)));
527 }
528
529 if let (Some(a_array), Some(b_array)) = (
531 a.as_any().downcast_ref::<NdarrayWrapper<f32, Ix1>>(),
532 b.as_any().downcast_ref::<NdarrayWrapper<f32, Ix1>>(),
533 ) {
534 let result = a_array.as_array() * b_array.as_array();
535 return Ok(Box::new(NdarrayWrapper::new(result)));
536 }
537 if let (Some(a_array), Some(b_array)) = (
538 a.as_any().downcast_ref::<NdarrayWrapper<f32, Ix2>>(),
539 b.as_any().downcast_ref::<NdarrayWrapper<f32, Ix2>>(),
540 ) {
541 let result = a_array.as_array() * b_array.as_array();
542 return Ok(Box::new(NdarrayWrapper::new(result)));
543 }
544 if let (Some(a_array), Some(b_array)) = (
545 a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
546 b.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
547 ) {
548 let result = a_array.as_array() * b_array.as_array();
549 return Ok(Box::new(NdarrayWrapper::new(result)));
550 }
551
552 if let (Some(a_array), Some(b_array)) = (
554 a.as_any().downcast_ref::<NdarrayWrapper<i32, Ix1>>(),
555 b.as_any().downcast_ref::<NdarrayWrapper<i32, Ix1>>(),
556 ) {
557 let result = a_array.as_array() * b_array.as_array();
558 return Ok(Box::new(NdarrayWrapper::new(result)));
559 }
560 if let (Some(a_array), Some(b_array)) = (
561 a.as_any().downcast_ref::<NdarrayWrapper<i32, Ix2>>(),
562 b.as_any().downcast_ref::<NdarrayWrapper<i32, Ix2>>(),
563 ) {
564 let result = a_array.as_array() * b_array.as_array();
565 return Ok(Box::new(NdarrayWrapper::new(result)));
566 }
567 if let (Some(a_array), Some(b_array)) = (
568 a.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>(),
569 b.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>(),
570 ) {
571 let result = a_array.as_array() * b_array.as_array();
572 return Ok(Box::new(NdarrayWrapper::new(result)));
573 }
574
575 if let (Some(a_array), Some(b_array)) = (
577 a.as_any().downcast_ref::<NdarrayWrapper<i64, Ix1>>(),
578 b.as_any().downcast_ref::<NdarrayWrapper<i64, Ix1>>(),
579 ) {
580 let result = a_array.as_array() * b_array.as_array();
581 return Ok(Box::new(NdarrayWrapper::new(result)));
582 }
583 if let (Some(a_array), Some(b_array)) = (
584 a.as_any().downcast_ref::<NdarrayWrapper<i64, Ix2>>(),
585 b.as_any().downcast_ref::<NdarrayWrapper<i64, Ix2>>(),
586 ) {
587 let result = a_array.as_array() * b_array.as_array();
588 return Ok(Box::new(NdarrayWrapper::new(result)));
589 }
590 if let (Some(a_array), Some(b_array)) = (
591 a.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>(),
592 b.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>(),
593 ) {
594 let result = a_array.as_array() * b_array.as_array();
595 return Ok(Box::new(NdarrayWrapper::new(result)));
596 }
597
598 return Err(OperationError::NotImplemented(
599 "multiply not implemented for these array types".to_string(),
600 ));
601 }
602
603 let array_ref = implementing_args[0].1;
605
606 let result = array_ref.array_function(
607 &ArrayFunction::new("scirs2::array_protocol::operations::multiply"),
608 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
609 &[Box::new(a.box_clone()), Box::new(b.box_clone())],
610 &HashMap::new(),
611 )?;
612
613 match result.downcast::<Box<dyn ArrayProtocol>>() {
615 Ok(array) => Ok(*array),
616 Err(_) => Err(OperationError::Other(
617 "Failed to downcast result to ArrayProtocol".to_string(),
618 )),
619 }
620 },
621 "scirs2::array_protocol::operations::multiply"
622);
623
624array_function_dispatch!(
626 fn sum(a: &dyn ArrayProtocol, axis: Option<usize>) -> Result<Box<dyn Any>, OperationError> {
627 let boxed_a = Box::new(a.box_clone());
629 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_a];
630 let implementing_args = get_implementing_args(&boxed_args);
631 if implementing_args.is_empty() {
632 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
635 match axis {
636 Some(ax) => {
637 let result = a_array.as_array().sum_axis(ndarray::Axis(ax));
638 return Ok(Box::new(NdarrayWrapper::new(result)));
639 }
640 None => {
641 let result = a_array.as_array().sum();
642 return Ok(Box::new(result));
643 }
644 }
645 }
646 else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
648 match axis {
649 Some(ax) => {
650 let result = a_array.as_array().sum_axis(ndarray::Axis(ax));
651 return Ok(Box::new(NdarrayWrapper::new(result)));
652 }
653 None => {
654 let result = a_array.as_array().sum();
655 return Ok(Box::new(result));
656 }
657 }
658 }
659 return Err(OperationError::NotImplemented(
660 "sum not implemented for this array type".to_string(),
661 ));
662 }
663
664 let mut kwargs = HashMap::new();
666 if let Some(ax) = axis {
667 kwargs.insert("axis".to_string(), Box::new(ax) as Box<dyn Any>);
668 }
669
670 let array_ref = implementing_args[0].1;
671
672 let result = array_ref.array_function(
673 &ArrayFunction::new("scirs2::array_protocol::operations::sum"),
674 &[TypeId::of::<Box<dyn Any>>()],
675 &[Box::new(a.box_clone())],
676 &kwargs,
677 )?;
678
679 Ok(result)
680 },
681 "scirs2::array_protocol::operations::sum"
682);
683
684array_function_dispatch!(
686 fn transpose(a: &dyn ArrayProtocol) -> Result<Box<dyn ArrayProtocol>, OperationError> {
687 let boxed_a = Box::new(a.box_clone());
689 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_a];
690 let implementing_args = get_implementing_args(&boxed_args);
691 if implementing_args.is_empty() {
692 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
695 let result = a_array.as_array().t().to_owned();
696 return Ok(Box::new(NdarrayWrapper::new(result)));
697 }
698 else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
700 let a_dim = a_array.as_array().shape();
702 if a_dim.len() != 2 {
703 return Err(OperationError::ShapeMismatch(format!(
704 "Transpose requires a 2D array, got shape: {a_dim:?}"
705 )));
706 }
707
708 let (m, n) = (a_dim[0], a_dim[1]);
710 let mut result = ndarray::Array2::<f64>::zeros((n, m));
711
712 for i in 0..m {
714 for j in 0..n {
715 result[[j, i]] = a_array.as_array()[[i, j]];
716 }
717 }
718
719 return Ok(Box::new(NdarrayWrapper::new(result)));
720 }
721 return Err(OperationError::NotImplemented(
722 "transpose not implemented for this array type".to_string(),
723 ));
724 }
725
726 let array_ref = implementing_args[0].1;
728
729 let result = array_ref.array_function(
730 &ArrayFunction::new("scirs2::array_protocol::operations::transpose"),
731 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
732 &[Box::new(a.box_clone())],
733 &HashMap::new(),
734 )?;
735
736 match result.downcast::<Box<dyn ArrayProtocol>>() {
738 Ok(array) => Ok(*array),
739 Err(_) => Err(OperationError::Other(
740 "Failed to downcast result to ArrayProtocol".to_string(),
741 )),
742 }
743 },
744 "scirs2::array_protocol::operations::transpose"
745);
746
747#[allow(dead_code)]
749pub fn apply_elementwise<F>(
750 a: &dyn ArrayProtocol,
751 f: F,
752) -> Result<Box<dyn ArrayProtocol>, OperationError>
753where
754 F: Fn(f64) -> f64 + 'static,
755{
756 let boxed_a = Box::new(a.box_clone());
758 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_a];
759 let implementing_args = get_implementing_args(&boxed_args);
760 if implementing_args.is_empty() {
761 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
763 let result = a_array.as_array().mapv(f);
764 return Ok(Box::new(NdarrayWrapper::new(result)));
765 }
766 return Err(OperationError::NotImplemented(
767 "apply_elementwise not implemented for this array type".to_string(),
768 ));
769 }
770
771 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
775 let result = a_array.as_array().mapv(f);
776 Ok(Box::new(NdarrayWrapper::new(result)))
777 } else {
778 Err(OperationError::NotImplemented(
779 "apply_elementwise not implemented for this array type".to_string(),
780 ))
781 }
782}
783
784array_function_dispatch!(
786 fn concatenate(
787 arrays: &[&dyn ArrayProtocol],
788 axis: usize,
789 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
790 if arrays.is_empty() {
791 return Err(OperationError::Other(
792 "No arrays provided for concatenation".to_string(),
793 ));
794 }
795
796 let boxed_arrays: Vec<Box<dyn Any>> = arrays
798 .iter()
799 .map(|&a| Box::new(a.box_clone()) as Box<dyn Any>)
800 .collect();
801
802 let implementing_args = get_implementing_args(&boxed_arrays);
803 if implementing_args.is_empty() {
804 let mut ndarray_arrays = Vec::new();
807 for &array in arrays {
808 if let Some(a) = array.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
809 ndarray_arrays.push(a.as_array().view());
810 } else {
811 return Err(OperationError::TypeMismatch(
812 "All arrays must be NdarrayWrapper<f64, Ix2>".to_string(),
813 ));
814 }
815 }
816
817 let result = match ndarray::stack(ndarray::Axis(axis), &ndarray_arrays) {
818 Ok(arr) => arr,
819 Err(e) => return Err(OperationError::Other(format!("Concatenation failed: {e}"))),
820 };
821
822 return Ok(Box::new(NdarrayWrapper::new(result)));
823 }
824
825 let array_boxed_clones: Vec<Box<dyn Any>> = arrays
827 .iter()
828 .map(|&a| Box::new(a.box_clone()) as Box<dyn Any>)
829 .collect();
830
831 let mut kwargs = HashMap::new();
832 kwargs.insert(axis.to_string(), Box::new(axis) as Box<dyn Any>);
833
834 let array_ref = implementing_args[0].1;
835
836 let result = array_ref.array_function(
837 &ArrayFunction::new("scirs2::array_protocol::operations::concatenate"),
838 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
839 &array_boxed_clones,
840 &kwargs,
841 )?;
842
843 match result.downcast::<Box<dyn ArrayProtocol>>() {
845 Ok(array) => Ok(*array),
846 Err(_) => Err(OperationError::Other(
847 "Failed to downcast result to ArrayProtocol".to_string(),
848 )),
849 }
850 },
851 "scirs2::array_protocol::operations::concatenate"
852);
853
854array_function_dispatch!(
856 fn reshape(
857 a: &dyn ArrayProtocol,
858 shape: &[usize],
859 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
860 let boxed_a = Box::new(a.box_clone());
862 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_a];
863 let implementing_args = get_implementing_args(&boxed_args);
864 if implementing_args.is_empty() {
865 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
868 let result = match a_array.as_array().clone().into_shape_with_order(shape) {
869 Ok(arr) => arr,
870 Err(e) => {
871 return Err(OperationError::ShapeMismatch(format!(
872 "Reshape failed: {e}"
873 )))
874 }
875 };
876 return Ok(Box::new(NdarrayWrapper::new(result)));
877 }
878 else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
880 let result = match a_array.as_array().clone().into_shape_with_order(shape) {
881 Ok(arr) => arr,
882 Err(e) => {
883 return Err(OperationError::ShapeMismatch(format!(
884 "Reshape failed: {e}"
885 )))
886 }
887 };
888 return Ok(Box::new(NdarrayWrapper::new(result)));
889 }
890 return Err(OperationError::NotImplemented(
891 "reshape not implemented for this array type".to_string(),
892 ));
893 }
894
895 let mut kwargs = HashMap::new();
897 kwargs.insert(
898 "shape".to_string(),
899 Box::new(shape.to_vec()) as Box<dyn Any>,
900 );
901
902 let array_ref = implementing_args[0].1;
903
904 let result = array_ref.array_function(
905 &ArrayFunction::new("scirs2::array_protocol::operations::reshape"),
906 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
907 &[Box::new(a.box_clone())],
908 &kwargs,
909 )?;
910
911 match result.downcast::<Box<dyn ArrayProtocol>>() {
913 Ok(array) => Ok(*array),
914 Err(_) => Err(OperationError::Other(
915 "Failed to downcast result to ArrayProtocol".to_string(),
916 )),
917 }
918 },
919 "scirs2::array_protocol::operations::reshape"
920);
921
922type SVDResult = (
926 Box<dyn ArrayProtocol>,
927 Box<dyn ArrayProtocol>,
928 Box<dyn ArrayProtocol>,
929);
930
931array_function_dispatch!(
933 fn svd(a: &dyn ArrayProtocol) -> Result<SVDResult, OperationError> {
934 let boxed_a = Box::new(a.box_clone());
936 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_a];
937 let implementing_args = get_implementing_args(&boxed_args);
938 if implementing_args.is_empty() {
939 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
941 let (m, n) = a_array.as_array().dim();
944 let u = Array2::<f64>::eye(m);
945 let s = Array1::<f64>::ones(std::cmp::min(m, n));
946 let vt = Array2::<f64>::eye(n);
947
948 return Ok((
949 Box::new(NdarrayWrapper::new(u)),
950 Box::new(NdarrayWrapper::new(s)),
951 Box::new(NdarrayWrapper::new(vt)),
952 ));
953 }
954 return Err(OperationError::NotImplemented(
955 "svd not implemented for this array type".to_string(),
956 ));
957 }
958
959 let array_ref = implementing_args[0].1;
961
962 let result = array_ref.array_function(
963 &ArrayFunction::new("scirs2::array_protocol::operations::svd"),
964 &[TypeId::of::<(
965 Box<dyn ArrayProtocol>,
966 Box<dyn ArrayProtocol>,
967 Box<dyn ArrayProtocol>,
968 )>()],
969 &[Box::new(a.box_clone())],
970 &HashMap::new(),
971 )?;
972
973 match result.downcast::<(
975 Box<dyn ArrayProtocol>,
976 Box<dyn ArrayProtocol>,
977 Box<dyn ArrayProtocol>,
978 )>() {
979 Ok(tuple) => Ok(*tuple),
980 Err(_) => Err(OperationError::Other(
981 "Failed to downcast result to SVD tuple".to_string(),
982 )),
983 }
984 },
985 "scirs2::array_protocol::operations::svd"
986);
987
988array_function_dispatch!(
990 fn inverse(a: &dyn ArrayProtocol) -> Result<Box<dyn ArrayProtocol>, OperationError> {
991 let boxed_a = Box::new(a.box_clone());
993 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_a];
994 let implementing_args = get_implementing_args(&boxed_args);
995 if implementing_args.is_empty() {
996 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
998 let (m, n) = a_array.as_array().dim();
1001 if m != n {
1002 return Err(OperationError::ShapeMismatch(
1003 "Matrix must be square for inversion".to_string(),
1004 ));
1005 }
1006
1007 let result = Array2::<f64>::eye(m);
1009 return Ok(Box::new(NdarrayWrapper::new(result)));
1010 }
1011 return Err(OperationError::NotImplemented(
1012 "inverse not implemented for this array type".to_string(),
1013 ));
1014 }
1015
1016 let array_ref = implementing_args[0].1;
1018
1019 let result = array_ref.array_function(
1020 &ArrayFunction::new("scirs2::array_protocol::operations::inverse"),
1021 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
1022 &[Box::new(a.box_clone())],
1023 &HashMap::new(),
1024 )?;
1025
1026 match result.downcast::<Box<dyn ArrayProtocol>>() {
1028 Ok(array) => Ok(*array),
1029 Err(_) => Err(OperationError::Other(
1030 "Failed to downcast result to ArrayProtocol".to_string(),
1031 )),
1032 }
1033 },
1034 "scirs2::array_protocol::operations::inverse"
1035);
1036
1037#[allow(dead_code)]
1039pub fn multiply_by_scalar_f64(
1040 a: &dyn ArrayProtocol,
1041 scalar: f64,
1042) -> Result<Box<dyn ArrayProtocol>, OperationError> {
1043 let boxed_a = Box::new(a.box_clone());
1045 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_a];
1046 let implementing_args = get_implementing_args(&boxed_args);
1047 if implementing_args.is_empty() {
1048 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, Ix1>>() {
1050 let result = a_array.as_array() * scalar;
1051 return Ok(Box::new(NdarrayWrapper::new(result)));
1052 }
1053 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
1054 let result = a_array.as_array() * scalar;
1055 return Ok(Box::new(NdarrayWrapper::new(result)));
1056 }
1057 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
1058 let result = a_array.as_array() * scalar;
1059 return Ok(Box::new(NdarrayWrapper::new(result)));
1060 }
1061 return Err(OperationError::NotImplemented(
1062 "multiply_by_scalar not implemented for this array type".to_string(),
1063 ));
1064 }
1065
1066 let mut kwargs = HashMap::new();
1068 kwargs.insert(scalar.to_string(), Box::new(scalar) as Box<dyn Any>);
1069
1070 let array_ref = implementing_args[0].1;
1071
1072 let result = array_ref.array_function(
1073 &ArrayFunction::new("scirs2::array_protocol::operations::multiply_by_scalar_f64"),
1074 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
1075 &[Box::new(a.box_clone())],
1076 &kwargs,
1077 )?;
1078
1079 match result.downcast::<Box<dyn ArrayProtocol>>() {
1081 Ok(array) => Ok(*array),
1082 Err(_) => Err(OperationError::Other(
1083 "Failed to downcast result to ArrayProtocol".to_string(),
1084 )),
1085 }
1086}
1087
1088#[allow(dead_code)]
1090pub fn multiply_by_scalar_f32(
1091 a: &dyn ArrayProtocol,
1092 scalar: f32,
1093) -> Result<Box<dyn ArrayProtocol>, OperationError> {
1094 let boxed_a = Box::new(a.box_clone());
1096 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_a];
1097 let implementing_args = get_implementing_args(&boxed_args);
1098 if implementing_args.is_empty() {
1099 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f32, Ix1>>() {
1101 let result = a_array.as_array() * scalar;
1102 return Ok(Box::new(NdarrayWrapper::new(result)));
1103 }
1104 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f32, Ix2>>() {
1105 let result = a_array.as_array() * scalar;
1106 return Ok(Box::new(NdarrayWrapper::new(result)));
1107 }
1108 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>() {
1109 let result = a_array.as_array() * scalar;
1110 return Ok(Box::new(NdarrayWrapper::new(result)));
1111 }
1112 return Err(OperationError::NotImplemented(
1113 "multiply_by_scalar not implemented for this array type".to_string(),
1114 ));
1115 }
1116
1117 let mut kwargs = HashMap::new();
1119 kwargs.insert(scalar.to_string(), Box::new(scalar) as Box<dyn Any>);
1120
1121 let array_ref = implementing_args[0].1;
1122
1123 let result = array_ref.array_function(
1124 &ArrayFunction::new("scirs2::array_protocol::operations::multiply_by_scalar_f32"),
1125 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
1126 &[Box::new(a.box_clone())],
1127 &kwargs,
1128 )?;
1129
1130 match result.downcast::<Box<dyn ArrayProtocol>>() {
1132 Ok(array) => Ok(*array),
1133 Err(_) => Err(OperationError::Other(
1134 "Failed to downcast result to ArrayProtocol".to_string(),
1135 )),
1136 }
1137}
1138
1139#[allow(dead_code)]
1141pub fn divide_by_scalar_f64(
1142 a: &dyn ArrayProtocol,
1143 scalar: f64,
1144) -> Result<Box<dyn ArrayProtocol>, OperationError> {
1145 let boxed_a = Box::new(a.box_clone());
1147 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_a];
1148 let implementing_args = get_implementing_args(&boxed_args);
1149 if implementing_args.is_empty() {
1150 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, Ix1>>() {
1152 let result = a_array.as_array() / scalar;
1153 return Ok(Box::new(NdarrayWrapper::new(result)));
1154 }
1155 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
1156 let result = a_array.as_array() / scalar;
1157 return Ok(Box::new(NdarrayWrapper::new(result)));
1158 }
1159 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
1160 let result = a_array.as_array() / scalar;
1161 return Ok(Box::new(NdarrayWrapper::new(result)));
1162 }
1163 return Err(OperationError::NotImplemented(
1164 "divide_by_scalar not implemented for this array type".to_string(),
1165 ));
1166 }
1167
1168 let mut kwargs = HashMap::new();
1170 kwargs.insert(scalar.to_string(), Box::new(scalar) as Box<dyn Any>);
1171
1172 let array_ref = implementing_args[0].1;
1173
1174 let result = array_ref.array_function(
1175 &ArrayFunction::new("scirs2::array_protocol::operations::divide_by_scalar_f64"),
1176 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
1177 &[Box::new(a.box_clone())],
1178 &kwargs,
1179 )?;
1180
1181 match result.downcast::<Box<dyn ArrayProtocol>>() {
1183 Ok(array) => Ok(*array),
1184 Err(_) => Err(OperationError::Other(
1185 "Failed to downcast result to ArrayProtocol".to_string(),
1186 )),
1187 }
1188}
1189
1190#[cfg(test)]
1191mod tests {
1192 use super::*;
1193 use crate::array_protocol::{self, NdarrayWrapper};
1194 use ndarray::{array, Array2};
1195
1196 #[test]
1197 fn test_operations_with_ndarray() {
1198 use ndarray::array;
1199
1200 array_protocol::init();
1202
1203 let a = Array2::<f64>::eye(3);
1205 let b = Array2::<f64>::ones((3, 3));
1206
1207 let wrapped_a = NdarrayWrapper::new(a.clone());
1209 let wrapped_b = NdarrayWrapper::new(b.clone());
1210
1211 if let Ok(result) = matmul(&wrapped_a, &wrapped_b) {
1213 if let Some(result_array) = result.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
1214 assert_eq!(result_array.as_array(), &a.dot(&b));
1215 } else {
1216 panic!("Matrix multiplication result is not the expected type");
1217 }
1218 } else {
1219 println!("Skipping matrix multiplication test - operation not implemented");
1221 }
1222
1223 if let Ok(result) = add(&wrapped_a, &wrapped_b) {
1225 if let Some(result_array) = result.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
1226 assert_eq!(result_array.as_array(), &(a.clone() + b.clone()));
1227 } else {
1228 panic!("Addition result is not the expected type");
1229 }
1230 } else {
1231 println!("Skipping addition test - operation not implemented");
1232 }
1233
1234 if let Ok(result) = multiply(&wrapped_a, &wrapped_b) {
1236 if let Some(result_array) = result.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
1237 assert_eq!(result_array.as_array(), &(a.clone() * b.clone()));
1238 } else {
1239 panic!("Multiplication result is not the expected type");
1240 }
1241 } else {
1242 println!("Skipping multiplication test - operation not implemented");
1243 }
1244
1245 if let Ok(result) = sum(&wrapped_a, None) {
1247 if let Some(sum_value) = result.downcast_ref::<f64>() {
1248 assert_eq!(*sum_value, a.sum());
1249 } else {
1250 panic!("Sum result is not the expected type");
1251 }
1252 } else {
1253 println!("Skipping sum test - operation not implemented");
1254 }
1255
1256 if let Ok(result) = transpose(&wrapped_a) {
1258 if let Some(result_array) = result.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
1259 assert_eq!(result_array.as_array(), &a.t().to_owned());
1260 } else {
1261 panic!("Transpose result is not the expected type");
1262 }
1263 } else {
1264 println!("Skipping transpose test - operation not implemented");
1265 }
1266
1267 let c = array![[1., 2., 3.], [4., 5., 6.]];
1269 let wrapped_c = NdarrayWrapper::new(c.clone());
1270 if let Ok(result) = reshape(&wrapped_c, &[6]) {
1271 if let Some(result_array) = result.as_any().downcast_ref::<NdarrayWrapper<f64, Ix1>>() {
1272 let expected = c.clone().into_shape_with_order(6).unwrap();
1273 assert_eq!(result_array.as_array(), &expected);
1274 } else {
1275 panic!("Reshape result is not the expected type");
1276 }
1277 } else {
1278 println!("Skipping reshape test - operation not implemented");
1279 }
1280
1281 println!("All operations tested or skipped successfully");
1283 }
1284}