1use std::any::{Any, TypeId};
20use std::collections::HashMap;
21
22use ndarray::{Array, Axis, Ix1, Ix2, Ix3, Ix4, IxDyn};
23use rand::prelude::*;
24use rand::Rng;
25
26use crate::array_protocol::operations::OperationError;
27use crate::array_protocol::{
28 array_function_dispatch, get_implementing_args, ArrayProtocol, NdarrayWrapper,
29};
30
31#[derive(Debug, Clone, Copy, PartialEq)]
33pub enum ActivationFunc {
34 ReLU,
36
37 Sigmoid,
39
40 Tanh,
42
43 Softmax,
45
46 LeakyReLU(f64),
48}
49
50#[allow(dead_code)]
52fn apply_activation(
53 x: &ndarray::ArrayBase<ndarray::ViewRepr<&f64>, IxDyn>,
54 func: ActivationFunc,
55) -> Array<f64, IxDyn> {
56 match func {
57 ActivationFunc::ReLU => x.mapv(|v| v.max(0.0)),
58 ActivationFunc::Sigmoid => x.mapv(|v| 1.0 / (1.0 + (-v).exp())),
59 ActivationFunc::Tanh => x.mapv(|v| v.tanh()),
60 ActivationFunc::LeakyReLU(alpha) => x.mapv(|v| if v > 0.0 { v } else { alpha * v }),
61 ActivationFunc::Softmax => {
62 let mut result = Array::zeros(x.raw_dim());
64
65 let last_dim = x.ndim() - 1;
67 let _last_dim_len = x.shape()[last_dim];
68
69 if x.ndim() == 1 {
70 let max_val = x.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
72 let exp_x = x.mapv(|v| (v - max_val).exp());
73 let sum_exp = exp_x.sum();
74 result.assign(&(exp_x / sum_exp));
75 } else {
76 for (i, mut slice) in result.lanes_mut(Axis(last_dim)).into_iter().enumerate() {
78 let x_slice = x.index_axis(Axis(last_dim), i);
80 let max_val = x_slice.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
81 let exp_x = x_slice.mapv(|v| (v - max_val).exp());
82 let sum_exp = exp_x.sum();
83 slice.assign(&(exp_x / sum_exp));
84 }
85 }
86
87 result
88 }
89 }
90}
91
92array_function_dispatch!(
95 fn activation(
96 x: &dyn ArrayProtocol,
97 func: ActivationFunc,
98 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
99 let boxed_x = Box::new(x.box_clone());
101 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_x];
102 let implementing_args = get_implementing_args(&boxed_args);
103 if implementing_args.is_empty() {
104 if let Some(x_array) = x.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
106 let x_array = x_array.as_array();
107 let result = apply_activation(&x_array.view(), func);
108 return Ok(Box::new(NdarrayWrapper::new(result)));
109 }
110 return Err(OperationError::NotImplemented(
111 "activation not implemented for this array type".to_string(),
112 ));
113 }
114
115 let array_ref = implementing_args[0].1;
117
118 let result = array_ref.array_function(
119 &crate::array_protocol::ArrayFunction::new(
120 "scirs2::array_protocol::ml_ops::activation",
121 ),
122 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
123 &[Box::new(x.box_clone())],
124 &HashMap::new(),
125 )?;
126
127 match result.downcast::<Box<dyn ArrayProtocol>>() {
129 Ok(array) => Ok(*array),
130 Err(_) => Err(OperationError::Other(
131 "Failed to downcast array_function result".to_string(),
132 )),
133 }
134 },
135 "scirs2::array_protocol::ml, ops: activation"
136);
137
138array_function_dispatch!(
139 fn conv2d(
140 input: &dyn ArrayProtocol,
141 filters: &dyn ArrayProtocol,
142 stride: (usize, usize),
143 padding: (usize, usize),
144 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
145 let boxed_input = Box::new(input.box_clone());
147 let boxed_filters = Box::new(filters.box_clone());
148 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_input, boxed_filters];
149 let implementing_args = get_implementing_args(&boxed_args);
150 if implementing_args.is_empty() {
151 if let (Some(inputarray), Some(filters_array)) = (
154 input.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
155 filters.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
156 ) {
157 let input = inputarray.as_array();
158 let filters = filters_array.as_array();
159
160 let batch_size = input.shape()[0];
162 let input_height = input.shape()[1];
163 let input_width = input.shape()[2];
164 let input_channels = input.shape()[3];
165
166 let filter_height = filters.shape()[0];
167 let filter_width = filters.shape()[1];
168 let filter_in_channels = filters.shape()[2];
169 let filter_out_channels = filters.shape()[3];
170
171 if input_channels != filter_in_channels {
173 return Err(OperationError::ShapeMismatch(format!(
174 "Input channels ({input_channels}) doesn't match filter input channels ({filter_in_channels})"
175 )));
176 }
177
178 let out_height = (input_height - filter_height + 2 * padding.0) / stride.0 + 1;
180 let out_width = (input_width - filter_width + 2 * padding.1) / stride.1 + 1;
181
182 let mut output: Array<f64, Ix4> =
184 Array::zeros((batch_size, out_height, out_width, filter_out_channels));
185
186 for b in 0..batch_size {
188 for out_c in 0..filter_out_channels {
189 for out_h in 0..out_height {
190 for out_w in 0..out_width {
191 let mut sum = 0.0;
192
193 for f_h in 0..filter_height {
195 for f_w in 0..filter_width {
196 for in_c in 0..input_channels {
197 let in_h = (out_h * stride.0) as i32 + f_h as i32
199 - padding.0 as i32;
200 let in_w = (out_w * stride.1) as i32 + f_w as i32
201 - padding.1 as i32;
202
203 if in_h >= 0
205 && in_h < input_height as i32
206 && in_w >= 0
207 && in_w < input_width as i32
208 {
209 let input_val =
210 input[[b, in_h as usize, in_w as usize, in_c]];
211 let filter_val = filters[[f_h, f_w, in_c, out_c]];
212 sum += input_val * filter_val;
213 }
214 }
215 }
216 }
217
218 output[[b, out_h, out_w, out_c]] = sum;
219 }
220 }
221 }
222 }
223
224 return Ok(Box::new(NdarrayWrapper::new(output)));
225 }
226 return Err(OperationError::NotImplemented(
227 "conv2d not implemented for these array types".to_string(),
228 ));
229 }
230
231 let mut kwargs = HashMap::new();
233 kwargs.insert("stride".to_string(), Box::new(stride) as Box<dyn Any>);
234 kwargs.insert("padding".to_string(), Box::new(padding) as Box<dyn Any>);
235
236 let array_ref = implementing_args[0].1;
237
238 let result = array_ref.array_function(
239 &crate::array_protocol::ArrayFunction::new("scirs2::array_protocol::ml_ops::conv2d"),
240 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
241 &[Box::new(input.box_clone()), Box::new(filters.box_clone())],
242 &kwargs,
243 )?;
244
245 match result.downcast::<Box<dyn ArrayProtocol>>() {
247 Ok(array) => Ok(*array),
248 Err(_) => Err(OperationError::Other(
249 "Failed to downcast array_function result".to_string(),
250 )),
251 }
252 },
253 "scirs2::array_protocol::ml, ops: conv2d"
254);
255
256array_function_dispatch!(
257 fn max_pool2d(
258 input: &dyn ArrayProtocol,
259 kernel_size: (usize, usize),
260 stride: (usize, usize),
261 padding: (usize, usize),
262 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
263 let boxed_input = Box::new(input.box_clone());
265 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_input];
266 let implementing_args = get_implementing_args(&boxed_args);
267 if implementing_args.is_empty() {
268 if let Some(inputarray) = input.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>() {
270 let input = inputarray.as_array();
271
272 let batch_size = input.shape()[0];
274 let input_height = input.shape()[1];
275 let input_width = input.shape()[2];
276 let channels = input.shape()[3];
277
278 let out_height = (input_height - kernel_size.0 + 2 * padding.0) / stride.0 + 1;
280 let out_width = (input_width - kernel_size.1 + 2 * padding.1) / stride.1 + 1;
281
282 let mut output: Array<f64, Ix4> =
284 Array::zeros((batch_size, out_height, out_width, channels));
285
286 for b in 0..batch_size {
288 for c in 0..channels {
289 for out_h in 0..out_height {
290 for out_w in 0..out_width {
291 let mut max_val = f64::NEG_INFINITY;
292
293 for k_h in 0..kernel_size.0 {
295 for k_w in 0..kernel_size.1 {
296 let in_h = (out_h * stride.0) as i32 + k_h as i32
298 - padding.0 as i32;
299 let in_w = (out_w * stride.1) as i32 + k_w as i32
300 - padding.1 as i32;
301
302 if in_h >= 0
304 && in_h < input_height as i32
305 && in_w >= 0
306 && in_w < input_width as i32
307 {
308 let val = input[[b, in_h as usize, in_w as usize, c]];
309 if val > max_val {
310 max_val = val;
311 }
312 }
313 }
314 }
315
316 output[[b, out_h, out_w, c]] = if max_val == f64::NEG_INFINITY {
318 0.0
319 } else {
320 max_val
321 };
322 }
323 }
324 }
325 }
326
327 return Ok(Box::new(NdarrayWrapper::new(output)));
328 }
329 return Err(OperationError::NotImplemented(
330 "max_pool2d not implemented for this array type".to_string(),
331 ));
332 }
333
334 let mut kwargs = HashMap::new();
336 kwargs.insert(
337 "kernel_size".to_string(),
338 Box::new(kernel_size) as Box<dyn Any>,
339 );
340 kwargs.insert("stride".to_string(), Box::new(stride) as Box<dyn Any>);
341 kwargs.insert("padding".to_string(), Box::new(padding) as Box<dyn Any>);
342
343 let array_ref = implementing_args[0].1;
344
345 let result = array_ref.array_function(
346 &crate::array_protocol::ArrayFunction::new(
347 "scirs2::array_protocol::ml_ops::max_pool2d",
348 ),
349 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
350 &[Box::new(input.box_clone())],
351 &kwargs,
352 )?;
353
354 match result.downcast::<Box<dyn ArrayProtocol>>() {
356 Ok(array) => Ok(*array),
357 Err(_) => Err(OperationError::Other(
358 "Failed to downcast array_function result".to_string(),
359 )),
360 }
361 },
362 "scirs2::array_protocol::ml, ops: max_pool2d"
363);
364
365array_function_dispatch!(
366 fn batch_norm(
367 input: &dyn ArrayProtocol,
368 scale: &dyn ArrayProtocol,
369 offset: &dyn ArrayProtocol,
370 mean: &dyn ArrayProtocol,
371 variance: &dyn ArrayProtocol,
372 epsilon: f64,
373 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
374 let boxed_args: Vec<Box<dyn Any>> = vec![
376 Box::new(input.box_clone()),
377 Box::new(scale.box_clone()),
378 Box::new(offset.box_clone()),
379 Box::new(mean.box_clone()),
380 Box::new(variance.box_clone()),
381 ];
382 let implementing_args = get_implementing_args(&boxed_args);
383 if implementing_args.is_empty() {
384 if let (
386 Some(inputarray),
387 Some(scale_array),
388 Some(offset_array),
389 Some(mean_array),
390 Some(variance_array),
391 ) = (
392 input.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
393 scale.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
394 offset.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
395 mean.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
396 variance
397 .as_any()
398 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
399 ) {
400 let input = inputarray.as_array();
401 let scale = scale_array.as_array();
402 let offset = offset_array.as_array();
403 let mean = mean_array.as_array();
404 let variance = variance_array.as_array();
405
406 let _batch_size = input.shape()[0];
408 let _height = input.shape()[1];
409 let _width = input.shape()[2];
410 let channels = input.shape()[3];
411
412 if scale.shape()[0] != channels
414 || offset.shape()[0] != channels
415 || mean.shape()[0] != channels
416 || variance.shape()[0] != channels
417 {
418 return Err(OperationError::ShapeMismatch(
419 "Scale, offset, mean, and variance must match the number of channels"
420 .to_string(),
421 ));
422 }
423
424 let mut output: Array<f64, Ix4> = Array::zeros(input.raw_dim());
426
427 let batch_size = input.shape()[0];
432 let _height = input.shape()[1];
433 let _width = input.shape()[2];
434
435 for b in 0..batch_size {
436 for h in 0.._height {
437 for w in 0.._width {
438 for c in 0..channels {
439 let x = input[[b, h, w, c]];
440 let m = mean[[c]];
441 let v = variance[[c]];
442 let s = scale[[c]];
443 let o = offset[[c]];
444
445 let normalized = (x - m) / (v + epsilon).sqrt();
447
448 let result = s * normalized + o;
450
451 output[[b, h, w, c]] = result;
452 }
453 }
454 }
455 }
456
457 return Ok(Box::new(NdarrayWrapper::new(output)));
458 }
459 return Err(OperationError::NotImplemented(
460 "batch_norm not implemented for these array types".to_string(),
461 ));
462 }
463
464 let mut kwargs = HashMap::new();
466 kwargs.insert("epsilon".to_string(), Box::new(epsilon) as Box<dyn Any>);
467
468 let array_ref = implementing_args[0].1;
469
470 let result = array_ref.array_function(
471 &crate::array_protocol::ArrayFunction::new(
472 "scirs2::array_protocol::ml_ops::batch_norm",
473 ),
474 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
475 &[
476 Box::new(input.box_clone()),
477 Box::new(scale.box_clone()),
478 Box::new(offset.box_clone()),
479 Box::new(mean.box_clone()),
480 Box::new(variance.box_clone()),
481 ],
482 &kwargs,
483 )?;
484
485 match result.downcast::<Box<dyn ArrayProtocol>>() {
487 Ok(array) => Ok(*array),
488 Err(_) => Err(OperationError::Other(
489 "Failed to downcast array_function result".to_string(),
490 )),
491 }
492 },
493 "scirs2::array_protocol::ml, ops: batch_norm"
494);
495
496array_function_dispatch!(
497 fn cross_entropy(
498 logits: &dyn ArrayProtocol,
499 labels: &dyn ArrayProtocol,
500 reduction: &str,
501 ) -> Result<Box<dyn Any>, OperationError> {
502 let boxed_args: Vec<Box<dyn Any>> =
504 vec![Box::new(logits.box_clone()), Box::new(labels.box_clone())];
505 let implementing_args = get_implementing_args(&boxed_args);
506 if implementing_args.is_empty() {
507 if let (Some(logits_array), Some(labels_array)) = (
509 logits.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>(),
510 labels.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>(),
511 ) {
512 let logits = logits_array.as_array();
513 let labels = labels_array.as_array();
514
515 if logits.shape() != labels.shape() {
517 return Err(OperationError::ShapeMismatch(format!(
518 "Logits shape {logitsshape:?} doesn't match labels shape {labelsshape:?}",
519 logitsshape = logits.shape(),
520 labelsshape = labels.shape()
521 )));
522 }
523
524 let mut softmax = Array::zeros(logits.raw_dim());
526
527 for (i, sample) in logits.outer_iter().enumerate() {
529 let max_val = sample.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
530 let exp_x = sample.mapv(|v| (v - max_val).exp());
531 let sum_exp = exp_x.sum();
532
533 for (j, val) in exp_x.iter().enumerate() {
534 softmax[[i, j]] = val / sum_exp;
535 }
536 }
537
538 let mut sample_losses = Array::zeros(logits.shape()[0]);
541
542 for (i, (s, l)) in softmax.outer_iter().zip(labels.outer_iter()).enumerate() {
543 let mut loss = 0.0;
544 for (s_val, l_val) in s.iter().zip(l.iter()) {
545 loss -= l_val * (s_val + 1e-10).ln();
547 }
548 sample_losses[i] = loss;
549 }
550
551 let loss = match reduction {
553 "none" => sample_losses,
554 "mean" => {
555 let mean = sample_losses.sum() / sample_losses.len() as f64;
556 Array::from_elem(Ix1(1), mean)
558 }
559 "sum" => {
560 let sum = sample_losses.sum();
561 Array::from_elem(Ix1(1), sum)
563 }
564 _ => {
565 return Err(OperationError::ShapeMismatch(format!(
566 "Unknown reduction method: {reduction}"
567 )))
568 }
569 };
570
571 return Ok(Box::new(loss) as Box<dyn Any>);
572 }
573 return Err(OperationError::NotImplemented(
574 "cross_entropy not implemented for these array types".to_string(),
575 ));
576 }
577
578 let mut kwargs = HashMap::new();
580 kwargs.insert(
581 "reduction".to_string(),
582 Box::new(reduction.to_string()) as Box<dyn Any>,
583 );
584
585 let array_ref = implementing_args[0].1;
586
587 let result = array_ref.array_function(
588 &crate::array_protocol::ArrayFunction::new(
589 "scirs2::array_protocol::ml_ops::cross_entropy",
590 ),
591 &[TypeId::of::<Box<dyn Any>>()],
592 &[Box::new(logits.box_clone()), Box::new(labels.box_clone())],
593 &kwargs,
594 )?;
595
596 Ok(result)
597 },
598 "scirs2::array_protocol::ml, ops: cross_entropy"
599);
600
601array_function_dispatch!(
602 fn dropout(
603 input: &dyn ArrayProtocol,
604 rate: f64,
605 training: bool,
606 seed: Option<u64>,
607 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
608 let boxed_args: Vec<Box<dyn Any>> = vec![Box::new(input.box_clone())];
610 let implementing_args = get_implementing_args(&boxed_args);
611 if implementing_args.is_empty() {
612 if let Some(inputarray) = input.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
614 let input = inputarray.as_array();
615
616 if !training {
617 return Ok(Box::new(NdarrayWrapper::new(input.clone())));
619 }
620
621 let mut rng = match seed {
623 Some(s) => rand::rngs::StdRng::seed_from_u64(s),
624 None => {
625 let mut rng = rand::rng();
626 let random_seed: u64 = rng.random();
628 rand::rngs::StdRng::seed_from_u64(random_seed)
629 }
630 };
631
632 let mask = Array::from_shape_fn(input.raw_dim(), |_| {
633 if rng.random::<f64>() >= rate {
634 1.0
635 } else {
636 0.0
637 }
638 });
639
640 let scale = 1.0 / (1.0 - rate);
642 let result = input.clone() * &mask * scale;
643
644 return Ok(Box::new(NdarrayWrapper::new(result)));
645 }
646 return Err(OperationError::NotImplemented(
647 "dropout not implemented for this array type".to_string(),
648 ));
649 }
650
651 let mut kwargs = HashMap::new();
653 kwargs.insert("rate".to_string(), Box::new(rate) as Box<dyn Any>);
654 kwargs.insert("training".to_string(), Box::new(training) as Box<dyn Any>);
655 if let Some(s) = seed {
656 kwargs.insert("seed".to_string(), Box::new(s) as Box<dyn Any>);
657 }
658
659 let array_ref = implementing_args[0].1;
660
661 let result = array_ref.array_function(
662 &crate::array_protocol::ArrayFunction::new("scirs2::array_protocol::ml_ops::dropout"),
663 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
664 &[Box::new(input.box_clone())],
665 &kwargs,
666 )?;
667
668 match result.downcast::<Box<dyn ArrayProtocol>>() {
670 Ok(array) => Ok(*array),
671 Err(_) => Err(OperationError::Other(
672 "Failed to downcast array_function result".to_string(),
673 )),
674 }
675 },
676 "scirs2::array_protocol::ml, ops: dropout"
677);
678
679array_function_dispatch!(
680 fn self_attention(
681 queries: &dyn ArrayProtocol,
682 keys: &dyn ArrayProtocol,
683 values: &dyn ArrayProtocol,
684 mask: Option<&dyn ArrayProtocol>,
685 scale: Option<f64>,
686 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
687 let mut boxed_args: Vec<Box<dyn Any>> = vec![
689 Box::new(queries.box_clone()),
690 Box::new(keys.box_clone()),
691 Box::new(values.box_clone()),
692 ];
693 if let Some(m) = mask {
694 boxed_args.push(Box::new(m.box_clone()));
695 }
696
697 let implementing_args = get_implementing_args(&boxed_args);
698 if implementing_args.is_empty() {
699 if let (Some(q_array), Some(k_array), Some(v_array)) = (
701 queries.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
702 keys.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
703 values.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
704 ) {
705 let q = q_array.as_array();
706 let k = k_array.as_array();
707 let v = v_array.as_array();
708
709 let batch_size = q.shape()[0];
712 let q_len = q.shape()[1];
713 let num_heads = q.shape()[2];
714 let d_k = q.shape()[3];
715
716 let k_len = k.shape()[1];
717
718 if k.shape()[0] != batch_size
720 || k.shape()[2] != num_heads
721 || k.shape()[3] != d_k
722 || v.shape()[0] != batch_size
723 || v.shape()[1] != k_len
724 || v.shape()[2] != num_heads
725 {
726 return Err(OperationError::ShapeMismatch(
727 "Incompatible shapes for self-attention".to_string(),
728 ));
729 }
730
731 let _scale_factor = scale.unwrap_or_else(|| {
733 let d_k_f64 = d_k as f64;
735 if d_k_f64 > 0.0 {
736 d_k_f64.sqrt()
737 } else {
738 1.0 }
740 });
741
742 let scale_factor = scale.unwrap_or(1.0 / (d_k as f64).sqrt());
749 let mut output: Array<f64, Ix3> = Array::zeros((batch_size, q_len, d_k));
750
751 for b in 0..batch_size {
752 let q_batch = q.slice(ndarray::s![b, .., .., ..]);
754 let k_batch = k.slice(ndarray::s![b, .., .., ..]);
755 let v_batch = v.slice(ndarray::s![b, .., .., ..]);
756
757 let mut head_outputs = Array::zeros((q_len, num_heads, d_k));
759
760 for h in 0..num_heads {
761 let mut scores = Array::zeros((q_len, k_len));
762 for i in 0..q_len {
763 for j in 0..k_len {
764 let mut dot_product = 0.0;
765 for k in 0..d_k {
766 dot_product += q_batch[[i, h, k]] * k_batch[[j, h, k]];
767 }
768 scores[[i, j]] = dot_product * scale_factor;
769 }
770 }
771
772 if let Some(mask_array) = mask {
774 if let Some(mask_wrapper) = mask_array
775 .as_any()
776 .downcast_ref::<NdarrayWrapper<f64, Ix3>>()
777 {
778 let mask_batch =
779 mask_wrapper.as_array().slice(ndarray::s![b, .., ..]);
780 for i in 0..q_len {
781 for j in 0..k_len {
782 if mask_batch[[i, j]] == 0.0 {
783 scores[[i, j]] = f64::NEG_INFINITY;
784 }
785 }
786 }
787 }
788 }
789
790 let mut attention = Array::zeros((q_len, k_len));
792 for i in 0..q_len {
793 let mut max_score = f64::NEG_INFINITY;
795 for j in 0..k_len {
796 if scores[[i, j]] > max_score {
797 max_score = scores[[i, j]];
798 }
799 }
800
801 let mut exp_sum = 0.0;
803 for j in 0..k_len {
804 let exp_val = (scores[[i, j]] - max_score).exp();
805 attention[[i, j]] = exp_val;
806 exp_sum += exp_val;
807 }
808
809 for j in 0..k_len {
811 attention[[i, j]] /= exp_sum;
812 }
813 }
814
815 for i in 0..q_len {
817 for k in 0..d_k {
818 let mut weighted_sum = 0.0;
819 for j in 0..k_len {
820 weighted_sum += attention[[i, j]] * v_batch[[j, h, k]];
821 }
822 head_outputs[[i, h, k]] = weighted_sum;
823 }
824 }
825 }
826
827 for i in 0..q_len {
829 for k in 0..d_k {
830 let mut sum = 0.0;
831 for h in 0..num_heads {
832 sum += head_outputs[[i, h, k]];
833 }
834 output[[b, i, k]] = sum / num_heads as f64;
835 }
836 }
837 }
838
839 return Ok(Box::new(NdarrayWrapper::new(output)));
840 }
841 return Err(OperationError::NotImplemented(
842 "self_attention not implemented for these array types".to_string(),
843 ));
844 }
845
846 let mut kwargs = HashMap::new();
848 if let Some(s) = scale {
849 kwargs.insert("scale".to_string(), Box::new(s) as Box<dyn Any>);
850 }
851 if let Some(m) = mask {
852 kwargs.insert("mask".to_string(), Box::new(m.box_clone()) as Box<dyn Any>);
853 }
854
855 let array_ref = implementing_args[0].1;
856
857 let result = array_ref.array_function(
858 &crate::array_protocol::ArrayFunction::new(
859 "scirs2::array_protocol::ml_ops::self_attention",
860 ),
861 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
862 &[
863 Box::new(queries.box_clone()),
864 Box::new(keys.box_clone()),
865 Box::new(values.box_clone()),
866 ],
867 &kwargs,
868 )?;
869
870 match result.downcast::<Box<dyn ArrayProtocol>>() {
872 Ok(array) => Ok(*array),
873 Err(_) => Err(OperationError::Other(
874 "Failed to downcast array_function result".to_string(),
875 )),
876 }
877 },
878 "scirs2::array_protocol::ml, ops: self_attention"
879);