1use std::any::{Any, TypeId};
20use std::collections::HashMap;
21
22use ::ndarray::{Array, Axis, Ix1, Ix2, Ix3, Ix4, IxDyn};
23use rand::{Rng, SeedableRng};
24
25use crate::array_protocol::operations::OperationError;
26use crate::array_protocol::{
27 array_function_dispatch, get_implementing_args, ArrayProtocol, NdarrayWrapper,
28};
29
30#[derive(Debug, Clone, Copy, PartialEq)]
32pub enum ActivationFunc {
33 ReLU,
35
36 Sigmoid,
38
39 Tanh,
41
42 Softmax,
44
45 LeakyReLU(f64),
47}
48
49#[allow(dead_code)]
51fn apply_activation(
52 x: &crate::ndarray::ArrayBase<crate::ndarray::ViewRepr<&f64>, IxDyn>,
53 func: ActivationFunc,
54) -> Array<f64, IxDyn> {
55 match func {
56 ActivationFunc::ReLU => x.mapv(|v| v.max(0.0)),
57 ActivationFunc::Sigmoid => x.mapv(|v| 1.0 / (1.0 + (-v).exp())),
58 ActivationFunc::Tanh => x.mapv(|v| v.tanh()),
59 ActivationFunc::LeakyReLU(alpha) => x.mapv(|v| if v > 0.0 { v } else { alpha * v }),
60 ActivationFunc::Softmax => {
61 let mut result = Array::zeros(x.raw_dim());
63
64 let last_dim = x.ndim() - 1;
66 let _last_dim_len = x.shape()[last_dim];
67
68 if x.ndim() == 1 {
69 let max_val = x.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
71 let exp_x = x.mapv(|v| (v - max_val).exp());
72 let sum_exp = exp_x.sum();
73 result.assign(&(exp_x / sum_exp));
74 } else {
75 for (i, mut slice) in result.lanes_mut(Axis(last_dim)).into_iter().enumerate() {
77 let x_slice = x.index_axis(Axis(last_dim), i);
79 let max_val = x_slice.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
80 let exp_x = x_slice.mapv(|v| (v - max_val).exp());
81 let sum_exp = exp_x.sum();
82 slice.assign(&(exp_x / sum_exp));
83 }
84 }
85
86 result
87 }
88 }
89}
90
91array_function_dispatch!(
94 fn activation(
95 x: &dyn ArrayProtocol,
96 func: ActivationFunc,
97 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
98 let boxed_x = Box::new(x.box_clone());
100 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_x];
101 let implementing_args = get_implementing_args(&boxed_args);
102 if implementing_args.is_empty() {
103 if let Some(x_array) = x.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
105 let x_array = x_array.as_array();
106 let result = apply_activation(&x_array.view(), func);
107 return Ok(Box::new(NdarrayWrapper::new(result)));
108 }
109 return Err(OperationError::NotImplemented(
110 "activation not implemented for this array type".to_string(),
111 ));
112 }
113
114 let array_ref = implementing_args[0].1;
116
117 let result = array_ref.array_function(
118 &crate::array_protocol::ArrayFunction::new(
119 "scirs2::array_protocol::ml_ops::activation",
120 ),
121 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
122 &[Box::new(x.box_clone())],
123 &HashMap::new(),
124 )?;
125
126 match result.downcast::<Box<dyn ArrayProtocol>>() {
128 Ok(array) => Ok(*array),
129 Err(_) => Err(OperationError::Other(
130 "Failed to downcast array_function result".to_string(),
131 )),
132 }
133 },
134 "scirs2::array_protocol::ml, ops: activation"
135);
136
137array_function_dispatch!(
138 fn conv2d(
139 input: &dyn ArrayProtocol,
140 filters: &dyn ArrayProtocol,
141 stride: (usize, usize),
142 padding: (usize, usize),
143 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
144 let boxed_input = Box::new(input.box_clone());
146 let boxed_filters = Box::new(filters.box_clone());
147 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_input, boxed_filters];
148 let implementing_args = get_implementing_args(&boxed_args);
149 if implementing_args.is_empty() {
150 if let (Some(inputarray), Some(filters_array)) = (
153 input.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
154 filters.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
155 ) {
156 let input = inputarray.as_array();
157 let filters = filters_array.as_array();
158
159 let batch_size = input.shape()[0];
161 let input_height = input.shape()[1];
162 let input_width = input.shape()[2];
163 let input_channels = input.shape()[3];
164
165 let filter_height = filters.shape()[0];
166 let filter_width = filters.shape()[1];
167 let filter_in_channels = filters.shape()[2];
168 let filter_out_channels = filters.shape()[3];
169
170 if input_channels != filter_in_channels {
172 return Err(OperationError::ShapeMismatch(format!(
173 "Input channels ({input_channels}) doesn't match filter input channels ({filter_in_channels})"
174 )));
175 }
176
177 let out_height = (input_height - filter_height + 2 * padding.0) / stride.0 + 1;
179 let out_width = (input_width - filter_width + 2 * padding.1) / stride.1 + 1;
180
181 let mut output: Array<f64, Ix4> =
183 Array::zeros((batch_size, out_height, out_width, filter_out_channels));
184
185 for b in 0..batch_size {
187 for out_c in 0..filter_out_channels {
188 for out_h in 0..out_height {
189 for out_w in 0..out_width {
190 let mut sum = 0.0;
191
192 for f_h in 0..filter_height {
194 for f_w in 0..filter_width {
195 for in_c in 0..input_channels {
196 let in_h = (out_h * stride.0) as i32 + f_h as i32
198 - padding.0 as i32;
199 let in_w = (out_w * stride.1) as i32 + f_w as i32
200 - padding.1 as i32;
201
202 if in_h >= 0
204 && in_h < input_height as i32
205 && in_w >= 0
206 && in_w < input_width as i32
207 {
208 let input_val =
209 input[[b, in_h as usize, in_w as usize, in_c]];
210 let filter_val = filters[[f_h, f_w, in_c, out_c]];
211 sum += input_val * filter_val;
212 }
213 }
214 }
215 }
216
217 output[[b, out_h, out_w, out_c]] = sum;
218 }
219 }
220 }
221 }
222
223 return Ok(Box::new(NdarrayWrapper::new(output)));
224 }
225 return Err(OperationError::NotImplemented(
226 "conv2d not implemented for these array types".to_string(),
227 ));
228 }
229
230 let mut kwargs = HashMap::new();
232 kwargs.insert("stride".to_string(), Box::new(stride) as Box<dyn Any>);
233 kwargs.insert("padding".to_string(), Box::new(padding) as Box<dyn Any>);
234
235 let array_ref = implementing_args[0].1;
236
237 let result = array_ref.array_function(
238 &crate::array_protocol::ArrayFunction::new("scirs2::array_protocol::ml_ops::conv2d"),
239 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
240 &[Box::new(input.box_clone()), Box::new(filters.box_clone())],
241 &kwargs,
242 )?;
243
244 match result.downcast::<Box<dyn ArrayProtocol>>() {
246 Ok(array) => Ok(*array),
247 Err(_) => Err(OperationError::Other(
248 "Failed to downcast array_function result".to_string(),
249 )),
250 }
251 },
252 "scirs2::array_protocol::ml, ops: conv2d"
253);
254
255array_function_dispatch!(
256 fn max_pool2d(
257 input: &dyn ArrayProtocol,
258 kernel_size: (usize, usize),
259 stride: (usize, usize),
260 padding: (usize, usize),
261 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
262 let boxed_input = Box::new(input.box_clone());
264 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_input];
265 let implementing_args = get_implementing_args(&boxed_args);
266 if implementing_args.is_empty() {
267 if let Some(inputarray) = input.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>() {
269 let input = inputarray.as_array();
270
271 let batch_size = input.shape()[0];
273 let input_height = input.shape()[1];
274 let input_width = input.shape()[2];
275 let channels = input.shape()[3];
276
277 let out_height = (input_height - kernel_size.0 + 2 * padding.0) / stride.0 + 1;
279 let out_width = (input_width - kernel_size.1 + 2 * padding.1) / stride.1 + 1;
280
281 let mut output: Array<f64, Ix4> =
283 Array::zeros((batch_size, out_height, out_width, channels));
284
285 for b in 0..batch_size {
287 for c in 0..channels {
288 for out_h in 0..out_height {
289 for out_w in 0..out_width {
290 let mut max_val = f64::NEG_INFINITY;
291
292 for k_h in 0..kernel_size.0 {
294 for k_w in 0..kernel_size.1 {
295 let in_h = (out_h * stride.0) as i32 + k_h as i32
297 - padding.0 as i32;
298 let in_w = (out_w * stride.1) as i32 + k_w as i32
299 - padding.1 as i32;
300
301 if in_h >= 0
303 && in_h < input_height as i32
304 && in_w >= 0
305 && in_w < input_width as i32
306 {
307 let val = input[[b, in_h as usize, in_w as usize, c]];
308 if val > max_val {
309 max_val = val;
310 }
311 }
312 }
313 }
314
315 output[[b, out_h, out_w, c]] = if max_val == f64::NEG_INFINITY {
317 0.0
318 } else {
319 max_val
320 };
321 }
322 }
323 }
324 }
325
326 return Ok(Box::new(NdarrayWrapper::new(output)));
327 }
328 return Err(OperationError::NotImplemented(
329 "max_pool2d not implemented for this array type".to_string(),
330 ));
331 }
332
333 let mut kwargs = HashMap::new();
335 kwargs.insert(
336 "kernel_size".to_string(),
337 Box::new(kernel_size) as Box<dyn Any>,
338 );
339 kwargs.insert("stride".to_string(), Box::new(stride) as Box<dyn Any>);
340 kwargs.insert("padding".to_string(), Box::new(padding) as Box<dyn Any>);
341
342 let array_ref = implementing_args[0].1;
343
344 let result = array_ref.array_function(
345 &crate::array_protocol::ArrayFunction::new(
346 "scirs2::array_protocol::ml_ops::max_pool2d",
347 ),
348 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
349 &[Box::new(input.box_clone())],
350 &kwargs,
351 )?;
352
353 match result.downcast::<Box<dyn ArrayProtocol>>() {
355 Ok(array) => Ok(*array),
356 Err(_) => Err(OperationError::Other(
357 "Failed to downcast array_function result".to_string(),
358 )),
359 }
360 },
361 "scirs2::array_protocol::ml, ops: max_pool2d"
362);
363
364array_function_dispatch!(
365 fn batch_norm(
366 input: &dyn ArrayProtocol,
367 scale: &dyn ArrayProtocol,
368 offset: &dyn ArrayProtocol,
369 mean: &dyn ArrayProtocol,
370 variance: &dyn ArrayProtocol,
371 epsilon: f64,
372 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
373 let boxed_args: Vec<Box<dyn Any>> = vec![
375 Box::new(input.box_clone()),
376 Box::new(scale.box_clone()),
377 Box::new(offset.box_clone()),
378 Box::new(mean.box_clone()),
379 Box::new(variance.box_clone()),
380 ];
381 let implementing_args = get_implementing_args(&boxed_args);
382 if implementing_args.is_empty() {
383 if let (
385 Some(inputarray),
386 Some(scale_array),
387 Some(offset_array),
388 Some(mean_array),
389 Some(variance_array),
390 ) = (
391 input.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
392 scale.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
393 offset.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
394 mean.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
395 variance
396 .as_any()
397 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
398 ) {
399 let input = inputarray.as_array();
400 let scale = scale_array.as_array();
401 let offset = offset_array.as_array();
402 let mean = mean_array.as_array();
403 let variance = variance_array.as_array();
404
405 let _batch_size = input.shape()[0];
407 let _height = input.shape()[1];
408 let _width = input.shape()[2];
409 let channels = input.shape()[3];
410
411 if scale.shape()[0] != channels
413 || offset.shape()[0] != channels
414 || mean.shape()[0] != channels
415 || variance.shape()[0] != channels
416 {
417 return Err(OperationError::ShapeMismatch(
418 "Scale, offset, mean, and variance must match the number of channels"
419 .to_string(),
420 ));
421 }
422
423 let mut output: Array<f64, Ix4> = Array::zeros(input.raw_dim());
425
426 let batch_size = input.shape()[0];
431 let _height = input.shape()[1];
432 let _width = input.shape()[2];
433
434 for b in 0..batch_size {
435 for h in 0.._height {
436 for w in 0.._width {
437 for c in 0..channels {
438 let x = input[[b, h, w, c]];
439 let m = mean[[c]];
440 let v = variance[[c]];
441 let s = scale[[c]];
442 let o = offset[[c]];
443
444 let normalized = (x - m) / (v + epsilon).sqrt();
446
447 let result = s * normalized + o;
449
450 output[[b, h, w, c]] = result;
451 }
452 }
453 }
454 }
455
456 return Ok(Box::new(NdarrayWrapper::new(output)));
457 }
458 return Err(OperationError::NotImplemented(
459 "batch_norm not implemented for these array types".to_string(),
460 ));
461 }
462
463 let mut kwargs = HashMap::new();
465 kwargs.insert("epsilon".to_string(), Box::new(epsilon) as Box<dyn Any>);
466
467 let array_ref = implementing_args[0].1;
468
469 let result = array_ref.array_function(
470 &crate::array_protocol::ArrayFunction::new(
471 "scirs2::array_protocol::ml_ops::batch_norm",
472 ),
473 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
474 &[
475 Box::new(input.box_clone()),
476 Box::new(scale.box_clone()),
477 Box::new(offset.box_clone()),
478 Box::new(mean.box_clone()),
479 Box::new(variance.box_clone()),
480 ],
481 &kwargs,
482 )?;
483
484 match result.downcast::<Box<dyn ArrayProtocol>>() {
486 Ok(array) => Ok(*array),
487 Err(_) => Err(OperationError::Other(
488 "Failed to downcast array_function result".to_string(),
489 )),
490 }
491 },
492 "scirs2::array_protocol::ml, ops: batch_norm"
493);
494
495array_function_dispatch!(
496 fn cross_entropy(
497 logits: &dyn ArrayProtocol,
498 labels: &dyn ArrayProtocol,
499 reduction: &str,
500 ) -> Result<Box<dyn Any>, OperationError> {
501 let boxed_args: Vec<Box<dyn Any>> =
503 vec![Box::new(logits.box_clone()), Box::new(labels.box_clone())];
504 let implementing_args = get_implementing_args(&boxed_args);
505 if implementing_args.is_empty() {
506 if let (Some(logits_array), Some(labels_array)) = (
508 logits.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>(),
509 labels.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>(),
510 ) {
511 let logits = logits_array.as_array();
512 let labels = labels_array.as_array();
513
514 if logits.shape() != labels.shape() {
516 return Err(OperationError::ShapeMismatch(format!(
517 "Logits shape {logitsshape:?} doesn't match labels shape {labelsshape:?}",
518 logitsshape = logits.shape(),
519 labelsshape = labels.shape()
520 )));
521 }
522
523 let mut softmax = Array::zeros(logits.raw_dim());
525
526 for (i, sample) in logits.outer_iter().enumerate() {
528 let max_val = sample.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
529 let exp_x = sample.mapv(|v| (v - max_val).exp());
530 let sum_exp = exp_x.sum();
531
532 for (j, val) in exp_x.iter().enumerate() {
533 softmax[[i, j]] = val / sum_exp;
534 }
535 }
536
537 let mut sample_losses = Array::zeros(logits.shape()[0]);
540
541 for (i, (s, l)) in softmax.outer_iter().zip(labels.outer_iter()).enumerate() {
542 let mut loss = 0.0;
543 for (s_val, l_val) in s.iter().zip(l.iter()) {
544 loss -= l_val * (s_val + 1e-10).ln();
546 }
547 sample_losses[i] = loss;
548 }
549
550 let loss = match reduction {
552 "none" => sample_losses,
553 "mean" => {
554 let mean = sample_losses.sum() / sample_losses.len() as f64;
555 Array::from_elem(Ix1(1), mean)
557 }
558 "sum" => {
559 let sum = sample_losses.sum();
560 Array::from_elem(Ix1(1), sum)
562 }
563 _ => {
564 return Err(OperationError::ShapeMismatch(format!(
565 "Unknown reduction method: {reduction}"
566 )))
567 }
568 };
569
570 return Ok(Box::new(loss) as Box<dyn Any>);
571 }
572 return Err(OperationError::NotImplemented(
573 "cross_entropy not implemented for these array types".to_string(),
574 ));
575 }
576
577 let mut kwargs = HashMap::new();
579 kwargs.insert(
580 "reduction".to_string(),
581 Box::new(reduction.to_string()) as Box<dyn Any>,
582 );
583
584 let array_ref = implementing_args[0].1;
585
586 let result = array_ref.array_function(
587 &crate::array_protocol::ArrayFunction::new(
588 "scirs2::array_protocol::ml_ops::cross_entropy",
589 ),
590 &[TypeId::of::<Box<dyn Any>>()],
591 &[Box::new(logits.box_clone()), Box::new(labels.box_clone())],
592 &kwargs,
593 )?;
594
595 Ok(result)
596 },
597 "scirs2::array_protocol::ml, ops: cross_entropy"
598);
599
600array_function_dispatch!(
601 fn dropout(
602 input: &dyn ArrayProtocol,
603 rate: f64,
604 training: bool,
605 seed: Option<u64>,
606 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
607 let boxed_args: Vec<Box<dyn Any>> = vec![Box::new(input.box_clone())];
609 let implementing_args = get_implementing_args(&boxed_args);
610 if implementing_args.is_empty() {
611 if let Some(inputarray) = input.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
613 let input = inputarray.as_array();
614
615 if !training {
616 return Ok(Box::new(NdarrayWrapper::new(input.clone())));
618 }
619
620 let mut rng = match seed {
622 Some(s) => rand::rngs::StdRng::seed_from_u64(s),
623 None => {
624 let mut rng = rand::rng();
625 let random_seed: u64 = rng.random();
627 rand::rngs::StdRng::seed_from_u64(random_seed)
628 }
629 };
630
631 let mask = Array::from_shape_fn(input.raw_dim(), |_| {
632 if rng.random::<f64>() >= rate {
633 1.0
634 } else {
635 0.0
636 }
637 });
638
639 let scale = 1.0 / (1.0 - rate);
641 let result = input.clone() * &mask * scale;
642
643 return Ok(Box::new(NdarrayWrapper::new(result)));
644 }
645 return Err(OperationError::NotImplemented(
646 "dropout not implemented for this array type".to_string(),
647 ));
648 }
649
650 let mut kwargs = HashMap::new();
652 kwargs.insert("rate".to_string(), Box::new(rate) as Box<dyn Any>);
653 kwargs.insert("training".to_string(), Box::new(training) as Box<dyn Any>);
654 if let Some(s) = seed {
655 kwargs.insert("seed".to_string(), Box::new(s) as Box<dyn Any>);
656 }
657
658 let array_ref = implementing_args[0].1;
659
660 let result = array_ref.array_function(
661 &crate::array_protocol::ArrayFunction::new("scirs2::array_protocol::ml_ops::dropout"),
662 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
663 &[Box::new(input.box_clone())],
664 &kwargs,
665 )?;
666
667 match result.downcast::<Box<dyn ArrayProtocol>>() {
669 Ok(array) => Ok(*array),
670 Err(_) => Err(OperationError::Other(
671 "Failed to downcast array_function result".to_string(),
672 )),
673 }
674 },
675 "scirs2::array_protocol::ml, ops: dropout"
676);
677
678array_function_dispatch!(
679 fn self_attention(
680 queries: &dyn ArrayProtocol,
681 keys: &dyn ArrayProtocol,
682 values: &dyn ArrayProtocol,
683 mask: Option<&dyn ArrayProtocol>,
684 scale: Option<f64>,
685 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
686 let mut boxed_args: Vec<Box<dyn Any>> = vec![
688 Box::new(queries.box_clone()),
689 Box::new(keys.box_clone()),
690 Box::new(values.box_clone()),
691 ];
692 if let Some(m) = mask {
693 boxed_args.push(Box::new(m.box_clone()));
694 }
695
696 let implementing_args = get_implementing_args(&boxed_args);
697 if implementing_args.is_empty() {
698 if let (Some(q_array), Some(k_array), Some(v_array)) = (
700 queries.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
701 keys.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
702 values.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
703 ) {
704 let q = q_array.as_array();
705 let k = k_array.as_array();
706 let v = v_array.as_array();
707
708 let batch_size = q.shape()[0];
711 let q_len = q.shape()[1];
712 let num_heads = q.shape()[2];
713 let d_k = q.shape()[3];
714
715 let k_len = k.shape()[1];
716
717 if k.shape()[0] != batch_size
719 || k.shape()[2] != num_heads
720 || k.shape()[3] != d_k
721 || v.shape()[0] != batch_size
722 || v.shape()[1] != k_len
723 || v.shape()[2] != num_heads
724 {
725 return Err(OperationError::ShapeMismatch(
726 "Incompatible shapes for self-attention".to_string(),
727 ));
728 }
729
730 let _scale_factor = scale.unwrap_or_else(|| {
732 let d_k_f64 = d_k as f64;
734 if d_k_f64 > 0.0 {
735 d_k_f64.sqrt()
736 } else {
737 1.0 }
739 });
740
741 let scale_factor = scale.unwrap_or(1.0 / (d_k as f64).sqrt());
748 let mut output: Array<f64, Ix3> = Array::zeros((batch_size, q_len, d_k));
749
750 for b in 0..batch_size {
751 let q_batch = q.slice(crate::s![b, .., .., ..]);
753 let k_batch = k.slice(crate::s![b, .., .., ..]);
754 let v_batch = v.slice(crate::s![b, .., .., ..]);
755
756 let mut head_outputs = Array::zeros((q_len, num_heads, d_k));
758
759 for h in 0..num_heads {
760 let mut scores = Array::zeros((q_len, k_len));
761 for i in 0..q_len {
762 for j in 0..k_len {
763 let mut dot_product = 0.0;
764 for k in 0..d_k {
765 dot_product += q_batch[[i, h, k]] * k_batch[[j, h, k]];
766 }
767 scores[[i, j]] = dot_product * scale_factor;
768 }
769 }
770
771 if let Some(mask_array) = mask {
773 if let Some(mask_wrapper) = mask_array
774 .as_any()
775 .downcast_ref::<NdarrayWrapper<f64, Ix3>>()
776 {
777 let mask_batch =
778 mask_wrapper.as_array().slice(crate::s![b, .., ..]);
779 for i in 0..q_len {
780 for j in 0..k_len {
781 if mask_batch[[i, j]] == 0.0 {
782 scores[[i, j]] = f64::NEG_INFINITY;
783 }
784 }
785 }
786 }
787 }
788
789 let mut attention = Array::zeros((q_len, k_len));
791 for i in 0..q_len {
792 let mut max_score = f64::NEG_INFINITY;
794 for j in 0..k_len {
795 if scores[[i, j]] > max_score {
796 max_score = scores[[i, j]];
797 }
798 }
799
800 let mut exp_sum = 0.0;
802 for j in 0..k_len {
803 let exp_val = (scores[[i, j]] - max_score).exp();
804 attention[[i, j]] = exp_val;
805 exp_sum += exp_val;
806 }
807
808 for j in 0..k_len {
810 attention[[i, j]] /= exp_sum;
811 }
812 }
813
814 for i in 0..q_len {
816 for k in 0..d_k {
817 let mut weighted_sum = 0.0;
818 for j in 0..k_len {
819 weighted_sum += attention[[i, j]] * v_batch[[j, h, k]];
820 }
821 head_outputs[[i, h, k]] = weighted_sum;
822 }
823 }
824 }
825
826 for i in 0..q_len {
828 for k in 0..d_k {
829 let mut sum = 0.0;
830 for h in 0..num_heads {
831 sum += head_outputs[[i, h, k]];
832 }
833 output[[b, i, k]] = sum / num_heads as f64;
834 }
835 }
836 }
837
838 return Ok(Box::new(NdarrayWrapper::new(output)));
839 }
840 return Err(OperationError::NotImplemented(
841 "self_attention not implemented for these array types".to_string(),
842 ));
843 }
844
845 let mut kwargs = HashMap::new();
847 if let Some(s) = scale {
848 kwargs.insert("scale".to_string(), Box::new(s) as Box<dyn Any>);
849 }
850 if let Some(m) = mask {
851 kwargs.insert("mask".to_string(), Box::new(m.box_clone()) as Box<dyn Any>);
852 }
853
854 let array_ref = implementing_args[0].1;
855
856 let result = array_ref.array_function(
857 &crate::array_protocol::ArrayFunction::new(
858 "scirs2::array_protocol::ml_ops::self_attention",
859 ),
860 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
861 &[
862 Box::new(queries.box_clone()),
863 Box::new(keys.box_clone()),
864 Box::new(values.box_clone()),
865 ],
866 &kwargs,
867 )?;
868
869 match result.downcast::<Box<dyn ArrayProtocol>>() {
871 Ok(array) => Ok(*array),
872 Err(_) => Err(OperationError::Other(
873 "Failed to downcast array_function result".to_string(),
874 )),
875 }
876 },
877 "scirs2::array_protocol::ml, ops: self_attention"
878);