scirs2_core/array_protocol/
ml_ops.rs

1// Copyright (c) 2025, `SciRS2` Team
2//
3// Licensed under either of
4//
5// * Apache License, Version 2.0
6//   (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
7// * MIT license
8//   (LICENSE-MIT or http://opensource.org/licenses/MIT)
9//
10// at your option.
11//
12
13//! Machine learning operations using the array protocol.
14//!
15//! This module provides implementations of various machine learning operations
16//! using the array protocol, such as activation functions, convolution, and
17//! pooling.
18
19use std::any::{Any, TypeId};
20use std::collections::HashMap;
21
22use ::ndarray::{Array, Axis, Ix1, Ix2, Ix3, Ix4, IxDyn};
23use rand::{Rng, SeedableRng};
24
25use crate::array_protocol::operations::OperationError;
26use crate::array_protocol::{
27    array_function_dispatch, get_implementing_args, ArrayProtocol, NdarrayWrapper,
28};
29
30/// Activation function types.
31#[derive(Debug, Clone, Copy, PartialEq)]
32pub enum ActivationFunc {
33    /// Rectified Linear Unit: f(x) = max(0, x)
34    ReLU,
35
36    /// Sigmoid function: f(x) = 1 / (1 + exp(-x))
37    Sigmoid,
38
39    /// Hyperbolic tangent: f(x) = tanh(x)
40    Tanh,
41
42    /// Softmax function (applied along the last dimension)
43    Softmax,
44
45    /// Leaky ReLU: f(x) = max(alpha * x, x)
46    LeakyReLU(f64),
47}
48
49/// Apply an activation function to an array.
50#[allow(dead_code)]
51fn apply_activation(
52    x: &crate::ndarray::ArrayBase<crate::ndarray::ViewRepr<&f64>, IxDyn>,
53    func: ActivationFunc,
54) -> Array<f64, IxDyn> {
55    match func {
56        ActivationFunc::ReLU => x.mapv(|v| v.max(0.0)),
57        ActivationFunc::Sigmoid => x.mapv(|v| 1.0 / (1.0 + (-v).exp())),
58        ActivationFunc::Tanh => x.mapv(|v| v.tanh()),
59        ActivationFunc::LeakyReLU(alpha) => x.mapv(|v| if v > 0.0 { v } else { alpha * v }),
60        ActivationFunc::Softmax => {
61            // Apply softmax along the last dimension
62            let mut result = Array::zeros(x.raw_dim());
63
64            // Iterate over all but the last dimension
65            let last_dim = x.ndim() - 1;
66            let _last_dim_len = x.shape()[last_dim];
67
68            if x.ndim() == 1 {
69                // Simple 1D case
70                let max_val = x.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
71                let exp_x = x.mapv(|v| (v - max_val).exp());
72                let sum_exp = exp_x.sum();
73                result.assign(&(exp_x / sum_exp));
74            } else {
75                // Multi-dimensional case
76                for (i, mut slice) in result.lanes_mut(Axis(last_dim)).into_iter().enumerate() {
77                    // Use index_axis to get the ith slice along the last dimension
78                    let x_slice = x.index_axis(Axis(last_dim), i);
79                    let max_val = x_slice.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
80                    let exp_x = x_slice.mapv(|v| (v - max_val).exp());
81                    let sum_exp = exp_x.sum();
82                    slice.assign(&(exp_x / sum_exp));
83                }
84            }
85
86            result
87        }
88    }
89}
90
91// Define machine learning operations using the array protocol
92
93array_function_dispatch!(
94    fn activation(
95        x: &dyn ArrayProtocol,
96        func: ActivationFunc,
97    ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
98        // Get implementing args
99        let boxed_x = Box::new(x.box_clone());
100        let boxed_args: Vec<Box<dyn Any>> = vec![boxed_x];
101        let implementing_args = get_implementing_args(&boxed_args);
102        if implementing_args.is_empty() {
103            // Fallback implementation for ndarray types
104            if let Some(x_array) = x.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
105                let x_array = x_array.as_array();
106                let result = apply_activation(&x_array.view(), func);
107                return Ok(Box::new(NdarrayWrapper::new(result)));
108            }
109            return Err(OperationError::NotImplemented(
110                "activation not implemented for this array type".to_string(),
111            ));
112        }
113
114        // Delegate to the implementation
115        let array_ref = implementing_args[0].1;
116
117        let result = array_ref.array_function(
118            &crate::array_protocol::ArrayFunction::new(
119                "scirs2::array_protocol::ml_ops::activation",
120            ),
121            &[TypeId::of::<Box<dyn ArrayProtocol>>()],
122            &[Box::new(x.box_clone())],
123            &HashMap::new(),
124        )?;
125
126        // Try to downcast the result
127        match result.downcast::<Box<dyn ArrayProtocol>>() {
128            Ok(array) => Ok(*array),
129            Err(_) => Err(OperationError::Other(
130                "Failed to downcast array_function result".to_string(),
131            )),
132        }
133    },
134    "scirs2::array_protocol::ml, ops: activation"
135);
136
137array_function_dispatch!(
138    fn conv2d(
139        input: &dyn ArrayProtocol,
140        filters: &dyn ArrayProtocol,
141        stride: (usize, usize),
142        padding: (usize, usize),
143    ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
144        // Get implementing args
145        let boxed_input = Box::new(input.box_clone());
146        let boxed_filters = Box::new(filters.box_clone());
147        let boxed_args: Vec<Box<dyn Any>> = vec![boxed_input, boxed_filters];
148        let implementing_args = get_implementing_args(&boxed_args);
149        if implementing_args.is_empty() {
150            // Fallback implementation for ndarray types
151            // This is a simplified implementation - in practice, convolution is much more complex
152            if let (Some(inputarray), Some(filters_array)) = (
153                input.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
154                filters.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
155            ) {
156                let input = inputarray.as_array();
157                let filters = filters_array.as_array();
158
159                // Get dimensions
160                let batch_size = input.shape()[0];
161                let input_height = input.shape()[1];
162                let input_width = input.shape()[2];
163                let input_channels = input.shape()[3];
164
165                let filter_height = filters.shape()[0];
166                let filter_width = filters.shape()[1];
167                let filter_in_channels = filters.shape()[2];
168                let filter_out_channels = filters.shape()[3];
169
170                // Check dimensions
171                if input_channels != filter_in_channels {
172                    return Err(OperationError::ShapeMismatch(format!(
173                        "Input channels ({input_channels}) doesn't match filter input channels ({filter_in_channels})"
174                    )));
175                }
176
177                // Calculate output dimensions
178                let out_height = (input_height - filter_height + 2 * padding.0) / stride.0 + 1;
179                let out_width = (input_width - filter_width + 2 * padding.1) / stride.1 + 1;
180
181                // Create output array
182                let mut output: Array<f64, Ix4> =
183                    Array::zeros((batch_size, out_height, out_width, filter_out_channels));
184
185                // Perform convolution using basic sliding window approach
186                for b in 0..batch_size {
187                    for out_c in 0..filter_out_channels {
188                        for out_h in 0..out_height {
189                            for out_w in 0..out_width {
190                                let mut sum = 0.0;
191
192                                // Convolution over the filter window
193                                for f_h in 0..filter_height {
194                                    for f_w in 0..filter_width {
195                                        for in_c in 0..input_channels {
196                                            // Calculate input coordinates with padding
197                                            let in_h = (out_h * stride.0) as i32 + f_h as i32
198                                                - padding.0 as i32;
199                                            let in_w = (out_w * stride.1) as i32 + f_w as i32
200                                                - padding.1 as i32;
201
202                                            // Check bounds (zero padding)
203                                            if in_h >= 0
204                                                && in_h < input_height as i32
205                                                && in_w >= 0
206                                                && in_w < input_width as i32
207                                            {
208                                                let input_val =
209                                                    input[[b, in_h as usize, in_w as usize, in_c]];
210                                                let filter_val = filters[[f_h, f_w, in_c, out_c]];
211                                                sum += input_val * filter_val;
212                                            }
213                                        }
214                                    }
215                                }
216
217                                output[[b, out_h, out_w, out_c]] = sum;
218                            }
219                        }
220                    }
221                }
222
223                return Ok(Box::new(NdarrayWrapper::new(output)));
224            }
225            return Err(OperationError::NotImplemented(
226                "conv2d not implemented for these array types".to_string(),
227            ));
228        }
229
230        // Delegate to the implementation
231        let mut kwargs = HashMap::new();
232        kwargs.insert("stride".to_string(), Box::new(stride) as Box<dyn Any>);
233        kwargs.insert("padding".to_string(), Box::new(padding) as Box<dyn Any>);
234
235        let array_ref = implementing_args[0].1;
236
237        let result = array_ref.array_function(
238            &crate::array_protocol::ArrayFunction::new("scirs2::array_protocol::ml_ops::conv2d"),
239            &[TypeId::of::<Box<dyn ArrayProtocol>>()],
240            &[Box::new(input.box_clone()), Box::new(filters.box_clone())],
241            &kwargs,
242        )?;
243
244        // Try to downcast the result
245        match result.downcast::<Box<dyn ArrayProtocol>>() {
246            Ok(array) => Ok(*array),
247            Err(_) => Err(OperationError::Other(
248                "Failed to downcast array_function result".to_string(),
249            )),
250        }
251    },
252    "scirs2::array_protocol::ml, ops: conv2d"
253);
254
255array_function_dispatch!(
256    fn max_pool2d(
257        input: &dyn ArrayProtocol,
258        kernel_size: (usize, usize),
259        stride: (usize, usize),
260        padding: (usize, usize),
261    ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
262        // Get implementing args
263        let boxed_input = Box::new(input.box_clone());
264        let boxed_args: Vec<Box<dyn Any>> = vec![boxed_input];
265        let implementing_args = get_implementing_args(&boxed_args);
266        if implementing_args.is_empty() {
267            // Fallback implementation for ndarray types
268            if let Some(inputarray) = input.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>() {
269                let input = inputarray.as_array();
270
271                // Get dimensions
272                let batch_size = input.shape()[0];
273                let input_height = input.shape()[1];
274                let input_width = input.shape()[2];
275                let channels = input.shape()[3];
276
277                // Calculate output dimensions
278                let out_height = (input_height - kernel_size.0 + 2 * padding.0) / stride.0 + 1;
279                let out_width = (input_width - kernel_size.1 + 2 * padding.1) / stride.1 + 1;
280
281                // Create output array
282                let mut output: Array<f64, Ix4> =
283                    Array::zeros((batch_size, out_height, out_width, channels));
284
285                // Perform max pooling
286                for b in 0..batch_size {
287                    for c in 0..channels {
288                        for out_h in 0..out_height {
289                            for out_w in 0..out_width {
290                                let mut max_val = f64::NEG_INFINITY;
291
292                                // Pool over the kernel window
293                                for k_h in 0..kernel_size.0 {
294                                    for k_w in 0..kernel_size.1 {
295                                        // Calculate input coordinates with padding
296                                        let in_h = (out_h * stride.0) as i32 + k_h as i32
297                                            - padding.0 as i32;
298                                        let in_w = (out_w * stride.1) as i32 + k_w as i32
299                                            - padding.1 as i32;
300
301                                        // Check bounds
302                                        if in_h >= 0
303                                            && in_h < input_height as i32
304                                            && in_w >= 0
305                                            && in_w < input_width as i32
306                                        {
307                                            let val = input[[b, in_h as usize, in_w as usize, c]];
308                                            if val > max_val {
309                                                max_val = val;
310                                            }
311                                        }
312                                    }
313                                }
314
315                                // Use 0.0 if no valid values found (due to padding)
316                                output[[b, out_h, out_w, c]] = if max_val == f64::NEG_INFINITY {
317                                    0.0
318                                } else {
319                                    max_val
320                                };
321                            }
322                        }
323                    }
324                }
325
326                return Ok(Box::new(NdarrayWrapper::new(output)));
327            }
328            return Err(OperationError::NotImplemented(
329                "max_pool2d not implemented for this array type".to_string(),
330            ));
331        }
332
333        // Delegate to the implementation
334        let mut kwargs = HashMap::new();
335        kwargs.insert(
336            "kernel_size".to_string(),
337            Box::new(kernel_size) as Box<dyn Any>,
338        );
339        kwargs.insert("stride".to_string(), Box::new(stride) as Box<dyn Any>);
340        kwargs.insert("padding".to_string(), Box::new(padding) as Box<dyn Any>);
341
342        let array_ref = implementing_args[0].1;
343
344        let result = array_ref.array_function(
345            &crate::array_protocol::ArrayFunction::new(
346                "scirs2::array_protocol::ml_ops::max_pool2d",
347            ),
348            &[TypeId::of::<Box<dyn ArrayProtocol>>()],
349            &[Box::new(input.box_clone())],
350            &kwargs,
351        )?;
352
353        // Try to downcast the result
354        match result.downcast::<Box<dyn ArrayProtocol>>() {
355            Ok(array) => Ok(*array),
356            Err(_) => Err(OperationError::Other(
357                "Failed to downcast array_function result".to_string(),
358            )),
359        }
360    },
361    "scirs2::array_protocol::ml, ops: max_pool2d"
362);
363
364array_function_dispatch!(
365    fn batch_norm(
366        input: &dyn ArrayProtocol,
367        scale: &dyn ArrayProtocol,
368        offset: &dyn ArrayProtocol,
369        mean: &dyn ArrayProtocol,
370        variance: &dyn ArrayProtocol,
371        epsilon: f64,
372    ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
373        // Get implementing args - convert to Box<dyn Any>
374        let boxed_args: Vec<Box<dyn Any>> = vec![
375            Box::new(input.box_clone()),
376            Box::new(scale.box_clone()),
377            Box::new(offset.box_clone()),
378            Box::new(mean.box_clone()),
379            Box::new(variance.box_clone()),
380        ];
381        let implementing_args = get_implementing_args(&boxed_args);
382        if implementing_args.is_empty() {
383            // Fallback implementation for ndarray types
384            if let (
385                Some(inputarray),
386                Some(scale_array),
387                Some(offset_array),
388                Some(mean_array),
389                Some(variance_array),
390            ) = (
391                input.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
392                scale.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
393                offset.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
394                mean.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
395                variance
396                    .as_any()
397                    .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
398            ) {
399                let input = inputarray.as_array();
400                let scale = scale_array.as_array();
401                let offset = offset_array.as_array();
402                let mean = mean_array.as_array();
403                let variance = variance_array.as_array();
404
405                // Get dimensions
406                let _batch_size = input.shape()[0];
407                let _height = input.shape()[1];
408                let _width = input.shape()[2];
409                let channels = input.shape()[3];
410
411                // Check dimensions
412                if scale.shape()[0] != channels
413                    || offset.shape()[0] != channels
414                    || mean.shape()[0] != channels
415                    || variance.shape()[0] != channels
416                {
417                    return Err(OperationError::ShapeMismatch(
418                        "Scale, offset, mean, and variance must match the number of channels"
419                            .to_string(),
420                    ));
421                }
422
423                // Create output array with same shape as input
424                let mut output: Array<f64, Ix4> = Array::zeros(input.raw_dim());
425
426                // Perform batch normalization
427                // For each channel, normalize using the formula:
428                // y = scale * (x - mean) / sqrt(variance + epsilon) + offset
429
430                let batch_size = input.shape()[0];
431                let _height = input.shape()[1];
432                let _width = input.shape()[2];
433
434                for b in 0..batch_size {
435                    for h in 0.._height {
436                        for w in 0.._width {
437                            for c in 0..channels {
438                                let x = input[[b, h, w, c]];
439                                let m = mean[[c]];
440                                let v = variance[[c]];
441                                let s = scale[[c]];
442                                let o = offset[[c]];
443
444                                // Normalize: (x - mean) / sqrt(variance + epsilon)
445                                let normalized = (x - m) / (v + epsilon).sqrt();
446
447                                // Scale and shift: scale * normalized + offset
448                                let result = s * normalized + o;
449
450                                output[[b, h, w, c]] = result;
451                            }
452                        }
453                    }
454                }
455
456                return Ok(Box::new(NdarrayWrapper::new(output)));
457            }
458            return Err(OperationError::NotImplemented(
459                "batch_norm not implemented for these array types".to_string(),
460            ));
461        }
462
463        // Delegate to the implementation
464        let mut kwargs = HashMap::new();
465        kwargs.insert("epsilon".to_string(), Box::new(epsilon) as Box<dyn Any>);
466
467        let array_ref = implementing_args[0].1;
468
469        let result = array_ref.array_function(
470            &crate::array_protocol::ArrayFunction::new(
471                "scirs2::array_protocol::ml_ops::batch_norm",
472            ),
473            &[TypeId::of::<Box<dyn ArrayProtocol>>()],
474            &[
475                Box::new(input.box_clone()),
476                Box::new(scale.box_clone()),
477                Box::new(offset.box_clone()),
478                Box::new(mean.box_clone()),
479                Box::new(variance.box_clone()),
480            ],
481            &kwargs,
482        )?;
483
484        // Try to downcast the result
485        match result.downcast::<Box<dyn ArrayProtocol>>() {
486            Ok(array) => Ok(*array),
487            Err(_) => Err(OperationError::Other(
488                "Failed to downcast array_function result".to_string(),
489            )),
490        }
491    },
492    "scirs2::array_protocol::ml, ops: batch_norm"
493);
494
495array_function_dispatch!(
496    fn cross_entropy(
497        logits: &dyn ArrayProtocol,
498        labels: &dyn ArrayProtocol,
499        reduction: &str,
500    ) -> Result<Box<dyn Any>, OperationError> {
501        // Get implementing args
502        let boxed_args: Vec<Box<dyn Any>> =
503            vec![Box::new(logits.box_clone()), Box::new(labels.box_clone())];
504        let implementing_args = get_implementing_args(&boxed_args);
505        if implementing_args.is_empty() {
506            // Fallback implementation for ndarray types
507            if let (Some(logits_array), Some(labels_array)) = (
508                logits.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>(),
509                labels.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>(),
510            ) {
511                let logits = logits_array.as_array();
512                let labels = labels_array.as_array();
513
514                // Check shapes
515                if logits.shape() != labels.shape() {
516                    return Err(OperationError::ShapeMismatch(format!(
517                        "Logits shape {logitsshape:?} doesn't match labels shape {labelsshape:?}",
518                        logitsshape = logits.shape(),
519                        labelsshape = labels.shape()
520                    )));
521                }
522
523                // Apply softmax to logits
524                let mut softmax = Array::zeros(logits.raw_dim());
525
526                // For each sample in the batch
527                for (i, sample) in logits.outer_iter().enumerate() {
528                    let max_val = sample.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
529                    let exp_x = sample.mapv(|v| (v - max_val).exp());
530                    let sum_exp = exp_x.sum();
531
532                    for (j, val) in exp_x.iter().enumerate() {
533                        softmax[[i, j]] = val / sum_exp;
534                    }
535                }
536
537                // Compute cross-entropy loss
538                // loss = -sum(labels * log(softmax))
539                let mut sample_losses = Array::zeros(logits.shape()[0]);
540
541                for (i, (s, l)) in softmax.outer_iter().zip(labels.outer_iter()).enumerate() {
542                    let mut loss = 0.0;
543                    for (s_val, l_val) in s.iter().zip(l.iter()) {
544                        // Add small epsilon to avoid log(0)
545                        loss -= l_val * (s_val + 1e-10).ln();
546                    }
547                    sample_losses[i] = loss;
548                }
549
550                // Apply reduction
551                let loss = match reduction {
552                    "none" => sample_losses,
553                    "mean" => {
554                        let mean = sample_losses.sum() / sample_losses.len() as f64;
555                        // Use Array1 instead of Array0 to make type consistent
556                        Array::from_elem(Ix1(1), mean)
557                    }
558                    "sum" => {
559                        let sum = sample_losses.sum();
560                        // Use Array1 instead of Array0 to make type consistent
561                        Array::from_elem(Ix1(1), sum)
562                    }
563                    _ => {
564                        return Err(OperationError::ShapeMismatch(format!(
565                            "Unknown reduction method: {reduction}"
566                        )))
567                    }
568                };
569
570                return Ok(Box::new(loss) as Box<dyn Any>);
571            }
572            return Err(OperationError::NotImplemented(
573                "cross_entropy not implemented for these array types".to_string(),
574            ));
575        }
576
577        // Delegate to the implementation
578        let mut kwargs = HashMap::new();
579        kwargs.insert(
580            "reduction".to_string(),
581            Box::new(reduction.to_string()) as Box<dyn Any>,
582        );
583
584        let array_ref = implementing_args[0].1;
585
586        let result = array_ref.array_function(
587            &crate::array_protocol::ArrayFunction::new(
588                "scirs2::array_protocol::ml_ops::cross_entropy",
589            ),
590            &[TypeId::of::<Box<dyn Any>>()],
591            &[Box::new(logits.box_clone()), Box::new(labels.box_clone())],
592            &kwargs,
593        )?;
594
595        Ok(result)
596    },
597    "scirs2::array_protocol::ml, ops: cross_entropy"
598);
599
600array_function_dispatch!(
601    fn dropout(
602        input: &dyn ArrayProtocol,
603        rate: f64,
604        training: bool,
605        seed: Option<u64>,
606    ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
607        // Get implementing args
608        let boxed_args: Vec<Box<dyn Any>> = vec![Box::new(input.box_clone())];
609        let implementing_args = get_implementing_args(&boxed_args);
610        if implementing_args.is_empty() {
611            // Fallback implementation for ndarray types
612            if let Some(inputarray) = input.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
613                let input = inputarray.as_array();
614
615                if !training {
616                    // During inference, just scale the input
617                    return Ok(Box::new(NdarrayWrapper::new(input.clone())));
618                }
619
620                // Create a binary mask with probabilities (1-rate)
621                let mut rng = match seed {
622                    Some(s) => rand::rngs::StdRng::seed_from_u64(s),
623                    None => {
624                        let mut rng = rand::rng();
625                        // Get a random seed from rng and create a new StdRng
626                        let random_seed: u64 = rng.random();
627                        rand::rngs::StdRng::seed_from_u64(random_seed)
628                    }
629                };
630
631                let mask = Array::from_shape_fn(input.raw_dim(), |_| {
632                    if rng.random::<f64>() >= rate {
633                        1.0
634                    } else {
635                        0.0
636                    }
637                });
638
639                // Scale by 1/(1-rate) to maintain expected value during training
640                let scale = 1.0 / (1.0 - rate);
641                let result = input.clone() * &mask * scale;
642
643                return Ok(Box::new(NdarrayWrapper::new(result)));
644            }
645            return Err(OperationError::NotImplemented(
646                "dropout not implemented for this array type".to_string(),
647            ));
648        }
649
650        // Delegate to the implementation
651        let mut kwargs = HashMap::new();
652        kwargs.insert("rate".to_string(), Box::new(rate) as Box<dyn Any>);
653        kwargs.insert("training".to_string(), Box::new(training) as Box<dyn Any>);
654        if let Some(s) = seed {
655            kwargs.insert("seed".to_string(), Box::new(s) as Box<dyn Any>);
656        }
657
658        let array_ref = implementing_args[0].1;
659
660        let result = array_ref.array_function(
661            &crate::array_protocol::ArrayFunction::new("scirs2::array_protocol::ml_ops::dropout"),
662            &[TypeId::of::<Box<dyn ArrayProtocol>>()],
663            &[Box::new(input.box_clone())],
664            &kwargs,
665        )?;
666
667        // Try to downcast the result
668        match result.downcast::<Box<dyn ArrayProtocol>>() {
669            Ok(array) => Ok(*array),
670            Err(_) => Err(OperationError::Other(
671                "Failed to downcast array_function result".to_string(),
672            )),
673        }
674    },
675    "scirs2::array_protocol::ml, ops: dropout"
676);
677
678array_function_dispatch!(
679    fn self_attention(
680        queries: &dyn ArrayProtocol,
681        keys: &dyn ArrayProtocol,
682        values: &dyn ArrayProtocol,
683        mask: Option<&dyn ArrayProtocol>,
684        scale: Option<f64>,
685    ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
686        // Get implementing args
687        let mut boxed_args: Vec<Box<dyn Any>> = vec![
688            Box::new(queries.box_clone()),
689            Box::new(keys.box_clone()),
690            Box::new(values.box_clone()),
691        ];
692        if let Some(m) = mask {
693            boxed_args.push(Box::new(m.box_clone()));
694        }
695
696        let implementing_args = get_implementing_args(&boxed_args);
697        if implementing_args.is_empty() {
698            // Fallback implementation for ndarray types
699            if let (Some(q_array), Some(k_array), Some(v_array)) = (
700                queries.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
701                keys.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
702                values.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
703            ) {
704                let q = q_array.as_array();
705                let k = k_array.as_array();
706                let v = v_array.as_array();
707
708                // Get dimensions
709                // q, k, v should have shape [batch_size, seq_len, num_heads, d_k]
710                let batch_size = q.shape()[0];
711                let q_len = q.shape()[1];
712                let num_heads = q.shape()[2];
713                let d_k = q.shape()[3];
714
715                let k_len = k.shape()[1];
716
717                // Check dimensions
718                if k.shape()[0] != batch_size
719                    || k.shape()[2] != num_heads
720                    || k.shape()[3] != d_k
721                    || v.shape()[0] != batch_size
722                    || v.shape()[1] != k_len
723                    || v.shape()[2] != num_heads
724                {
725                    return Err(OperationError::ShapeMismatch(
726                        "Incompatible shapes for self-attention".to_string(),
727                    ));
728                }
729
730                // Apply scaling
731                let _scale_factor = scale.unwrap_or_else(|| {
732                    // Default scale factor is 1/sqrt(d_k)
733                    let d_k_f64 = d_k as f64;
734                    if d_k_f64 > 0.0 {
735                        d_k_f64.sqrt()
736                    } else {
737                        1.0 // Fallback for edge case
738                    }
739                });
740
741                // Implement self-attention:
742                // 1. scores = matmul(q, k.transpose) / scale_factor
743                // 2. if mask: scores = scores.masked_fill(mask, -inf)
744                // 3. attention = softmax(scores)
745                // 4. output = matmul(attention, v)
746
747                let scale_factor = scale.unwrap_or(1.0 / (d_k as f64).sqrt());
748                let mut output: Array<f64, Ix3> = Array::zeros((batch_size, q_len, d_k));
749
750                for b in 0..batch_size {
751                    // Extract batch slices
752                    let q_batch = q.slice(crate::s![b, .., .., ..]);
753                    let k_batch = k.slice(crate::s![b, .., .., ..]);
754                    let v_batch = v.slice(crate::s![b, .., .., ..]);
755
756                    // Compute attention scores: Q * K^T for each head
757                    let mut head_outputs = Array::zeros((q_len, num_heads, d_k));
758
759                    for h in 0..num_heads {
760                        let mut scores = Array::zeros((q_len, k_len));
761                        for i in 0..q_len {
762                            for j in 0..k_len {
763                                let mut dot_product = 0.0;
764                                for k in 0..d_k {
765                                    dot_product += q_batch[[i, h, k]] * k_batch[[j, h, k]];
766                                }
767                                scores[[i, j]] = dot_product * scale_factor;
768                            }
769                        }
770
771                        // Apply mask if provided
772                        if let Some(mask_array) = mask {
773                            if let Some(mask_wrapper) = mask_array
774                                .as_any()
775                                .downcast_ref::<NdarrayWrapper<f64, Ix3>>()
776                            {
777                                let mask_batch =
778                                    mask_wrapper.as_array().slice(crate::s![b, .., ..]);
779                                for i in 0..q_len {
780                                    for j in 0..k_len {
781                                        if mask_batch[[i, j]] == 0.0 {
782                                            scores[[i, j]] = f64::NEG_INFINITY;
783                                        }
784                                    }
785                                }
786                            }
787                        }
788
789                        // Apply softmax to get attention weights
790                        let mut attention = Array::zeros((q_len, k_len));
791                        for i in 0..q_len {
792                            // Find max for numerical stability
793                            let mut max_score = f64::NEG_INFINITY;
794                            for j in 0..k_len {
795                                if scores[[i, j]] > max_score {
796                                    max_score = scores[[i, j]];
797                                }
798                            }
799
800                            // Compute exp and sum
801                            let mut exp_sum = 0.0;
802                            for j in 0..k_len {
803                                let exp_val = (scores[[i, j]] - max_score).exp();
804                                attention[[i, j]] = exp_val;
805                                exp_sum += exp_val;
806                            }
807
808                            // Normalize
809                            for j in 0..k_len {
810                                attention[[i, j]] /= exp_sum;
811                            }
812                        }
813
814                        // Compute output: attention * V
815                        for i in 0..q_len {
816                            for k in 0..d_k {
817                                let mut weighted_sum = 0.0;
818                                for j in 0..k_len {
819                                    weighted_sum += attention[[i, j]] * v_batch[[j, h, k]];
820                                }
821                                head_outputs[[i, h, k]] = weighted_sum;
822                            }
823                        }
824                    }
825
826                    // Aggregate outputs from all heads (simple average)
827                    for i in 0..q_len {
828                        for k in 0..d_k {
829                            let mut sum = 0.0;
830                            for h in 0..num_heads {
831                                sum += head_outputs[[i, h, k]];
832                            }
833                            output[[b, i, k]] = sum / num_heads as f64;
834                        }
835                    }
836                }
837
838                return Ok(Box::new(NdarrayWrapper::new(output)));
839            }
840            return Err(OperationError::NotImplemented(
841                "self_attention not implemented for these array types".to_string(),
842            ));
843        }
844
845        // Delegate to the implementation
846        let mut kwargs = HashMap::new();
847        if let Some(s) = scale {
848            kwargs.insert("scale".to_string(), Box::new(s) as Box<dyn Any>);
849        }
850        if let Some(m) = mask {
851            kwargs.insert("mask".to_string(), Box::new(m.box_clone()) as Box<dyn Any>);
852        }
853
854        let array_ref = implementing_args[0].1;
855
856        let result = array_ref.array_function(
857            &crate::array_protocol::ArrayFunction::new(
858                "scirs2::array_protocol::ml_ops::self_attention",
859            ),
860            &[TypeId::of::<Box<dyn ArrayProtocol>>()],
861            &[
862                Box::new(queries.box_clone()),
863                Box::new(keys.box_clone()),
864                Box::new(values.box_clone()),
865            ],
866            &kwargs,
867        )?;
868
869        // Try to downcast the result
870        match result.downcast::<Box<dyn ArrayProtocol>>() {
871            Ok(array) => Ok(*array),
872            Err(_) => Err(OperationError::Other(
873                "Failed to downcast array_function result".to_string(),
874            )),
875        }
876    },
877    "scirs2::array_protocol::ml, ops: self_attention"
878);