scirs2_core/array_protocol/
ml_ops.rs

1// Copyright (c) 2025, `SciRS2` Team
2//
3// Licensed under either of
4//
5// * Apache License, Version 2.0
6//   (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
7// * MIT license
8//   (LICENSE-MIT or http://opensource.org/licenses/MIT)
9//
10// at your option.
11//
12
13//! Machine learning operations using the array protocol.
14//!
15//! This module provides implementations of various machine learning operations
16//! using the array protocol, such as activation functions, convolution, and
17//! pooling.
18
19use std::any::{Any, TypeId};
20use std::collections::HashMap;
21
22use ndarray::{Array, Axis, Ix1, Ix2, Ix3, Ix4, IxDyn};
23use rand::prelude::*;
24use rand::Rng;
25
26use crate::array_protocol::operations::OperationError;
27use crate::array_protocol::{
28    array_function_dispatch, get_implementing_args, ArrayProtocol, NdarrayWrapper,
29};
30
31/// Activation function types.
32#[derive(Debug, Clone, Copy, PartialEq)]
33pub enum ActivationFunc {
34    /// Rectified Linear Unit: f(x) = max(0, x)
35    ReLU,
36
37    /// Sigmoid function: f(x) = 1 / (1 + exp(-x))
38    Sigmoid,
39
40    /// Hyperbolic tangent: f(x) = tanh(x)
41    Tanh,
42
43    /// Softmax function (applied along the last dimension)
44    Softmax,
45
46    /// Leaky ReLU: f(x) = max(alpha * x, x)
47    LeakyReLU(f64),
48}
49
50/// Apply an activation function to an array.
51#[allow(dead_code)]
52fn apply_activation(
53    x: &ndarray::ArrayBase<ndarray::ViewRepr<&f64>, IxDyn>,
54    func: ActivationFunc,
55) -> Array<f64, IxDyn> {
56    match func {
57        ActivationFunc::ReLU => x.mapv(|v| v.max(0.0)),
58        ActivationFunc::Sigmoid => x.mapv(|v| 1.0 / (1.0 + (-v).exp())),
59        ActivationFunc::Tanh => x.mapv(|v| v.tanh()),
60        ActivationFunc::LeakyReLU(alpha) => x.mapv(|v| if v > 0.0 { v } else { alpha * v }),
61        ActivationFunc::Softmax => {
62            // Apply softmax along the last dimension
63            let mut result = Array::zeros(x.raw_dim());
64
65            // Iterate over all but the last dimension
66            let last_dim = x.ndim() - 1;
67            let _last_dim_len = x.shape()[last_dim];
68
69            if x.ndim() == 1 {
70                // Simple 1D case
71                let max_val = x.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
72                let exp_x = x.mapv(|v| (v - max_val).exp());
73                let sum_exp = exp_x.sum();
74                result.assign(&(exp_x / sum_exp));
75            } else {
76                // Multi-dimensional case
77                for (i, mut slice) in result.lanes_mut(Axis(last_dim)).into_iter().enumerate() {
78                    // Use index_axis to get the ith slice along the last dimension
79                    let x_slice = x.index_axis(Axis(last_dim), i);
80                    let max_val = x_slice.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
81                    let exp_x = x_slice.mapv(|v| (v - max_val).exp());
82                    let sum_exp = exp_x.sum();
83                    slice.assign(&(exp_x / sum_exp));
84                }
85            }
86
87            result
88        }
89    }
90}
91
92// Define machine learning operations using the array protocol
93
94array_function_dispatch!(
95    fn activation(
96        x: &dyn ArrayProtocol,
97        func: ActivationFunc,
98    ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
99        // Get implementing args
100        let boxed_x = Box::new(x.box_clone());
101        let boxed_args: Vec<Box<dyn Any>> = vec![boxed_x];
102        let implementing_args = get_implementing_args(&boxed_args);
103        if implementing_args.is_empty() {
104            // Fallback implementation for ndarray types
105            if let Some(x_array) = x.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
106                let x_array = x_array.as_array();
107                let result = apply_activation(&x_array.view(), func);
108                return Ok(Box::new(NdarrayWrapper::new(result)));
109            }
110            return Err(OperationError::NotImplemented(
111                "activation not implemented for this array type".to_string(),
112            ));
113        }
114
115        // Delegate to the implementation
116        let array_ref = implementing_args[0].1;
117
118        let result = array_ref.array_function(
119            &crate::array_protocol::ArrayFunction::new(
120                "scirs2::array_protocol::ml_ops::activation",
121            ),
122            &[TypeId::of::<Box<dyn ArrayProtocol>>()],
123            &[Box::new(x.box_clone())],
124            &HashMap::new(),
125        )?;
126
127        // Try to downcast the result
128        match result.downcast::<Box<dyn ArrayProtocol>>() {
129            Ok(array) => Ok(*array),
130            Err(_) => Err(OperationError::Other(
131                "Failed to downcast array_function result".to_string(),
132            )),
133        }
134    },
135    "scirs2::array_protocol::ml, ops: activation"
136);
137
138array_function_dispatch!(
139    fn conv2d(
140        input: &dyn ArrayProtocol,
141        filters: &dyn ArrayProtocol,
142        stride: (usize, usize),
143        padding: (usize, usize),
144    ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
145        // Get implementing args
146        let boxed_input = Box::new(input.box_clone());
147        let boxed_filters = Box::new(filters.box_clone());
148        let boxed_args: Vec<Box<dyn Any>> = vec![boxed_input, boxed_filters];
149        let implementing_args = get_implementing_args(&boxed_args);
150        if implementing_args.is_empty() {
151            // Fallback implementation for ndarray types
152            // This is a simplified implementation - in practice, convolution is much more complex
153            if let (Some(inputarray), Some(filters_array)) = (
154                input.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
155                filters.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
156            ) {
157                let input = inputarray.as_array();
158                let filters = filters_array.as_array();
159
160                // Get dimensions
161                let batch_size = input.shape()[0];
162                let input_height = input.shape()[1];
163                let input_width = input.shape()[2];
164                let input_channels = input.shape()[3];
165
166                let filter_height = filters.shape()[0];
167                let filter_width = filters.shape()[1];
168                let filter_in_channels = filters.shape()[2];
169                let filter_out_channels = filters.shape()[3];
170
171                // Check dimensions
172                if input_channels != filter_in_channels {
173                    return Err(OperationError::ShapeMismatch(format!(
174                        "Input channels ({input_channels}) doesn't match filter input channels ({filter_in_channels})"
175                    )));
176                }
177
178                // Calculate output dimensions
179                let out_height = (input_height - filter_height + 2 * padding.0) / stride.0 + 1;
180                let out_width = (input_width - filter_width + 2 * padding.1) / stride.1 + 1;
181
182                // Create output array
183                let mut output: Array<f64, Ix4> =
184                    Array::zeros((batch_size, out_height, out_width, filter_out_channels));
185
186                // Perform convolution using basic sliding window approach
187                for b in 0..batch_size {
188                    for out_c in 0..filter_out_channels {
189                        for out_h in 0..out_height {
190                            for out_w in 0..out_width {
191                                let mut sum = 0.0;
192
193                                // Convolution over the filter window
194                                for f_h in 0..filter_height {
195                                    for f_w in 0..filter_width {
196                                        for in_c in 0..input_channels {
197                                            // Calculate input coordinates with padding
198                                            let in_h = (out_h * stride.0) as i32 + f_h as i32
199                                                - padding.0 as i32;
200                                            let in_w = (out_w * stride.1) as i32 + f_w as i32
201                                                - padding.1 as i32;
202
203                                            // Check bounds (zero padding)
204                                            if in_h >= 0
205                                                && in_h < input_height as i32
206                                                && in_w >= 0
207                                                && in_w < input_width as i32
208                                            {
209                                                let input_val =
210                                                    input[[b, in_h as usize, in_w as usize, in_c]];
211                                                let filter_val = filters[[f_h, f_w, in_c, out_c]];
212                                                sum += input_val * filter_val;
213                                            }
214                                        }
215                                    }
216                                }
217
218                                output[[b, out_h, out_w, out_c]] = sum;
219                            }
220                        }
221                    }
222                }
223
224                return Ok(Box::new(NdarrayWrapper::new(output)));
225            }
226            return Err(OperationError::NotImplemented(
227                "conv2d not implemented for these array types".to_string(),
228            ));
229        }
230
231        // Delegate to the implementation
232        let mut kwargs = HashMap::new();
233        kwargs.insert("stride".to_string(), Box::new(stride) as Box<dyn Any>);
234        kwargs.insert("padding".to_string(), Box::new(padding) as Box<dyn Any>);
235
236        let array_ref = implementing_args[0].1;
237
238        let result = array_ref.array_function(
239            &crate::array_protocol::ArrayFunction::new("scirs2::array_protocol::ml_ops::conv2d"),
240            &[TypeId::of::<Box<dyn ArrayProtocol>>()],
241            &[Box::new(input.box_clone()), Box::new(filters.box_clone())],
242            &kwargs,
243        )?;
244
245        // Try to downcast the result
246        match result.downcast::<Box<dyn ArrayProtocol>>() {
247            Ok(array) => Ok(*array),
248            Err(_) => Err(OperationError::Other(
249                "Failed to downcast array_function result".to_string(),
250            )),
251        }
252    },
253    "scirs2::array_protocol::ml, ops: conv2d"
254);
255
256array_function_dispatch!(
257    fn max_pool2d(
258        input: &dyn ArrayProtocol,
259        kernel_size: (usize, usize),
260        stride: (usize, usize),
261        padding: (usize, usize),
262    ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
263        // Get implementing args
264        let boxed_input = Box::new(input.box_clone());
265        let boxed_args: Vec<Box<dyn Any>> = vec![boxed_input];
266        let implementing_args = get_implementing_args(&boxed_args);
267        if implementing_args.is_empty() {
268            // Fallback implementation for ndarray types
269            if let Some(inputarray) = input.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>() {
270                let input = inputarray.as_array();
271
272                // Get dimensions
273                let batch_size = input.shape()[0];
274                let input_height = input.shape()[1];
275                let input_width = input.shape()[2];
276                let channels = input.shape()[3];
277
278                // Calculate output dimensions
279                let out_height = (input_height - kernel_size.0 + 2 * padding.0) / stride.0 + 1;
280                let out_width = (input_width - kernel_size.1 + 2 * padding.1) / stride.1 + 1;
281
282                // Create output array
283                let mut output: Array<f64, Ix4> =
284                    Array::zeros((batch_size, out_height, out_width, channels));
285
286                // Perform max pooling
287                for b in 0..batch_size {
288                    for c in 0..channels {
289                        for out_h in 0..out_height {
290                            for out_w in 0..out_width {
291                                let mut max_val = f64::NEG_INFINITY;
292
293                                // Pool over the kernel window
294                                for k_h in 0..kernel_size.0 {
295                                    for k_w in 0..kernel_size.1 {
296                                        // Calculate input coordinates with padding
297                                        let in_h = (out_h * stride.0) as i32 + k_h as i32
298                                            - padding.0 as i32;
299                                        let in_w = (out_w * stride.1) as i32 + k_w as i32
300                                            - padding.1 as i32;
301
302                                        // Check bounds
303                                        if in_h >= 0
304                                            && in_h < input_height as i32
305                                            && in_w >= 0
306                                            && in_w < input_width as i32
307                                        {
308                                            let val = input[[b, in_h as usize, in_w as usize, c]];
309                                            if val > max_val {
310                                                max_val = val;
311                                            }
312                                        }
313                                    }
314                                }
315
316                                // Use 0.0 if no valid values found (due to padding)
317                                output[[b, out_h, out_w, c]] = if max_val == f64::NEG_INFINITY {
318                                    0.0
319                                } else {
320                                    max_val
321                                };
322                            }
323                        }
324                    }
325                }
326
327                return Ok(Box::new(NdarrayWrapper::new(output)));
328            }
329            return Err(OperationError::NotImplemented(
330                "max_pool2d not implemented for this array type".to_string(),
331            ));
332        }
333
334        // Delegate to the implementation
335        let mut kwargs = HashMap::new();
336        kwargs.insert(
337            "kernel_size".to_string(),
338            Box::new(kernel_size) as Box<dyn Any>,
339        );
340        kwargs.insert("stride".to_string(), Box::new(stride) as Box<dyn Any>);
341        kwargs.insert("padding".to_string(), Box::new(padding) as Box<dyn Any>);
342
343        let array_ref = implementing_args[0].1;
344
345        let result = array_ref.array_function(
346            &crate::array_protocol::ArrayFunction::new(
347                "scirs2::array_protocol::ml_ops::max_pool2d",
348            ),
349            &[TypeId::of::<Box<dyn ArrayProtocol>>()],
350            &[Box::new(input.box_clone())],
351            &kwargs,
352        )?;
353
354        // Try to downcast the result
355        match result.downcast::<Box<dyn ArrayProtocol>>() {
356            Ok(array) => Ok(*array),
357            Err(_) => Err(OperationError::Other(
358                "Failed to downcast array_function result".to_string(),
359            )),
360        }
361    },
362    "scirs2::array_protocol::ml, ops: max_pool2d"
363);
364
365array_function_dispatch!(
366    fn batch_norm(
367        input: &dyn ArrayProtocol,
368        scale: &dyn ArrayProtocol,
369        offset: &dyn ArrayProtocol,
370        mean: &dyn ArrayProtocol,
371        variance: &dyn ArrayProtocol,
372        epsilon: f64,
373    ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
374        // Get implementing args - convert to Box<dyn Any>
375        let boxed_args: Vec<Box<dyn Any>> = vec![
376            Box::new(input.box_clone()),
377            Box::new(scale.box_clone()),
378            Box::new(offset.box_clone()),
379            Box::new(mean.box_clone()),
380            Box::new(variance.box_clone()),
381        ];
382        let implementing_args = get_implementing_args(&boxed_args);
383        if implementing_args.is_empty() {
384            // Fallback implementation for ndarray types
385            if let (
386                Some(inputarray),
387                Some(scale_array),
388                Some(offset_array),
389                Some(mean_array),
390                Some(variance_array),
391            ) = (
392                input.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
393                scale.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
394                offset.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
395                mean.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
396                variance
397                    .as_any()
398                    .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
399            ) {
400                let input = inputarray.as_array();
401                let scale = scale_array.as_array();
402                let offset = offset_array.as_array();
403                let mean = mean_array.as_array();
404                let variance = variance_array.as_array();
405
406                // Get dimensions
407                let _batch_size = input.shape()[0];
408                let _height = input.shape()[1];
409                let _width = input.shape()[2];
410                let channels = input.shape()[3];
411
412                // Check dimensions
413                if scale.shape()[0] != channels
414                    || offset.shape()[0] != channels
415                    || mean.shape()[0] != channels
416                    || variance.shape()[0] != channels
417                {
418                    return Err(OperationError::ShapeMismatch(
419                        "Scale, offset, mean, and variance must match the number of channels"
420                            .to_string(),
421                    ));
422                }
423
424                // Create output array with same shape as input
425                let mut output: Array<f64, Ix4> = Array::zeros(input.raw_dim());
426
427                // Perform batch normalization
428                // For each channel, normalize using the formula:
429                // y = scale * (x - mean) / sqrt(variance + epsilon) + offset
430
431                let batch_size = input.shape()[0];
432                let _height = input.shape()[1];
433                let _width = input.shape()[2];
434
435                for b in 0..batch_size {
436                    for h in 0.._height {
437                        for w in 0.._width {
438                            for c in 0..channels {
439                                let x = input[[b, h, w, c]];
440                                let m = mean[[c]];
441                                let v = variance[[c]];
442                                let s = scale[[c]];
443                                let o = offset[[c]];
444
445                                // Normalize: (x - mean) / sqrt(variance + epsilon)
446                                let normalized = (x - m) / (v + epsilon).sqrt();
447
448                                // Scale and shift: scale * normalized + offset
449                                let result = s * normalized + o;
450
451                                output[[b, h, w, c]] = result;
452                            }
453                        }
454                    }
455                }
456
457                return Ok(Box::new(NdarrayWrapper::new(output)));
458            }
459            return Err(OperationError::NotImplemented(
460                "batch_norm not implemented for these array types".to_string(),
461            ));
462        }
463
464        // Delegate to the implementation
465        let mut kwargs = HashMap::new();
466        kwargs.insert("epsilon".to_string(), Box::new(epsilon) as Box<dyn Any>);
467
468        let array_ref = implementing_args[0].1;
469
470        let result = array_ref.array_function(
471            &crate::array_protocol::ArrayFunction::new(
472                "scirs2::array_protocol::ml_ops::batch_norm",
473            ),
474            &[TypeId::of::<Box<dyn ArrayProtocol>>()],
475            &[
476                Box::new(input.box_clone()),
477                Box::new(scale.box_clone()),
478                Box::new(offset.box_clone()),
479                Box::new(mean.box_clone()),
480                Box::new(variance.box_clone()),
481            ],
482            &kwargs,
483        )?;
484
485        // Try to downcast the result
486        match result.downcast::<Box<dyn ArrayProtocol>>() {
487            Ok(array) => Ok(*array),
488            Err(_) => Err(OperationError::Other(
489                "Failed to downcast array_function result".to_string(),
490            )),
491        }
492    },
493    "scirs2::array_protocol::ml, ops: batch_norm"
494);
495
496array_function_dispatch!(
497    fn cross_entropy(
498        logits: &dyn ArrayProtocol,
499        labels: &dyn ArrayProtocol,
500        reduction: &str,
501    ) -> Result<Box<dyn Any>, OperationError> {
502        // Get implementing args
503        let boxed_args: Vec<Box<dyn Any>> =
504            vec![Box::new(logits.box_clone()), Box::new(labels.box_clone())];
505        let implementing_args = get_implementing_args(&boxed_args);
506        if implementing_args.is_empty() {
507            // Fallback implementation for ndarray types
508            if let (Some(logits_array), Some(labels_array)) = (
509                logits.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>(),
510                labels.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>(),
511            ) {
512                let logits = logits_array.as_array();
513                let labels = labels_array.as_array();
514
515                // Check shapes
516                if logits.shape() != labels.shape() {
517                    return Err(OperationError::ShapeMismatch(format!(
518                        "Logits shape {logitsshape:?} doesn't match labels shape {labelsshape:?}",
519                        logitsshape = logits.shape(),
520                        labelsshape = labels.shape()
521                    )));
522                }
523
524                // Apply softmax to logits
525                let mut softmax = Array::zeros(logits.raw_dim());
526
527                // For each sample in the batch
528                for (i, sample) in logits.outer_iter().enumerate() {
529                    let max_val = sample.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
530                    let exp_x = sample.mapv(|v| (v - max_val).exp());
531                    let sum_exp = exp_x.sum();
532
533                    for (j, val) in exp_x.iter().enumerate() {
534                        softmax[[i, j]] = val / sum_exp;
535                    }
536                }
537
538                // Compute cross-entropy loss
539                // loss = -sum(labels * log(softmax))
540                let mut sample_losses = Array::zeros(logits.shape()[0]);
541
542                for (i, (s, l)) in softmax.outer_iter().zip(labels.outer_iter()).enumerate() {
543                    let mut loss = 0.0;
544                    for (s_val, l_val) in s.iter().zip(l.iter()) {
545                        // Add small epsilon to avoid log(0)
546                        loss -= l_val * (s_val + 1e-10).ln();
547                    }
548                    sample_losses[i] = loss;
549                }
550
551                // Apply reduction
552                let loss = match reduction {
553                    "none" => sample_losses,
554                    "mean" => {
555                        let mean = sample_losses.sum() / sample_losses.len() as f64;
556                        // Use Array1 instead of Array0 to make type consistent
557                        Array::from_elem(Ix1(1), mean)
558                    }
559                    "sum" => {
560                        let sum = sample_losses.sum();
561                        // Use Array1 instead of Array0 to make type consistent
562                        Array::from_elem(Ix1(1), sum)
563                    }
564                    _ => {
565                        return Err(OperationError::ShapeMismatch(format!(
566                            "Unknown reduction method: {reduction}"
567                        )))
568                    }
569                };
570
571                return Ok(Box::new(loss) as Box<dyn Any>);
572            }
573            return Err(OperationError::NotImplemented(
574                "cross_entropy not implemented for these array types".to_string(),
575            ));
576        }
577
578        // Delegate to the implementation
579        let mut kwargs = HashMap::new();
580        kwargs.insert(
581            "reduction".to_string(),
582            Box::new(reduction.to_string()) as Box<dyn Any>,
583        );
584
585        let array_ref = implementing_args[0].1;
586
587        let result = array_ref.array_function(
588            &crate::array_protocol::ArrayFunction::new(
589                "scirs2::array_protocol::ml_ops::cross_entropy",
590            ),
591            &[TypeId::of::<Box<dyn Any>>()],
592            &[Box::new(logits.box_clone()), Box::new(labels.box_clone())],
593            &kwargs,
594        )?;
595
596        Ok(result)
597    },
598    "scirs2::array_protocol::ml, ops: cross_entropy"
599);
600
601array_function_dispatch!(
602    fn dropout(
603        input: &dyn ArrayProtocol,
604        rate: f64,
605        training: bool,
606        seed: Option<u64>,
607    ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
608        // Get implementing args
609        let boxed_args: Vec<Box<dyn Any>> = vec![Box::new(input.box_clone())];
610        let implementing_args = get_implementing_args(&boxed_args);
611        if implementing_args.is_empty() {
612            // Fallback implementation for ndarray types
613            if let Some(inputarray) = input.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
614                let input = inputarray.as_array();
615
616                if !training {
617                    // During inference, just scale the input
618                    return Ok(Box::new(NdarrayWrapper::new(input.clone())));
619                }
620
621                // Create a binary mask with probabilities (1-rate)
622                let mut rng = match seed {
623                    Some(s) => rand::rngs::StdRng::seed_from_u64(s),
624                    None => {
625                        let mut rng = rand::rng();
626                        // Get a random seed from rng and create a new StdRng
627                        let random_seed: u64 = rng.random();
628                        rand::rngs::StdRng::seed_from_u64(random_seed)
629                    }
630                };
631
632                let mask = Array::from_shape_fn(input.raw_dim(), |_| {
633                    if rng.random::<f64>() >= rate {
634                        1.0
635                    } else {
636                        0.0
637                    }
638                });
639
640                // Scale by 1/(1-rate) to maintain expected value during training
641                let scale = 1.0 / (1.0 - rate);
642                let result = input.clone() * &mask * scale;
643
644                return Ok(Box::new(NdarrayWrapper::new(result)));
645            }
646            return Err(OperationError::NotImplemented(
647                "dropout not implemented for this array type".to_string(),
648            ));
649        }
650
651        // Delegate to the implementation
652        let mut kwargs = HashMap::new();
653        kwargs.insert("rate".to_string(), Box::new(rate) as Box<dyn Any>);
654        kwargs.insert("training".to_string(), Box::new(training) as Box<dyn Any>);
655        if let Some(s) = seed {
656            kwargs.insert("seed".to_string(), Box::new(s) as Box<dyn Any>);
657        }
658
659        let array_ref = implementing_args[0].1;
660
661        let result = array_ref.array_function(
662            &crate::array_protocol::ArrayFunction::new("scirs2::array_protocol::ml_ops::dropout"),
663            &[TypeId::of::<Box<dyn ArrayProtocol>>()],
664            &[Box::new(input.box_clone())],
665            &kwargs,
666        )?;
667
668        // Try to downcast the result
669        match result.downcast::<Box<dyn ArrayProtocol>>() {
670            Ok(array) => Ok(*array),
671            Err(_) => Err(OperationError::Other(
672                "Failed to downcast array_function result".to_string(),
673            )),
674        }
675    },
676    "scirs2::array_protocol::ml, ops: dropout"
677);
678
679array_function_dispatch!(
680    fn self_attention(
681        queries: &dyn ArrayProtocol,
682        keys: &dyn ArrayProtocol,
683        values: &dyn ArrayProtocol,
684        mask: Option<&dyn ArrayProtocol>,
685        scale: Option<f64>,
686    ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
687        // Get implementing args
688        let mut boxed_args: Vec<Box<dyn Any>> = vec![
689            Box::new(queries.box_clone()),
690            Box::new(keys.box_clone()),
691            Box::new(values.box_clone()),
692        ];
693        if let Some(m) = mask {
694            boxed_args.push(Box::new(m.box_clone()));
695        }
696
697        let implementing_args = get_implementing_args(&boxed_args);
698        if implementing_args.is_empty() {
699            // Fallback implementation for ndarray types
700            if let (Some(q_array), Some(k_array), Some(v_array)) = (
701                queries.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
702                keys.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
703                values.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
704            ) {
705                let q = q_array.as_array();
706                let k = k_array.as_array();
707                let v = v_array.as_array();
708
709                // Get dimensions
710                // q, k, v should have shape [batch_size, seq_len, num_heads, d_k]
711                let batch_size = q.shape()[0];
712                let q_len = q.shape()[1];
713                let num_heads = q.shape()[2];
714                let d_k = q.shape()[3];
715
716                let k_len = k.shape()[1];
717
718                // Check dimensions
719                if k.shape()[0] != batch_size
720                    || k.shape()[2] != num_heads
721                    || k.shape()[3] != d_k
722                    || v.shape()[0] != batch_size
723                    || v.shape()[1] != k_len
724                    || v.shape()[2] != num_heads
725                {
726                    return Err(OperationError::ShapeMismatch(
727                        "Incompatible shapes for self-attention".to_string(),
728                    ));
729                }
730
731                // Apply scaling
732                let _scale_factor = scale.unwrap_or_else(|| {
733                    // Default scale factor is 1/sqrt(d_k)
734                    let d_k_f64 = d_k as f64;
735                    if d_k_f64 > 0.0 {
736                        d_k_f64.sqrt()
737                    } else {
738                        1.0 // Fallback for edge case
739                    }
740                });
741
742                // Implement self-attention:
743                // 1. scores = matmul(q, k.transpose) / scale_factor
744                // 2. if mask: scores = scores.masked_fill(mask, -inf)
745                // 3. attention = softmax(scores)
746                // 4. output = matmul(attention, v)
747
748                let scale_factor = scale.unwrap_or(1.0 / (d_k as f64).sqrt());
749                let mut output: Array<f64, Ix3> = Array::zeros((batch_size, q_len, d_k));
750
751                for b in 0..batch_size {
752                    // Extract batch slices
753                    let q_batch = q.slice(ndarray::s![b, .., .., ..]);
754                    let k_batch = k.slice(ndarray::s![b, .., .., ..]);
755                    let v_batch = v.slice(ndarray::s![b, .., .., ..]);
756
757                    // Compute attention scores: Q * K^T for each head
758                    let mut head_outputs = Array::zeros((q_len, num_heads, d_k));
759
760                    for h in 0..num_heads {
761                        let mut scores = Array::zeros((q_len, k_len));
762                        for i in 0..q_len {
763                            for j in 0..k_len {
764                                let mut dot_product = 0.0;
765                                for k in 0..d_k {
766                                    dot_product += q_batch[[i, h, k]] * k_batch[[j, h, k]];
767                                }
768                                scores[[i, j]] = dot_product * scale_factor;
769                            }
770                        }
771
772                        // Apply mask if provided
773                        if let Some(mask_array) = mask {
774                            if let Some(mask_wrapper) = mask_array
775                                .as_any()
776                                .downcast_ref::<NdarrayWrapper<f64, Ix3>>()
777                            {
778                                let mask_batch =
779                                    mask_wrapper.as_array().slice(ndarray::s![b, .., ..]);
780                                for i in 0..q_len {
781                                    for j in 0..k_len {
782                                        if mask_batch[[i, j]] == 0.0 {
783                                            scores[[i, j]] = f64::NEG_INFINITY;
784                                        }
785                                    }
786                                }
787                            }
788                        }
789
790                        // Apply softmax to get attention weights
791                        let mut attention = Array::zeros((q_len, k_len));
792                        for i in 0..q_len {
793                            // Find max for numerical stability
794                            let mut max_score = f64::NEG_INFINITY;
795                            for j in 0..k_len {
796                                if scores[[i, j]] > max_score {
797                                    max_score = scores[[i, j]];
798                                }
799                            }
800
801                            // Compute exp and sum
802                            let mut exp_sum = 0.0;
803                            for j in 0..k_len {
804                                let exp_val = (scores[[i, j]] - max_score).exp();
805                                attention[[i, j]] = exp_val;
806                                exp_sum += exp_val;
807                            }
808
809                            // Normalize
810                            for j in 0..k_len {
811                                attention[[i, j]] /= exp_sum;
812                            }
813                        }
814
815                        // Compute output: attention * V
816                        for i in 0..q_len {
817                            for k in 0..d_k {
818                                let mut weighted_sum = 0.0;
819                                for j in 0..k_len {
820                                    weighted_sum += attention[[i, j]] * v_batch[[j, h, k]];
821                                }
822                                head_outputs[[i, h, k]] = weighted_sum;
823                            }
824                        }
825                    }
826
827                    // Aggregate outputs from all heads (simple average)
828                    for i in 0..q_len {
829                        for k in 0..d_k {
830                            let mut sum = 0.0;
831                            for h in 0..num_heads {
832                                sum += head_outputs[[i, h, k]];
833                            }
834                            output[[b, i, k]] = sum / num_heads as f64;
835                        }
836                    }
837                }
838
839                return Ok(Box::new(NdarrayWrapper::new(output)));
840            }
841            return Err(OperationError::NotImplemented(
842                "self_attention not implemented for these array types".to_string(),
843            ));
844        }
845
846        // Delegate to the implementation
847        let mut kwargs = HashMap::new();
848        if let Some(s) = scale {
849            kwargs.insert("scale".to_string(), Box::new(s) as Box<dyn Any>);
850        }
851        if let Some(m) = mask {
852            kwargs.insert("mask".to_string(), Box::new(m.box_clone()) as Box<dyn Any>);
853        }
854
855        let array_ref = implementing_args[0].1;
856
857        let result = array_ref.array_function(
858            &crate::array_protocol::ArrayFunction::new(
859                "scirs2::array_protocol::ml_ops::self_attention",
860            ),
861            &[TypeId::of::<Box<dyn ArrayProtocol>>()],
862            &[
863                Box::new(queries.box_clone()),
864                Box::new(keys.box_clone()),
865                Box::new(values.box_clone()),
866            ],
867            &kwargs,
868        )?;
869
870        // Try to downcast the result
871        match result.downcast::<Box<dyn ArrayProtocol>>() {
872            Ok(array) => Ok(*array),
873            Err(_) => Err(OperationError::Other(
874                "Failed to downcast array_function result".to_string(),
875            )),
876        }
877    },
878    "scirs2::array_protocol::ml, ops: self_attention"
879);