1use std::any::{Any, TypeId};
14use std::collections::HashMap;
15
16use ::ndarray::{Array, Axis, Ix1, Ix2, Ix3, Ix4, IxDyn};
17use rand::{Rng, RngExt, SeedableRng};
18
19use crate::array_protocol::operations::OperationError;
20use crate::array_protocol::{
21 array_function_dispatch, get_implementing_args, ArrayProtocol, NdarrayWrapper,
22};
23
24#[derive(Debug, Clone, Copy, PartialEq)]
26pub enum ActivationFunc {
27 ReLU,
29
30 Sigmoid,
32
33 Tanh,
35
36 Softmax,
38
39 LeakyReLU(f64),
41}
42
43#[allow(dead_code)]
45fn apply_activation(
46 x: &crate::ndarray::ArrayBase<crate::ndarray::ViewRepr<&f64>, IxDyn>,
47 func: ActivationFunc,
48) -> Array<f64, IxDyn> {
49 match func {
50 ActivationFunc::ReLU => x.mapv(|v| v.max(0.0)),
51 ActivationFunc::Sigmoid => x.mapv(|v| 1.0 / (1.0 + (-v).exp())),
52 ActivationFunc::Tanh => x.mapv(|v| v.tanh()),
53 ActivationFunc::LeakyReLU(alpha) => x.mapv(|v| if v > 0.0 { v } else { alpha * v }),
54 ActivationFunc::Softmax => {
55 let mut result = Array::zeros(x.raw_dim());
57
58 let last_dim = x.ndim() - 1;
60 let _last_dim_len = x.shape()[last_dim];
61
62 if x.ndim() == 1 {
63 let max_val = x.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
65 let exp_x = x.mapv(|v| (v - max_val).exp());
66 let sum_exp = exp_x.sum();
67 result.assign(&(exp_x / sum_exp));
68 } else {
69 for (i, mut slice) in result.lanes_mut(Axis(last_dim)).into_iter().enumerate() {
71 let x_slice = x.index_axis(Axis(last_dim), i);
73 let max_val = x_slice.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
74 let exp_x = x_slice.mapv(|v| (v - max_val).exp());
75 let sum_exp = exp_x.sum();
76 slice.assign(&(exp_x / sum_exp));
77 }
78 }
79
80 result
81 }
82 }
83}
84
85array_function_dispatch!(
88 fn activation(
89 x: &dyn ArrayProtocol,
90 func: ActivationFunc,
91 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
92 let boxed_x = Box::new(x.box_clone());
94 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_x];
95 let implementing_args = get_implementing_args(&boxed_args);
96 if implementing_args.is_empty() {
97 if let Some(x_array) = x.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
99 let x_array = x_array.as_array();
100 let result = apply_activation(&x_array.view(), func);
101 return Ok(Box::new(NdarrayWrapper::new(result)));
102 }
103 return Err(OperationError::NotImplemented(
104 "activation not implemented for this array type".to_string(),
105 ));
106 }
107
108 let array_ref = implementing_args[0].1;
110
111 let result = array_ref.array_function(
112 &crate::array_protocol::ArrayFunction::new(
113 "scirs2::array_protocol::ml_ops::activation",
114 ),
115 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
116 &[Box::new(x.box_clone())],
117 &HashMap::new(),
118 )?;
119
120 match result.downcast::<Box<dyn ArrayProtocol>>() {
122 Ok(array) => Ok(*array),
123 Err(_) => Err(OperationError::Other(
124 "Failed to downcast array_function result".to_string(),
125 )),
126 }
127 },
128 "scirs2::array_protocol::ml, ops: activation"
129);
130
131array_function_dispatch!(
132 fn conv2d(
133 input: &dyn ArrayProtocol,
134 filters: &dyn ArrayProtocol,
135 stride: (usize, usize),
136 padding: (usize, usize),
137 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
138 let boxed_input = Box::new(input.box_clone());
140 let boxed_filters = Box::new(filters.box_clone());
141 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_input, boxed_filters];
142 let implementing_args = get_implementing_args(&boxed_args);
143 if implementing_args.is_empty() {
144 if let (Some(inputarray), Some(filters_array)) = (
147 input.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
148 filters.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
149 ) {
150 let input = inputarray.as_array();
151 let filters = filters_array.as_array();
152
153 let batch_size = input.shape()[0];
155 let input_height = input.shape()[1];
156 let input_width = input.shape()[2];
157 let input_channels = input.shape()[3];
158
159 let filter_height = filters.shape()[0];
160 let filter_width = filters.shape()[1];
161 let filter_in_channels = filters.shape()[2];
162 let filter_out_channels = filters.shape()[3];
163
164 if input_channels != filter_in_channels {
166 return Err(OperationError::ShapeMismatch(format!(
167 "Input channels ({input_channels}) doesn't match filter input channels ({filter_in_channels})"
168 )));
169 }
170
171 let out_height = (input_height - filter_height + 2 * padding.0) / stride.0 + 1;
173 let out_width = (input_width - filter_width + 2 * padding.1) / stride.1 + 1;
174
175 let mut output: Array<f64, Ix4> =
177 Array::zeros((batch_size, out_height, out_width, filter_out_channels));
178
179 for b in 0..batch_size {
181 for out_c in 0..filter_out_channels {
182 for out_h in 0..out_height {
183 for out_w in 0..out_width {
184 let mut sum = 0.0;
185
186 for f_h in 0..filter_height {
188 for f_w in 0..filter_width {
189 for in_c in 0..input_channels {
190 let in_h = (out_h * stride.0) as i32 + f_h as i32
192 - padding.0 as i32;
193 let in_w = (out_w * stride.1) as i32 + f_w as i32
194 - padding.1 as i32;
195
196 if in_h >= 0
198 && in_h < input_height as i32
199 && in_w >= 0
200 && in_w < input_width as i32
201 {
202 let input_val =
203 input[[b, in_h as usize, in_w as usize, in_c]];
204 let filter_val = filters[[f_h, f_w, in_c, out_c]];
205 sum += input_val * filter_val;
206 }
207 }
208 }
209 }
210
211 output[[b, out_h, out_w, out_c]] = sum;
212 }
213 }
214 }
215 }
216
217 return Ok(Box::new(NdarrayWrapper::new(output)));
218 }
219 return Err(OperationError::NotImplemented(
220 "conv2d not implemented for these array types".to_string(),
221 ));
222 }
223
224 let mut kwargs = HashMap::new();
226 kwargs.insert("stride".to_string(), Box::new(stride) as Box<dyn Any>);
227 kwargs.insert("padding".to_string(), Box::new(padding) as Box<dyn Any>);
228
229 let array_ref = implementing_args[0].1;
230
231 let result = array_ref.array_function(
232 &crate::array_protocol::ArrayFunction::new("scirs2::array_protocol::ml_ops::conv2d"),
233 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
234 &[Box::new(input.box_clone()), Box::new(filters.box_clone())],
235 &kwargs,
236 )?;
237
238 match result.downcast::<Box<dyn ArrayProtocol>>() {
240 Ok(array) => Ok(*array),
241 Err(_) => Err(OperationError::Other(
242 "Failed to downcast array_function result".to_string(),
243 )),
244 }
245 },
246 "scirs2::array_protocol::ml, ops: conv2d"
247);
248
249array_function_dispatch!(
250 fn max_pool2d(
251 input: &dyn ArrayProtocol,
252 kernel_size: (usize, usize),
253 stride: (usize, usize),
254 padding: (usize, usize),
255 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
256 let boxed_input = Box::new(input.box_clone());
258 let boxed_args: Vec<Box<dyn Any>> = vec![boxed_input];
259 let implementing_args = get_implementing_args(&boxed_args);
260 if implementing_args.is_empty() {
261 if let Some(inputarray) = input.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>() {
263 let input = inputarray.as_array();
264
265 let batch_size = input.shape()[0];
267 let input_height = input.shape()[1];
268 let input_width = input.shape()[2];
269 let channels = input.shape()[3];
270
271 let out_height = (input_height - kernel_size.0 + 2 * padding.0) / stride.0 + 1;
273 let out_width = (input_width - kernel_size.1 + 2 * padding.1) / stride.1 + 1;
274
275 let mut output: Array<f64, Ix4> =
277 Array::zeros((batch_size, out_height, out_width, channels));
278
279 for b in 0..batch_size {
281 for c in 0..channels {
282 for out_h in 0..out_height {
283 for out_w in 0..out_width {
284 let mut max_val = f64::NEG_INFINITY;
285
286 for k_h in 0..kernel_size.0 {
288 for k_w in 0..kernel_size.1 {
289 let in_h = (out_h * stride.0) as i32 + k_h as i32
291 - padding.0 as i32;
292 let in_w = (out_w * stride.1) as i32 + k_w as i32
293 - padding.1 as i32;
294
295 if in_h >= 0
297 && in_h < input_height as i32
298 && in_w >= 0
299 && in_w < input_width as i32
300 {
301 let val = input[[b, in_h as usize, in_w as usize, c]];
302 if val > max_val {
303 max_val = val;
304 }
305 }
306 }
307 }
308
309 output[[b, out_h, out_w, c]] = if max_val == f64::NEG_INFINITY {
311 0.0
312 } else {
313 max_val
314 };
315 }
316 }
317 }
318 }
319
320 return Ok(Box::new(NdarrayWrapper::new(output)));
321 }
322 return Err(OperationError::NotImplemented(
323 "max_pool2d not implemented for this array type".to_string(),
324 ));
325 }
326
327 let mut kwargs = HashMap::new();
329 kwargs.insert(
330 "kernel_size".to_string(),
331 Box::new(kernel_size) as Box<dyn Any>,
332 );
333 kwargs.insert("stride".to_string(), Box::new(stride) as Box<dyn Any>);
334 kwargs.insert("padding".to_string(), Box::new(padding) as Box<dyn Any>);
335
336 let array_ref = implementing_args[0].1;
337
338 let result = array_ref.array_function(
339 &crate::array_protocol::ArrayFunction::new(
340 "scirs2::array_protocol::ml_ops::max_pool2d",
341 ),
342 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
343 &[Box::new(input.box_clone())],
344 &kwargs,
345 )?;
346
347 match result.downcast::<Box<dyn ArrayProtocol>>() {
349 Ok(array) => Ok(*array),
350 Err(_) => Err(OperationError::Other(
351 "Failed to downcast array_function result".to_string(),
352 )),
353 }
354 },
355 "scirs2::array_protocol::ml, ops: max_pool2d"
356);
357
358array_function_dispatch!(
359 fn batch_norm(
360 input: &dyn ArrayProtocol,
361 scale: &dyn ArrayProtocol,
362 offset: &dyn ArrayProtocol,
363 mean: &dyn ArrayProtocol,
364 variance: &dyn ArrayProtocol,
365 epsilon: f64,
366 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
367 let boxed_args: Vec<Box<dyn Any>> = vec![
369 Box::new(input.box_clone()),
370 Box::new(scale.box_clone()),
371 Box::new(offset.box_clone()),
372 Box::new(mean.box_clone()),
373 Box::new(variance.box_clone()),
374 ];
375 let implementing_args = get_implementing_args(&boxed_args);
376 if implementing_args.is_empty() {
377 if let (
379 Some(inputarray),
380 Some(scale_array),
381 Some(offset_array),
382 Some(mean_array),
383 Some(variance_array),
384 ) = (
385 input.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
386 scale.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
387 offset.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
388 mean.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
389 variance
390 .as_any()
391 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
392 ) {
393 let input = inputarray.as_array();
394 let scale = scale_array.as_array();
395 let offset = offset_array.as_array();
396 let mean = mean_array.as_array();
397 let variance = variance_array.as_array();
398
399 let _batch_size = input.shape()[0];
401 let _height = input.shape()[1];
402 let _width = input.shape()[2];
403 let channels = input.shape()[3];
404
405 if scale.shape()[0] != channels
407 || offset.shape()[0] != channels
408 || mean.shape()[0] != channels
409 || variance.shape()[0] != channels
410 {
411 return Err(OperationError::ShapeMismatch(
412 "Scale, offset, mean, and variance must match the number of channels"
413 .to_string(),
414 ));
415 }
416
417 let mut output: Array<f64, Ix4> = Array::zeros(input.raw_dim());
419
420 let batch_size = input.shape()[0];
425 let _height = input.shape()[1];
426 let _width = input.shape()[2];
427
428 for b in 0..batch_size {
429 for h in 0.._height {
430 for w in 0.._width {
431 for c in 0..channels {
432 let x = input[[b, h, w, c]];
433 let m = mean[[c]];
434 let v = variance[[c]];
435 let s = scale[[c]];
436 let o = offset[[c]];
437
438 let normalized = (x - m) / (v + epsilon).sqrt();
440
441 let result = s * normalized + o;
443
444 output[[b, h, w, c]] = result;
445 }
446 }
447 }
448 }
449
450 return Ok(Box::new(NdarrayWrapper::new(output)));
451 }
452 return Err(OperationError::NotImplemented(
453 "batch_norm not implemented for these array types".to_string(),
454 ));
455 }
456
457 let mut kwargs = HashMap::new();
459 kwargs.insert("epsilon".to_string(), Box::new(epsilon) as Box<dyn Any>);
460
461 let array_ref = implementing_args[0].1;
462
463 let result = array_ref.array_function(
464 &crate::array_protocol::ArrayFunction::new(
465 "scirs2::array_protocol::ml_ops::batch_norm",
466 ),
467 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
468 &[
469 Box::new(input.box_clone()),
470 Box::new(scale.box_clone()),
471 Box::new(offset.box_clone()),
472 Box::new(mean.box_clone()),
473 Box::new(variance.box_clone()),
474 ],
475 &kwargs,
476 )?;
477
478 match result.downcast::<Box<dyn ArrayProtocol>>() {
480 Ok(array) => Ok(*array),
481 Err(_) => Err(OperationError::Other(
482 "Failed to downcast array_function result".to_string(),
483 )),
484 }
485 },
486 "scirs2::array_protocol::ml, ops: batch_norm"
487);
488
489array_function_dispatch!(
490 fn cross_entropy(
491 logits: &dyn ArrayProtocol,
492 labels: &dyn ArrayProtocol,
493 reduction: &str,
494 ) -> Result<Box<dyn Any>, OperationError> {
495 let boxed_args: Vec<Box<dyn Any>> =
497 vec![Box::new(logits.box_clone()), Box::new(labels.box_clone())];
498 let implementing_args = get_implementing_args(&boxed_args);
499 if implementing_args.is_empty() {
500 if let (Some(logits_array), Some(labels_array)) = (
502 logits.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>(),
503 labels.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>(),
504 ) {
505 let logits = logits_array.as_array();
506 let labels = labels_array.as_array();
507
508 if logits.shape() != labels.shape() {
510 return Err(OperationError::ShapeMismatch(format!(
511 "Logits shape {logitsshape:?} doesn't match labels shape {labelsshape:?}",
512 logitsshape = logits.shape(),
513 labelsshape = labels.shape()
514 )));
515 }
516
517 let mut softmax = Array::zeros(logits.raw_dim());
519
520 for (i, sample) in logits.outer_iter().enumerate() {
522 let max_val = sample.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
523 let exp_x = sample.mapv(|v| (v - max_val).exp());
524 let sum_exp = exp_x.sum();
525
526 for (j, val) in exp_x.iter().enumerate() {
527 softmax[[i, j]] = val / sum_exp;
528 }
529 }
530
531 let mut sample_losses = Array::zeros(logits.shape()[0]);
534
535 for (i, (s, l)) in softmax.outer_iter().zip(labels.outer_iter()).enumerate() {
536 let mut loss = 0.0;
537 for (s_val, l_val) in s.iter().zip(l.iter()) {
538 loss -= l_val * (s_val + 1e-10).ln();
540 }
541 sample_losses[i] = loss;
542 }
543
544 let loss = match reduction {
546 "none" => sample_losses,
547 "mean" => {
548 let mean = sample_losses.sum() / sample_losses.len() as f64;
549 Array::from_elem(Ix1(1), mean)
551 }
552 "sum" => {
553 let sum = sample_losses.sum();
554 Array::from_elem(Ix1(1), sum)
556 }
557 _ => {
558 return Err(OperationError::ShapeMismatch(format!(
559 "Unknown reduction method: {reduction}"
560 )))
561 }
562 };
563
564 return Ok(Box::new(loss) as Box<dyn Any>);
565 }
566 return Err(OperationError::NotImplemented(
567 "cross_entropy not implemented for these array types".to_string(),
568 ));
569 }
570
571 let mut kwargs = HashMap::new();
573 kwargs.insert(
574 "reduction".to_string(),
575 Box::new(reduction.to_string()) as Box<dyn Any>,
576 );
577
578 let array_ref = implementing_args[0].1;
579
580 let result = array_ref.array_function(
581 &crate::array_protocol::ArrayFunction::new(
582 "scirs2::array_protocol::ml_ops::cross_entropy",
583 ),
584 &[TypeId::of::<Box<dyn Any>>()],
585 &[Box::new(logits.box_clone()), Box::new(labels.box_clone())],
586 &kwargs,
587 )?;
588
589 Ok(result)
590 },
591 "scirs2::array_protocol::ml, ops: cross_entropy"
592);
593
594array_function_dispatch!(
595 fn dropout(
596 input: &dyn ArrayProtocol,
597 rate: f64,
598 training: bool,
599 seed: Option<u64>,
600 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
601 let boxed_args: Vec<Box<dyn Any>> = vec![Box::new(input.box_clone())];
603 let implementing_args = get_implementing_args(&boxed_args);
604 if implementing_args.is_empty() {
605 if let Some(inputarray) = input.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
607 let input = inputarray.as_array();
608
609 if !training {
610 return Ok(Box::new(NdarrayWrapper::new(input.clone())));
612 }
613
614 let mut rng = match seed {
616 Some(s) => rand::rngs::StdRng::seed_from_u64(s),
617 None => {
618 let mut rng = rand::rng();
619 let random_seed: u64 = rng.random();
621 rand::rngs::StdRng::seed_from_u64(random_seed)
622 }
623 };
624
625 let mask = Array::from_shape_fn(input.raw_dim(), |_| {
626 if rng.random::<f64>() >= rate {
627 1.0
628 } else {
629 0.0
630 }
631 });
632
633 let scale = 1.0 / (1.0 - rate);
635 let result = input.clone() * &mask * scale;
636
637 return Ok(Box::new(NdarrayWrapper::new(result)));
638 }
639 return Err(OperationError::NotImplemented(
640 "dropout not implemented for this array type".to_string(),
641 ));
642 }
643
644 let mut kwargs = HashMap::new();
646 kwargs.insert("rate".to_string(), Box::new(rate) as Box<dyn Any>);
647 kwargs.insert("training".to_string(), Box::new(training) as Box<dyn Any>);
648 if let Some(s) = seed {
649 kwargs.insert("seed".to_string(), Box::new(s) as Box<dyn Any>);
650 }
651
652 let array_ref = implementing_args[0].1;
653
654 let result = array_ref.array_function(
655 &crate::array_protocol::ArrayFunction::new("scirs2::array_protocol::ml_ops::dropout"),
656 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
657 &[Box::new(input.box_clone())],
658 &kwargs,
659 )?;
660
661 match result.downcast::<Box<dyn ArrayProtocol>>() {
663 Ok(array) => Ok(*array),
664 Err(_) => Err(OperationError::Other(
665 "Failed to downcast array_function result".to_string(),
666 )),
667 }
668 },
669 "scirs2::array_protocol::ml, ops: dropout"
670);
671
672array_function_dispatch!(
673 fn self_attention(
674 queries: &dyn ArrayProtocol,
675 keys: &dyn ArrayProtocol,
676 values: &dyn ArrayProtocol,
677 mask: Option<&dyn ArrayProtocol>,
678 scale: Option<f64>,
679 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
680 let mut boxed_args: Vec<Box<dyn Any>> = vec![
682 Box::new(queries.box_clone()),
683 Box::new(keys.box_clone()),
684 Box::new(values.box_clone()),
685 ];
686 if let Some(m) = mask {
687 boxed_args.push(Box::new(m.box_clone()));
688 }
689
690 let implementing_args = get_implementing_args(&boxed_args);
691 if implementing_args.is_empty() {
692 if let (Some(q_array), Some(k_array), Some(v_array)) = (
694 queries.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
695 keys.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
696 values.as_any().downcast_ref::<NdarrayWrapper<f64, Ix4>>(),
697 ) {
698 let q = q_array.as_array();
699 let k = k_array.as_array();
700 let v = v_array.as_array();
701
702 let batch_size = q.shape()[0];
705 let q_len = q.shape()[1];
706 let num_heads = q.shape()[2];
707 let d_k = q.shape()[3];
708
709 let k_len = k.shape()[1];
710
711 if k.shape()[0] != batch_size
713 || k.shape()[2] != num_heads
714 || k.shape()[3] != d_k
715 || v.shape()[0] != batch_size
716 || v.shape()[1] != k_len
717 || v.shape()[2] != num_heads
718 {
719 return Err(OperationError::ShapeMismatch(
720 "Incompatible shapes for self-attention".to_string(),
721 ));
722 }
723
724 let _scale_factor = scale.unwrap_or_else(|| {
726 let d_k_f64 = d_k as f64;
728 if d_k_f64 > 0.0 {
729 d_k_f64.sqrt()
730 } else {
731 1.0 }
733 });
734
735 let scale_factor = scale.unwrap_or(1.0 / (d_k as f64).sqrt());
742 let mut output: Array<f64, Ix3> = Array::zeros((batch_size, q_len, d_k));
743
744 for b in 0..batch_size {
745 let q_batch = q.slice(crate::s![b, .., .., ..]);
747 let k_batch = k.slice(crate::s![b, .., .., ..]);
748 let v_batch = v.slice(crate::s![b, .., .., ..]);
749
750 let mut head_outputs = Array::zeros((q_len, num_heads, d_k));
752
753 for h in 0..num_heads {
754 let mut scores = Array::zeros((q_len, k_len));
755 for i in 0..q_len {
756 for j in 0..k_len {
757 let mut dot_product = 0.0;
758 for k in 0..d_k {
759 dot_product += q_batch[[i, h, k]] * k_batch[[j, h, k]];
760 }
761 scores[[i, j]] = dot_product * scale_factor;
762 }
763 }
764
765 if let Some(mask_array) = mask {
767 if let Some(mask_wrapper) = mask_array
768 .as_any()
769 .downcast_ref::<NdarrayWrapper<f64, Ix3>>()
770 {
771 let mask_batch =
772 mask_wrapper.as_array().slice(crate::s![b, .., ..]);
773 for i in 0..q_len {
774 for j in 0..k_len {
775 if mask_batch[[i, j]] == 0.0 {
776 scores[[i, j]] = f64::NEG_INFINITY;
777 }
778 }
779 }
780 }
781 }
782
783 let mut attention = Array::zeros((q_len, k_len));
785 for i in 0..q_len {
786 let mut max_score = f64::NEG_INFINITY;
788 for j in 0..k_len {
789 if scores[[i, j]] > max_score {
790 max_score = scores[[i, j]];
791 }
792 }
793
794 let mut exp_sum = 0.0;
796 for j in 0..k_len {
797 let exp_val = (scores[[i, j]] - max_score).exp();
798 attention[[i, j]] = exp_val;
799 exp_sum += exp_val;
800 }
801
802 for j in 0..k_len {
804 attention[[i, j]] /= exp_sum;
805 }
806 }
807
808 for i in 0..q_len {
810 for k in 0..d_k {
811 let mut weighted_sum = 0.0;
812 for j in 0..k_len {
813 weighted_sum += attention[[i, j]] * v_batch[[j, h, k]];
814 }
815 head_outputs[[i, h, k]] = weighted_sum;
816 }
817 }
818 }
819
820 for i in 0..q_len {
822 for k in 0..d_k {
823 let mut sum = 0.0;
824 for h in 0..num_heads {
825 sum += head_outputs[[i, h, k]];
826 }
827 output[[b, i, k]] = sum / num_heads as f64;
828 }
829 }
830 }
831
832 return Ok(Box::new(NdarrayWrapper::new(output)));
833 }
834 return Err(OperationError::NotImplemented(
835 "self_attention not implemented for these array types".to_string(),
836 ));
837 }
838
839 let mut kwargs = HashMap::new();
841 if let Some(s) = scale {
842 kwargs.insert("scale".to_string(), Box::new(s) as Box<dyn Any>);
843 }
844 if let Some(m) = mask {
845 kwargs.insert("mask".to_string(), Box::new(m.box_clone()) as Box<dyn Any>);
846 }
847
848 let array_ref = implementing_args[0].1;
849
850 let result = array_ref.array_function(
851 &crate::array_protocol::ArrayFunction::new(
852 "scirs2::array_protocol::ml_ops::self_attention",
853 ),
854 &[TypeId::of::<Box<dyn ArrayProtocol>>()],
855 &[
856 Box::new(queries.box_clone()),
857 Box::new(keys.box_clone()),
858 Box::new(values.box_clone()),
859 ],
860 &kwargs,
861 )?;
862
863 match result.downcast::<Box<dyn ArrayProtocol>>() {
865 Ok(array) => Ok(*array),
866 Err(_) => Err(OperationError::Other(
867 "Failed to downcast array_function result".to_string(),
868 )),
869 }
870 },
871 "scirs2::array_protocol::ml, ops: self_attention"
872);