Skip to main content

scirs2_core/array_protocol/
ml_ops.rs

1// Copyright (c) 2025, `SciRS2` Team
2//
3// Licensed under the Apache License, Version 2.0
4// (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
5//
6
7//! Machine learning operations using the array protocol.
8//!
9//! This module provides implementations of various machine learning operations
10//! using the array protocol, such as activation functions, convolution, and
11//! pooling.
12
13use std::any::{Any, TypeId};
14use std::collections::HashMap;
15
16use ::ndarray::{Array, Axis, Ix1, Ix2, Ix3, Ix4, IxDyn};
17use rand::{Rng, RngExt, SeedableRng};
18
19use crate::array_protocol::operations::OperationError;
20use crate::array_protocol::{
21    array_function_dispatch, get_implementing_args, ArrayProtocol, NdarrayWrapper,
22};
23
24/// Activation function types.
25#[derive(Debug, Clone, Copy, PartialEq)]
26pub enum ActivationFunc {
27    /// Rectified Linear Unit: f(x) = max(0, x)
28    ReLU,
29
30    /// Sigmoid function: f(x) = 1 / (1 + exp(-x))
31    Sigmoid,
32
33    /// Hyperbolic tangent: f(x) = tanh(x)
34    Tanh,
35
36    /// Softmax function (applied along the last dimension)
37    Softmax,
38
39    /// Leaky ReLU: f(x) = max(alpha * x, x)
40    LeakyReLU(f64),
41}
42
43/// Apply an activation function to an array.
44#[allow(dead_code)]
45fn apply_activation(
46    x: &crate::ndarray::ArrayBase<crate::ndarray::ViewRepr<&f64>, IxDyn>,
47    func: ActivationFunc,
48) -> Array<f64, IxDyn> {
49    match func {
50        ActivationFunc::ReLU => x.mapv(|v| v.max(0.0)),
51        ActivationFunc::Sigmoid => x.mapv(|v| 1.0 / (1.0 + (-v).exp())),
52        ActivationFunc::Tanh => x.mapv(|v| v.tanh()),
53        ActivationFunc::LeakyReLU(alpha) => x.mapv(|v| if v > 0.0 { v } else { alpha * v }),
54        ActivationFunc::Softmax => {
55            // Apply softmax along the last dimension
56            let mut result = Array::zeros(x.raw_dim());
57
58            // Iterate over all but the last dimension
59            let last_dim = x.ndim() - 1;
60            let _last_dim_len = x.shape()[last_dim];
61
62            if x.ndim() == 1 {
63                // Simple 1D case
64                let max_val = x.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
65                let exp_x = x.mapv(|v| (v - max_val).exp());
66                let sum_exp = exp_x.sum();
67                result.assign(&(exp_x / sum_exp));
68            } else {
69                // Multi-dimensional case
70                for (i, mut slice) in result.lanes_mut(Axis(last_dim)).into_iter().enumerate() {
71                    // Use index_axis to get the ith slice along the last dimension
72                    let x_slice = x.index_axis(Axis(last_dim), i);
73                    let max_val = x_slice.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
74                    let exp_x = x_slice.mapv(|v| (v - max_val).exp());
75                    let sum_exp = exp_x.sum();
76                    slice.assign(&(exp_x / sum_exp));
77                }
78            }
79
80            result
81        }
82    }
83}
84
85// Define machine learning operations using the array protocol
86
87array_function_dispatch!(
88    fn activation(
89        x: &dyn ArrayProtocol,
90        func: ActivationFunc,
91    ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
92        // Get implementing args
93        let boxed_x = Box::new(x.box_clone());
94        let boxed_args: Vec<Box<dyn Any>> = vec![boxed_x];
95        let implementing_args = get_implementing_args(&boxed_args);
96        if implementing_args.is_empty() {
97            // Fallback implementation for ndarray types
98            if let Some(x_array) = x.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
99                let x_array = x_array.as_array();
100                let result = apply_activation(&x_array.view(), func);
101                return Ok(Box::new(NdarrayWrapper::new(result)));
102            }
103            return Err(OperationError::NotImplemented(
104                "activation not implemented for this array type".to_string(),
105            ));
106        }
107
108        // Delegate to the implementation
109        let array_ref = implementing_args[0].1;
110
111        let result = array_ref.array_function(
112            &crate::array_protocol::ArrayFunction::new(
113                "scirs2::array_protocol::ml_ops::activation",
114            ),
115            &[TypeId::of::<Box<dyn ArrayProtocol>>()],
116            &[Box::new(x.box_clone())],
117            &HashMap::new(),
118        )?;
119
120        // Try to downcast the result
121        match result.downcast::<Box<dyn ArrayProtocol>>() {
122            Ok(array) => Ok(*array),
123            Err(_) => Err(OperationError::Other(
124                "Failed to downcast array_function result".to_string(),
125            )),
126        }
127    },
128    "scirs2::array_protocol::ml, ops: activation"
129);
130
131array_function_dispatch!(
132    fn conv2d(
133        input: &dyn ArrayProtocol,
134        filters: &dyn ArrayProtocol,
135        stride: (usize, usize),
136        padding: (usize, usize),
137    ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
138        // Get implementing args
139        let boxed_input = Box::new(input.box_clone());
140        let boxed_filters = Box::new(filters.box_clone());
141        let boxed_args: Vec<Box<dyn Any>> = vec![boxed_input, boxed_filters];
142        let implementing_args = get_implementing_args(&boxed_args);
143        if implementing_args.is_empty() {
144            // Fallback implementation for ndarray types
145            // This is a simplified implementation - in practice, convolution is much more complex
146            if let (Some(inputarray), Some(filters_array)) = (
147                input.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
148                filters.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
149            ) {
150                let input = inputarray.as_array();
151                let filters = filters_array.as_array();
152
153                // Get dimensions
154                let batch_size = input.shape()[0];
155                let input_height = input.shape()[1];
156                let input_width = input.shape()[2];
157                let input_channels = input.shape()[3];
158
159                let filter_height = filters.shape()[0];
160                let filter_width = filters.shape()[1];
161                let filter_in_channels = filters.shape()[2];
162                let filter_out_channels = filters.shape()[3];
163
164                // Check dimensions
165                if input_channels != filter_in_channels {
166                    return Err(OperationError::ShapeMismatch(format!(
167                        "Input channels ({input_channels}) doesn't match filter input channels ({filter_in_channels})"
168                    )));
169                }
170
171                // Calculate output dimensions
172                let out_height = (input_height - filter_height + 2 * padding.0) / stride.0 + 1;
173                let out_width = (input_width - filter_width + 2 * padding.1) / stride.1 + 1;
174
175                // Create output array
176                let mut output: Array<f64, Ix4> =
177                    Array::zeros((batch_size, out_height, out_width, filter_out_channels));
178
179                // Perform convolution using basic sliding window approach
180                for b in 0..batch_size {
181                    for out_c in 0..filter_out_channels {
182                        for out_h in 0..out_height {
183                            for out_w in 0..out_width {
184                                let mut sum = 0.0;
185
186                                // Convolution over the filter window
187                                for f_h in 0..filter_height {
188                                    for f_w in 0..filter_width {
189                                        for in_c in 0..input_channels {
190                                            // Calculate input coordinates with padding
191                                            let in_h = (out_h * stride.0) as i32 + f_h as i32
192                                                - padding.0 as i32;
193                                            let in_w = (out_w * stride.1) as i32 + f_w as i32
194                                                - padding.1 as i32;
195
196                                            // Check bounds (zero padding)
197                                            if in_h >= 0
198                                                && in_h < input_height as i32
199                                                && in_w >= 0
200                                                && in_w < input_width as i32
201                                            {
202                                                let input_val =
203                                                    input[[b, in_h as usize, in_w as usize, in_c]];
204                                                let filter_val = filters[[f_h, f_w, in_c, out_c]];
205                                                sum += input_val * filter_val;
206                                            }
207                                        }
208                                    }
209                                }
210
211                                output[[b, out_h, out_w, out_c]] = sum;
212                            }
213                        }
214                    }
215                }
216
217                return Ok(Box::new(NdarrayWrapper::new(output)));
218            }
219            return Err(OperationError::NotImplemented(
220                "conv2d not implemented for these array types".to_string(),
221            ));
222        }
223
224        // Delegate to the implementation
225        let mut kwargs = HashMap::new();
226        kwargs.insert("stride".to_string(), Box::new(stride) as Box<dyn Any>);
227        kwargs.insert("padding".to_string(), Box::new(padding) as Box<dyn Any>);
228
229        let array_ref = implementing_args[0].1;
230
231        let result = array_ref.array_function(
232            &crate::array_protocol::ArrayFunction::new("scirs2::array_protocol::ml_ops::conv2d"),
233            &[TypeId::of::<Box<dyn ArrayProtocol>>()],
234            &[Box::new(input.box_clone()), Box::new(filters.box_clone())],
235            &kwargs,
236        )?;
237
238        // Try to downcast the result
239        match result.downcast::<Box<dyn ArrayProtocol>>() {
240            Ok(array) => Ok(*array),
241            Err(_) => Err(OperationError::Other(
242                "Failed to downcast array_function result".to_string(),
243            )),
244        }
245    },
246    "scirs2::array_protocol::ml, ops: conv2d"
247);
248
249array_function_dispatch!(
250    fn max_pool2d(
251        input: &dyn ArrayProtocol,
252        kernel_size: (usize, usize),
253        stride: (usize, usize),
254        padding: (usize, usize),
255    ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
256        // Get implementing args
257        let boxed_input = Box::new(input.box_clone());
258        let boxed_args: Vec<Box<dyn Any>> = vec![boxed_input];
259        let implementing_args = get_implementing_args(&boxed_args);
260        if implementing_args.is_empty() {
261            // Fallback implementation for ndarray types
262            if let Some(inputarray) = input.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>() {
263                let input = inputarray.as_array();
264
265                // Get dimensions
266                let batch_size = input.shape()[0];
267                let input_height = input.shape()[1];
268                let input_width = input.shape()[2];
269                let channels = input.shape()[3];
270
271                // Calculate output dimensions
272                let out_height = (input_height - kernel_size.0 + 2 * padding.0) / stride.0 + 1;
273                let out_width = (input_width - kernel_size.1 + 2 * padding.1) / stride.1 + 1;
274
275                // Create output array
276                let mut output: Array<f64, Ix4> =
277                    Array::zeros((batch_size, out_height, out_width, channels));
278
279                // Perform max pooling
280                for b in 0..batch_size {
281                    for c in 0..channels {
282                        for out_h in 0..out_height {
283                            for out_w in 0..out_width {
284                                let mut max_val = f64::NEG_INFINITY;
285
286                                // Pool over the kernel window
287                                for k_h in 0..kernel_size.0 {
288                                    for k_w in 0..kernel_size.1 {
289                                        // Calculate input coordinates with padding
290                                        let in_h = (out_h * stride.0) as i32 + k_h as i32
291                                            - padding.0 as i32;
292                                        let in_w = (out_w * stride.1) as i32 + k_w as i32
293                                            - padding.1 as i32;
294
295                                        // Check bounds
296                                        if in_h >= 0
297                                            && in_h < input_height as i32
298                                            && in_w >= 0
299                                            && in_w < input_width as i32
300                                        {
301                                            let val = input[[b, in_h as usize, in_w as usize, c]];
302                                            if val > max_val {
303                                                max_val = val;
304                                            }
305                                        }
306                                    }
307                                }
308
309                                // Use 0.0 if no valid values found (due to padding)
310                                output[[b, out_h, out_w, c]] = if max_val == f64::NEG_INFINITY {
311                                    0.0
312                                } else {
313                                    max_val
314                                };
315                            }
316                        }
317                    }
318                }
319
320                return Ok(Box::new(NdarrayWrapper::new(output)));
321            }
322            return Err(OperationError::NotImplemented(
323                "max_pool2d not implemented for this array type".to_string(),
324            ));
325        }
326
327        // Delegate to the implementation
328        let mut kwargs = HashMap::new();
329        kwargs.insert(
330            "kernel_size".to_string(),
331            Box::new(kernel_size) as Box<dyn Any>,
332        );
333        kwargs.insert("stride".to_string(), Box::new(stride) as Box<dyn Any>);
334        kwargs.insert("padding".to_string(), Box::new(padding) as Box<dyn Any>);
335
336        let array_ref = implementing_args[0].1;
337
338        let result = array_ref.array_function(
339            &crate::array_protocol::ArrayFunction::new(
340                "scirs2::array_protocol::ml_ops::max_pool2d",
341            ),
342            &[TypeId::of::<Box<dyn ArrayProtocol>>()],
343            &[Box::new(input.box_clone())],
344            &kwargs,
345        )?;
346
347        // Try to downcast the result
348        match result.downcast::<Box<dyn ArrayProtocol>>() {
349            Ok(array) => Ok(*array),
350            Err(_) => Err(OperationError::Other(
351                "Failed to downcast array_function result".to_string(),
352            )),
353        }
354    },
355    "scirs2::array_protocol::ml, ops: max_pool2d"
356);
357
358array_function_dispatch!(
359    fn batch_norm(
360        input: &dyn ArrayProtocol,
361        scale: &dyn ArrayProtocol,
362        offset: &dyn ArrayProtocol,
363        mean: &dyn ArrayProtocol,
364        variance: &dyn ArrayProtocol,
365        epsilon: f64,
366    ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
367        // Get implementing args - convert to Box<dyn Any>
368        let boxed_args: Vec<Box<dyn Any>> = vec![
369            Box::new(input.box_clone()),
370            Box::new(scale.box_clone()),
371            Box::new(offset.box_clone()),
372            Box::new(mean.box_clone()),
373            Box::new(variance.box_clone()),
374        ];
375        let implementing_args = get_implementing_args(&boxed_args);
376        if implementing_args.is_empty() {
377            // Fallback implementation for ndarray types
378            if let (
379                Some(inputarray),
380                Some(scale_array),
381                Some(offset_array),
382                Some(mean_array),
383                Some(variance_array),
384            ) = (
385                input.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
386                scale.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
387                offset.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
388                mean.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
389                variance
390                    .as_any()
391                    .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
392            ) {
393                let input = inputarray.as_array();
394                let scale = scale_array.as_array();
395                let offset = offset_array.as_array();
396                let mean = mean_array.as_array();
397                let variance = variance_array.as_array();
398
399                // Get dimensions
400                let _batch_size = input.shape()[0];
401                let _height = input.shape()[1];
402                let _width = input.shape()[2];
403                let channels = input.shape()[3];
404
405                // Check dimensions
406                if scale.shape()[0] != channels
407                    || offset.shape()[0] != channels
408                    || mean.shape()[0] != channels
409                    || variance.shape()[0] != channels
410                {
411                    return Err(OperationError::ShapeMismatch(
412                        "Scale, offset, mean, and variance must match the number of channels"
413                            .to_string(),
414                    ));
415                }
416
417                // Create output array with same shape as input
418                let mut output: Array<f64, Ix4> = Array::zeros(input.raw_dim());
419
420                // Perform batch normalization
421                // For each channel, normalize using the formula:
422                // y = scale * (x - mean) / sqrt(variance + epsilon) + offset
423
424                let batch_size = input.shape()[0];
425                let _height = input.shape()[1];
426                let _width = input.shape()[2];
427
428                for b in 0..batch_size {
429                    for h in 0.._height {
430                        for w in 0.._width {
431                            for c in 0..channels {
432                                let x = input[[b, h, w, c]];
433                                let m = mean[[c]];
434                                let v = variance[[c]];
435                                let s = scale[[c]];
436                                let o = offset[[c]];
437
438                                // Normalize: (x - mean) / sqrt(variance + epsilon)
439                                let normalized = (x - m) / (v + epsilon).sqrt();
440
441                                // Scale and shift: scale * normalized + offset
442                                let result = s * normalized + o;
443
444                                output[[b, h, w, c]] = result;
445                            }
446                        }
447                    }
448                }
449
450                return Ok(Box::new(NdarrayWrapper::new(output)));
451            }
452            return Err(OperationError::NotImplemented(
453                "batch_norm not implemented for these array types".to_string(),
454            ));
455        }
456
457        // Delegate to the implementation
458        let mut kwargs = HashMap::new();
459        kwargs.insert("epsilon".to_string(), Box::new(epsilon) as Box<dyn Any>);
460
461        let array_ref = implementing_args[0].1;
462
463        let result = array_ref.array_function(
464            &crate::array_protocol::ArrayFunction::new(
465                "scirs2::array_protocol::ml_ops::batch_norm",
466            ),
467            &[TypeId::of::<Box<dyn ArrayProtocol>>()],
468            &[
469                Box::new(input.box_clone()),
470                Box::new(scale.box_clone()),
471                Box::new(offset.box_clone()),
472                Box::new(mean.box_clone()),
473                Box::new(variance.box_clone()),
474            ],
475            &kwargs,
476        )?;
477
478        // Try to downcast the result
479        match result.downcast::<Box<dyn ArrayProtocol>>() {
480            Ok(array) => Ok(*array),
481            Err(_) => Err(OperationError::Other(
482                "Failed to downcast array_function result".to_string(),
483            )),
484        }
485    },
486    "scirs2::array_protocol::ml, ops: batch_norm"
487);
488
489array_function_dispatch!(
490    fn cross_entropy(
491        logits: &dyn ArrayProtocol,
492        labels: &dyn ArrayProtocol,
493        reduction: &str,
494    ) -> Result<Box<dyn Any>, OperationError> {
495        // Get implementing args
496        let boxed_args: Vec<Box<dyn Any>> =
497            vec![Box::new(logits.box_clone()), Box::new(labels.box_clone())];
498        let implementing_args = get_implementing_args(&boxed_args);
499        if implementing_args.is_empty() {
500            // Fallback implementation for ndarray types
501            if let (Some(logits_array), Some(labels_array)) = (
502                logits.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>(),
503                labels.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>(),
504            ) {
505                let logits = logits_array.as_array();
506                let labels = labels_array.as_array();
507
508                // Check shapes
509                if logits.shape() != labels.shape() {
510                    return Err(OperationError::ShapeMismatch(format!(
511                        "Logits shape {logitsshape:?} doesn't match labels shape {labelsshape:?}",
512                        logitsshape = logits.shape(),
513                        labelsshape = labels.shape()
514                    )));
515                }
516
517                // Apply softmax to logits
518                let mut softmax = Array::zeros(logits.raw_dim());
519
520                // For each sample in the batch
521                for (i, sample) in logits.outer_iter().enumerate() {
522                    let max_val = sample.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
523                    let exp_x = sample.mapv(|v| (v - max_val).exp());
524                    let sum_exp = exp_x.sum();
525
526                    for (j, val) in exp_x.iter().enumerate() {
527                        softmax[[i, j]] = val / sum_exp;
528                    }
529                }
530
531                // Compute cross-entropy loss
532                // loss = -sum(labels * log(softmax))
533                let mut sample_losses = Array::zeros(logits.shape()[0]);
534
535                for (i, (s, l)) in softmax.outer_iter().zip(labels.outer_iter()).enumerate() {
536                    let mut loss = 0.0;
537                    for (s_val, l_val) in s.iter().zip(l.iter()) {
538                        // Add small epsilon to avoid log(0)
539                        loss -= l_val * (s_val + 1e-10).ln();
540                    }
541                    sample_losses[i] = loss;
542                }
543
544                // Apply reduction
545                let loss = match reduction {
546                    "none" => sample_losses,
547                    "mean" => {
548                        let mean = sample_losses.sum() / sample_losses.len() as f64;
549                        // Use Array1 instead of Array0 to make type consistent
550                        Array::from_elem(Ix1(1), mean)
551                    }
552                    "sum" => {
553                        let sum = sample_losses.sum();
554                        // Use Array1 instead of Array0 to make type consistent
555                        Array::from_elem(Ix1(1), sum)
556                    }
557                    _ => {
558                        return Err(OperationError::ShapeMismatch(format!(
559                            "Unknown reduction method: {reduction}"
560                        )))
561                    }
562                };
563
564                return Ok(Box::new(loss) as Box<dyn Any>);
565            }
566            return Err(OperationError::NotImplemented(
567                "cross_entropy not implemented for these array types".to_string(),
568            ));
569        }
570
571        // Delegate to the implementation
572        let mut kwargs = HashMap::new();
573        kwargs.insert(
574            "reduction".to_string(),
575            Box::new(reduction.to_string()) as Box<dyn Any>,
576        );
577
578        let array_ref = implementing_args[0].1;
579
580        let result = array_ref.array_function(
581            &crate::array_protocol::ArrayFunction::new(
582                "scirs2::array_protocol::ml_ops::cross_entropy",
583            ),
584            &[TypeId::of::<Box<dyn Any>>()],
585            &[Box::new(logits.box_clone()), Box::new(labels.box_clone())],
586            &kwargs,
587        )?;
588
589        Ok(result)
590    },
591    "scirs2::array_protocol::ml, ops: cross_entropy"
592);
593
594array_function_dispatch!(
595    fn dropout(
596        input: &dyn ArrayProtocol,
597        rate: f64,
598        training: bool,
599        seed: Option<u64>,
600    ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
601        // Get implementing args
602        let boxed_args: Vec<Box<dyn Any>> = vec![Box::new(input.box_clone())];
603        let implementing_args = get_implementing_args(&boxed_args);
604        if implementing_args.is_empty() {
605            // Fallback implementation for ndarray types
606            if let Some(inputarray) = input.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
607                let input = inputarray.as_array();
608
609                if !training {
610                    // During inference, just scale the input
611                    return Ok(Box::new(NdarrayWrapper::new(input.clone())));
612                }
613
614                // Create a binary mask with probabilities (1-rate)
615                let mut rng = match seed {
616                    Some(s) => rand::rngs::StdRng::seed_from_u64(s),
617                    None => {
618                        let mut rng = rand::rng();
619                        // Get a random seed from rng and create a new StdRng
620                        let random_seed: u64 = rng.random();
621                        rand::rngs::StdRng::seed_from_u64(random_seed)
622                    }
623                };
624
625                let mask = Array::from_shape_fn(input.raw_dim(), |_| {
626                    if rng.random::<f64>() >= rate {
627                        1.0
628                    } else {
629                        0.0
630                    }
631                });
632
633                // Scale by 1/(1-rate) to maintain expected value during training
634                let scale = 1.0 / (1.0 - rate);
635                let result = input.clone() * &mask * scale;
636
637                return Ok(Box::new(NdarrayWrapper::new(result)));
638            }
639            return Err(OperationError::NotImplemented(
640                "dropout not implemented for this array type".to_string(),
641            ));
642        }
643
644        // Delegate to the implementation
645        let mut kwargs = HashMap::new();
646        kwargs.insert("rate".to_string(), Box::new(rate) as Box<dyn Any>);
647        kwargs.insert("training".to_string(), Box::new(training) as Box<dyn Any>);
648        if let Some(s) = seed {
649            kwargs.insert("seed".to_string(), Box::new(s) as Box<dyn Any>);
650        }
651
652        let array_ref = implementing_args[0].1;
653
654        let result = array_ref.array_function(
655            &crate::array_protocol::ArrayFunction::new("scirs2::array_protocol::ml_ops::dropout"),
656            &[TypeId::of::<Box<dyn ArrayProtocol>>()],
657            &[Box::new(input.box_clone())],
658            &kwargs,
659        )?;
660
661        // Try to downcast the result
662        match result.downcast::<Box<dyn ArrayProtocol>>() {
663            Ok(array) => Ok(*array),
664            Err(_) => Err(OperationError::Other(
665                "Failed to downcast array_function result".to_string(),
666            )),
667        }
668    },
669    "scirs2::array_protocol::ml, ops: dropout"
670);
671
672array_function_dispatch!(
673    fn self_attention(
674        queries: &dyn ArrayProtocol,
675        keys: &dyn ArrayProtocol,
676        values: &dyn ArrayProtocol,
677        mask: Option<&dyn ArrayProtocol>,
678        scale: Option<f64>,
679    ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
680        // Get implementing args
681        let mut boxed_args: Vec<Box<dyn Any>> = vec![
682            Box::new(queries.box_clone()),
683            Box::new(keys.box_clone()),
684            Box::new(values.box_clone()),
685        ];
686        if let Some(m) = mask {
687            boxed_args.push(Box::new(m.box_clone()));
688        }
689
690        let implementing_args = get_implementing_args(&boxed_args);
691        if implementing_args.is_empty() {
692            // Fallback implementation for ndarray types
693            if let (Some(q_array), Some(k_array), Some(v_array)) = (
694                queries.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
695                keys.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
696                values.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
697            ) {
698                let q = q_array.as_array();
699                let k = k_array.as_array();
700                let v = v_array.as_array();
701
702                // Get dimensions
703                // q, k, v should have shape [batch_size, seq_len, num_heads, d_k]
704                let batch_size = q.shape()[0];
705                let q_len = q.shape()[1];
706                let num_heads = q.shape()[2];
707                let d_k = q.shape()[3];
708
709                let k_len = k.shape()[1];
710
711                // Check dimensions
712                if k.shape()[0] != batch_size
713                    || k.shape()[2] != num_heads
714                    || k.shape()[3] != d_k
715                    || v.shape()[0] != batch_size
716                    || v.shape()[1] != k_len
717                    || v.shape()[2] != num_heads
718                {
719                    return Err(OperationError::ShapeMismatch(
720                        "Incompatible shapes for self-attention".to_string(),
721                    ));
722                }
723
724                // Apply scaling
725                let _scale_factor = scale.unwrap_or_else(|| {
726                    // Default scale factor is 1/sqrt(d_k)
727                    let d_k_f64 = d_k as f64;
728                    if d_k_f64 > 0.0 {
729                        d_k_f64.sqrt()
730                    } else {
731                        1.0 // Fallback for edge case
732                    }
733                });
734
735                // Implement self-attention:
736                // 1. scores = matmul(q, k.transpose) / scale_factor
737                // 2. if mask: scores = scores.masked_fill(mask, -inf)
738                // 3. attention = softmax(scores)
739                // 4. output = matmul(attention, v)
740
741                let scale_factor = scale.unwrap_or(1.0 / (d_k as f64).sqrt());
742                let mut output: Array<f64, Ix3> = Array::zeros((batch_size, q_len, d_k));
743
744                for b in 0..batch_size {
745                    // Extract batch slices
746                    let q_batch = q.slice(crate::s![b, .., .., ..]);
747                    let k_batch = k.slice(crate::s![b, .., .., ..]);
748                    let v_batch = v.slice(crate::s![b, .., .., ..]);
749
750                    // Compute attention scores: Q * K^T for each head
751                    let mut head_outputs = Array::zeros((q_len, num_heads, d_k));
752
753                    for h in 0..num_heads {
754                        let mut scores = Array::zeros((q_len, k_len));
755                        for i in 0..q_len {
756                            for j in 0..k_len {
757                                let mut dot_product = 0.0;
758                                for k in 0..d_k {
759                                    dot_product += q_batch[[i, h, k]] * k_batch[[j, h, k]];
760                                }
761                                scores[[i, j]] = dot_product * scale_factor;
762                            }
763                        }
764
765                        // Apply mask if provided
766                        if let Some(mask_array) = mask {
767                            if let Some(mask_wrapper) = mask_array
768                                .as_any()
769                                .downcast_ref::<NdarrayWrapper<f64, Ix3>>()
770                            {
771                                let mask_batch =
772                                    mask_wrapper.as_array().slice(crate::s![b, .., ..]);
773                                for i in 0..q_len {
774                                    for j in 0..k_len {
775                                        if mask_batch[[i, j]] == 0.0 {
776                                            scores[[i, j]] = f64::NEG_INFINITY;
777                                        }
778                                    }
779                                }
780                            }
781                        }
782
783                        // Apply softmax to get attention weights
784                        let mut attention = Array::zeros((q_len, k_len));
785                        for i in 0..q_len {
786                            // Find max for numerical stability
787                            let mut max_score = f64::NEG_INFINITY;
788                            for j in 0..k_len {
789                                if scores[[i, j]] > max_score {
790                                    max_score = scores[[i, j]];
791                                }
792                            }
793
794                            // Compute exp and sum
795                            let mut exp_sum = 0.0;
796                            for j in 0..k_len {
797                                let exp_val = (scores[[i, j]] - max_score).exp();
798                                attention[[i, j]] = exp_val;
799                                exp_sum += exp_val;
800                            }
801
802                            // Normalize
803                            for j in 0..k_len {
804                                attention[[i, j]] /= exp_sum;
805                            }
806                        }
807
808                        // Compute output: attention * V
809                        for i in 0..q_len {
810                            for k in 0..d_k {
811                                let mut weighted_sum = 0.0;
812                                for j in 0..k_len {
813                                    weighted_sum += attention[[i, j]] * v_batch[[j, h, k]];
814                                }
815                                head_outputs[[i, h, k]] = weighted_sum;
816                            }
817                        }
818                    }
819
820                    // Aggregate outputs from all heads (simple average)
821                    for i in 0..q_len {
822                        for k in 0..d_k {
823                            let mut sum = 0.0;
824                            for h in 0..num_heads {
825                                sum += head_outputs[[i, h, k]];
826                            }
827                            output[[b, i, k]] = sum / num_heads as f64;
828                        }
829                    }
830                }
831
832                return Ok(Box::new(NdarrayWrapper::new(output)));
833            }
834            return Err(OperationError::NotImplemented(
835                "self_attention not implemented for these array types".to_string(),
836            ));
837        }
838
839        // Delegate to the implementation
840        let mut kwargs = HashMap::new();
841        if let Some(s) = scale {
842            kwargs.insert("scale".to_string(), Box::new(s) as Box<dyn Any>);
843        }
844        if let Some(m) = mask {
845            kwargs.insert("mask".to_string(), Box::new(m.box_clone()) as Box<dyn Any>);
846        }
847
848        let array_ref = implementing_args[0].1;
849
850        let result = array_ref.array_function(
851            &crate::array_protocol::ArrayFunction::new(
852                "scirs2::array_protocol::ml_ops::self_attention",
853            ),
854            &[TypeId::of::<Box<dyn ArrayProtocol>>()],
855            &[
856                Box::new(queries.box_clone()),
857                Box::new(keys.box_clone()),
858                Box::new(values.box_clone()),
859            ],
860            &kwargs,
861        )?;
862
863        // Try to downcast the result
864        match result.downcast::<Box<dyn ArrayProtocol>>() {
865            Ok(array) => Ok(*array),
866            Err(_) => Err(OperationError::Other(
867                "Failed to downcast array_function result".to_string(),
868            )),
869        }
870    },
871    "scirs2::array_protocol::ml, ops: self_attention"
872);