1use crate::{FloatElement, Tensor};
4use torsh_core::error::{Result, TorshError};
5use torsh_core::TensorElement;
6
7impl<T: FloatElement> Tensor<T> {
8 pub fn conv1d(
10 &self,
11 weight: &Self,
12 bias: Option<&Self>,
13 stride: usize,
14 padding: usize,
15 dilation: usize,
16 groups: usize,
17 ) -> Result<Self> {
18 let input_shape_obj = self.shape();
23 let input_shape = input_shape_obj.dims();
24 let weight_shape_obj = weight.shape();
25 let weight_shape = weight_shape_obj.dims();
26
27 if input_shape.len() != 3 {
28 return Err(TorshError::InvalidArgument(format!(
29 "Expected 3D input tensor for conv1d, got {}D",
30 input_shape.len()
31 )));
32 }
33
34 if weight_shape.len() != 3 {
35 return Err(TorshError::InvalidArgument(format!(
36 "Expected 3D weight tensor for conv1d, got {}D",
37 weight_shape.len()
38 )));
39 }
40
41 let batch_size = input_shape[0];
42 let in_channels = input_shape[1];
43 let input_length = input_shape[2];
44
45 let out_channels = weight_shape[0];
46 let kernel_size = weight_shape[2];
47
48 if in_channels % groups != 0 || out_channels % groups != 0 {
50 return Err(TorshError::InvalidArgument(
51 "in_channels and out_channels must be divisible by groups".to_string(),
52 ));
53 }
54
55 if weight_shape[1] != in_channels / groups {
56 return Err(TorshError::InvalidArgument(format!(
57 "Weight tensor has wrong number of input channels: expected {}, got {}",
58 in_channels / groups,
59 weight_shape[1]
60 )));
61 }
62
63 let effective_kernel = (kernel_size - 1) * dilation + 1;
65 let padded_length = input_length + 2 * padding;
66 let output_length = (padded_length - effective_kernel) / stride + 1;
67
68 let mut output_data =
70 vec![<T as TensorElement>::zero(); batch_size * out_channels * output_length];
71
72 for n in 0..batch_size {
74 for g in 0..groups {
75 let out_ch_start = g * (out_channels / groups);
76 let out_ch_end = (g + 1) * (out_channels / groups);
77 let in_ch_start = g * (in_channels / groups);
78 let in_ch_end = (g + 1) * (in_channels / groups);
79
80 for oc in out_ch_start..out_ch_end {
81 for ol in 0..output_length {
82 let mut sum = <T as TensorElement>::zero();
83
84 for ic in in_ch_start..in_ch_end {
85 let ic_rel = ic - in_ch_start;
86 for k in 0..kernel_size {
87 let il = (ol * stride + k * dilation) as i32 - padding as i32;
88
89 if il >= 0 && (il as usize) < input_length {
90 let input_idx = n * in_channels * input_length
91 + ic * input_length
92 + il as usize;
93 let weight_idx = oc * (in_channels / groups) * kernel_size
94 + ic_rel * kernel_size
95 + k;
96
97 let input_val = self.storage.get(input_idx)?;
98 let weight_val = weight.storage.get(weight_idx)?;
99 sum = sum + input_val * weight_val;
100 }
101 }
102 }
103
104 let output_idx = n * out_channels * output_length + oc * output_length + ol;
105 output_data[output_idx] = sum;
106 }
107 }
108 }
109 }
110
111 let mut output = Tensor::from_data(
113 output_data,
114 vec![batch_size, out_channels, output_length],
115 self.device(),
116 )?;
117
118 if let Some(b) = bias {
120 if b.shape().dims() != [out_channels] {
121 return Err(TorshError::InvalidArgument(format!(
122 "Bias must have shape [{}], got {:?}",
123 out_channels,
124 b.shape().dims()
125 )));
126 }
127
128 let bias_data = b.to_vec()?;
130 let mut output_data = output.to_vec()?;
131
132 for n in 0..batch_size {
133 #[allow(clippy::needless_range_loop)]
134 for oc in 0..out_channels {
135 for ol in 0..output_length {
136 let idx = n * out_channels * output_length + oc * output_length + ol;
137 output_data[idx] = output_data[idx] + bias_data[oc];
138 }
139 }
140 }
141
142 output = Tensor::from_data(
144 output_data,
145 vec![batch_size, out_channels, output_length],
146 self.device(),
147 )?;
148 }
149
150 if self.requires_grad
152 || weight.requires_grad
153 || (bias.is_some() && bias.expect("bias checked with is_some").requires_grad)
154 {
155 use std::sync::Arc;
156 output.requires_grad = true;
157 output.operation = crate::Operation::Custom(
158 "conv1d".to_string(),
159 vec![
160 Arc::downgrade(&Arc::new(self.clone())),
161 Arc::downgrade(&Arc::new(weight.clone())),
162 ],
163 );
164 }
165
166 Ok(output)
167 }
168
169 pub fn conv2d(
171 &self,
172 weight: &Self,
173 bias: Option<&Self>,
174 stride: (usize, usize),
175 padding: (usize, usize),
176 dilation: (usize, usize),
177 groups: usize,
178 ) -> Result<Self> {
179 let input_shape_obj = self.shape();
184 let input_shape = input_shape_obj.dims();
185 let weight_shape_obj = weight.shape();
186 let weight_shape = weight_shape_obj.dims();
187
188 if input_shape.len() != 4 {
189 return Err(TorshError::InvalidArgument(format!(
190 "Expected 4D input tensor for conv2d, got {}D",
191 input_shape.len()
192 )));
193 }
194
195 if weight_shape.len() != 4 {
196 return Err(TorshError::InvalidArgument(format!(
197 "Expected 4D weight tensor for conv2d, got {}D",
198 weight_shape.len()
199 )));
200 }
201
202 let batch_size = input_shape[0];
203 let in_channels = input_shape[1];
204 let input_height = input_shape[2];
205 let input_width = input_shape[3];
206
207 let out_channels = weight_shape[0];
208 let kernel_height = weight_shape[2];
209 let kernel_width = weight_shape[3];
210
211 if in_channels % groups != 0 || out_channels % groups != 0 {
213 return Err(TorshError::InvalidArgument(
214 "in_channels and out_channels must be divisible by groups".to_string(),
215 ));
216 }
217
218 if weight_shape[1] != in_channels / groups {
219 return Err(TorshError::InvalidArgument(format!(
220 "Weight tensor has wrong number of input channels: expected {}, got {}",
221 in_channels / groups,
222 weight_shape[1]
223 )));
224 }
225
226 let effective_kernel_h = (kernel_height - 1) * dilation.0 + 1;
228 let effective_kernel_w = (kernel_width - 1) * dilation.1 + 1;
229 let padded_height = input_height + 2 * padding.0;
230 let padded_width = input_width + 2 * padding.1;
231 let output_height = (padded_height - effective_kernel_h) / stride.0 + 1;
232 let output_width = (padded_width - effective_kernel_w) / stride.1 + 1;
233
234 let mut output_data = vec![
236 <T as TensorElement>::zero();
237 batch_size * out_channels * output_height * output_width
238 ];
239
240 let self_data = self.to_vec()?;
241 let weight_data = weight.to_vec()?;
242
243 for n in 0..batch_size {
245 for g in 0..groups {
246 let out_ch_start = g * (out_channels / groups);
247 let out_ch_end = (g + 1) * (out_channels / groups);
248 let in_ch_start = g * (in_channels / groups);
249 let in_ch_end = (g + 1) * (in_channels / groups);
250
251 for oc in out_ch_start..out_ch_end {
252 for oh in 0..output_height {
253 for ow in 0..output_width {
254 let mut sum = <T as TensorElement>::zero();
255
256 for ic in in_ch_start..in_ch_end {
257 let ic_rel = ic - in_ch_start;
258 for kh in 0..kernel_height {
259 for kw in 0..kernel_width {
260 let ih = (oh * stride.0 + kh * dilation.0) as i32
261 - padding.0 as i32;
262 let iw = (ow * stride.1 + kw * dilation.1) as i32
263 - padding.1 as i32;
264
265 if ih >= 0
266 && (ih as usize) < input_height
267 && iw >= 0
268 && (iw as usize) < input_width
269 {
270 let input_idx =
271 n * in_channels * input_height * input_width
272 + ic * input_height * input_width
273 + ih as usize * input_width
274 + iw as usize;
275 let weight_idx = oc
276 * (in_channels / groups)
277 * kernel_height
278 * kernel_width
279 + ic_rel * kernel_height * kernel_width
280 + kh * kernel_width
281 + kw;
282
283 sum = sum
284 + self_data[input_idx] * weight_data[weight_idx];
285 }
286 }
287 }
288 }
289
290 let output_idx = n * out_channels * output_height * output_width
291 + oc * output_height * output_width
292 + oh * output_width
293 + ow;
294 output_data[output_idx] = sum;
295 }
296 }
297 }
298 }
299 }
300
301 let mut output = Tensor::from_data(
303 output_data,
304 vec![batch_size, out_channels, output_height, output_width],
305 self.device(),
306 )?;
307
308 if let Some(b) = bias {
310 if b.shape().dims() != [out_channels] {
311 return Err(TorshError::InvalidArgument(format!(
312 "Bias must have shape [{}], got {:?}",
313 out_channels,
314 b.shape().dims()
315 )));
316 }
317
318 let bias_data = b.to_vec()?;
319
320 let mut output_data = output.to_vec()?;
321
322 for n in 0..batch_size {
323 #[allow(clippy::needless_range_loop)]
324 for oc in 0..out_channels {
325 for oh in 0..output_height {
326 for ow in 0..output_width {
327 let idx = n * out_channels * output_height * output_width
328 + oc * output_height * output_width
329 + oh * output_width
330 + ow;
331 output_data[idx] = output_data[idx] + bias_data[oc];
332 }
333 }
334 }
335 }
336
337 output = Tensor::from_data(
339 output_data,
340 vec![batch_size, out_channels, output_height, output_width],
341 self.device(),
342 )?;
343 }
344
345 if self.requires_grad
347 || weight.requires_grad
348 || (bias.is_some() && bias.expect("bias checked with is_some").requires_grad)
349 {
350 use std::sync::Arc;
351 output.requires_grad = true;
352 output.operation = crate::Operation::Custom(
353 "conv2d".to_string(),
354 vec![
355 Arc::downgrade(&Arc::new(self.clone())),
356 Arc::downgrade(&Arc::new(weight.clone())),
357 ],
358 );
359 }
360
361 Ok(output)
362 }
363
364 pub fn conv3d(
366 &self,
367 weight: &Self,
368 bias: Option<&Self>,
369 stride: (usize, usize, usize),
370 padding: (usize, usize, usize),
371 dilation: (usize, usize, usize),
372 groups: usize,
373 ) -> Result<Self> {
374 let input_shape_obj = self.shape();
379 let input_shape = input_shape_obj.dims();
380 let weight_shape_obj = weight.shape();
381 let weight_shape = weight_shape_obj.dims();
382
383 if input_shape.len() != 5 {
384 return Err(TorshError::InvalidArgument(format!(
385 "Expected 5D input tensor for conv3d, got {}D",
386 input_shape.len()
387 )));
388 }
389
390 if weight_shape.len() != 5 {
391 return Err(TorshError::InvalidArgument(format!(
392 "Expected 5D weight tensor for conv3d, got {}D",
393 weight_shape.len()
394 )));
395 }
396
397 let batch_size = input_shape[0];
398 let in_channels = input_shape[1];
399 let input_depth = input_shape[2];
400 let input_height = input_shape[3];
401 let input_width = input_shape[4];
402
403 let out_channels = weight_shape[0];
404 let kernel_depth = weight_shape[2];
405 let kernel_height = weight_shape[3];
406 let kernel_width = weight_shape[4];
407
408 if in_channels % groups != 0 || out_channels % groups != 0 {
410 return Err(TorshError::InvalidArgument(
411 "in_channels and out_channels must be divisible by groups".to_string(),
412 ));
413 }
414
415 if weight_shape[1] != in_channels / groups {
416 return Err(TorshError::InvalidArgument(format!(
417 "Weight tensor has wrong number of input channels: expected {}, got {}",
418 in_channels / groups,
419 weight_shape[1]
420 )));
421 }
422
423 let effective_kernel_d = (kernel_depth - 1) * dilation.0 + 1;
425 let effective_kernel_h = (kernel_height - 1) * dilation.1 + 1;
426 let effective_kernel_w = (kernel_width - 1) * dilation.2 + 1;
427 let padded_depth = input_depth + 2 * padding.0;
428 let padded_height = input_height + 2 * padding.1;
429 let padded_width = input_width + 2 * padding.2;
430 let output_depth = (padded_depth - effective_kernel_d) / stride.0 + 1;
431 let output_height = (padded_height - effective_kernel_h) / stride.1 + 1;
432 let output_width = (padded_width - effective_kernel_w) / stride.2 + 1;
433
434 let output_size = batch_size * out_channels * output_depth * output_height * output_width;
436 let mut output_data = vec![<T as TensorElement>::zero(); output_size];
437
438 let self_data = self.to_vec()?;
439 let weight_data = weight.to_vec()?;
440
441 for n in 0..batch_size {
443 for g in 0..groups {
444 let out_ch_start = g * (out_channels / groups);
445 let out_ch_end = (g + 1) * (out_channels / groups);
446 let in_ch_start = g * (in_channels / groups);
447 let in_ch_end = (g + 1) * (in_channels / groups);
448
449 for oc in out_ch_start..out_ch_end {
450 for od in 0..output_depth {
451 for oh in 0..output_height {
452 for ow in 0..output_width {
453 let mut sum = <T as TensorElement>::zero();
454
455 for ic in in_ch_start..in_ch_end {
456 let ic_rel = ic - in_ch_start;
457 for kd in 0..kernel_depth {
458 for kh in 0..kernel_height {
459 for kw in 0..kernel_width {
460 let id = (od * stride.0 + kd * dilation.0) as i32
461 - padding.0 as i32;
462 let ih = (oh * stride.1 + kh * dilation.1) as i32
463 - padding.1 as i32;
464 let iw = (ow * stride.2 + kw * dilation.2) as i32
465 - padding.2 as i32;
466
467 if id >= 0
468 && (id as usize) < input_depth
469 && ih >= 0
470 && (ih as usize) < input_height
471 && iw >= 0
472 && (iw as usize) < input_width
473 {
474 let input_idx = n
475 * in_channels
476 * input_depth
477 * input_height
478 * input_width
479 + ic * input_depth
480 * input_height
481 * input_width
482 + id as usize * input_height * input_width
483 + ih as usize * input_width
484 + iw as usize;
485 let weight_idx = oc
486 * (in_channels / groups)
487 * kernel_depth
488 * kernel_height
489 * kernel_width
490 + ic_rel
491 * kernel_depth
492 * kernel_height
493 * kernel_width
494 + kd * kernel_height * kernel_width
495 + kh * kernel_width
496 + kw;
497
498 sum = sum
499 + self_data[input_idx]
500 * weight_data[weight_idx];
501 }
502 }
503 }
504 }
505 }
506
507 let output_idx =
508 n * out_channels * output_depth * output_height * output_width
509 + oc * output_depth * output_height * output_width
510 + od * output_height * output_width
511 + oh * output_width
512 + ow;
513 output_data[output_idx] = sum;
514 }
515 }
516 }
517 }
518 }
519 }
520
521 let mut output = Tensor::from_data(
523 output_data,
524 vec![
525 batch_size,
526 out_channels,
527 output_depth,
528 output_height,
529 output_width,
530 ],
531 self.device(),
532 )?;
533
534 if let Some(b) = bias {
536 if b.shape().dims() != [out_channels] {
537 return Err(TorshError::InvalidArgument(format!(
538 "Bias must have shape [{}], got {:?}",
539 out_channels,
540 b.shape().dims()
541 )));
542 }
543
544 let bias_data = b.to_vec()?;
545
546 let mut output_data = output.to_vec()?;
547
548 for n in 0..batch_size {
549 #[allow(clippy::needless_range_loop)]
550 for oc in 0..out_channels {
551 for od in 0..output_depth {
552 for oh in 0..output_height {
553 for ow in 0..output_width {
554 let idx =
555 n * out_channels * output_depth * output_height * output_width
556 + oc * output_depth * output_height * output_width
557 + od * output_height * output_width
558 + oh * output_width
559 + ow;
560 output_data[idx] = output_data[idx] + bias_data[oc];
561 }
562 }
563 }
564 }
565 }
566
567 output = Tensor::from_data(
569 output_data,
570 vec![
571 batch_size,
572 out_channels,
573 output_depth,
574 output_height,
575 output_width,
576 ],
577 self.device(),
578 )?;
579 }
580
581 if self.requires_grad
583 || weight.requires_grad
584 || (bias.is_some() && bias.expect("bias checked with is_some").requires_grad)
585 {
586 use std::sync::Arc;
587 output.requires_grad = true;
588 output.operation = crate::Operation::Custom(
589 "conv3d".to_string(),
590 vec![
591 Arc::downgrade(&Arc::new(self.clone())),
592 Arc::downgrade(&Arc::new(weight.clone())),
593 ],
594 );
595 }
596
597 Ok(output)
598 }
599
600 pub fn depthwise_conv2d(
603 &self,
604 weight: &Self,
605 bias: Option<&Self>,
606 stride: (usize, usize),
607 padding: (usize, usize),
608 dilation: (usize, usize),
609 ) -> Result<Self> {
610 let input_shape_obj = self.shape();
615 let input_shape = input_shape_obj.dims();
616 let weight_shape_obj = weight.shape();
617 let weight_shape = weight_shape_obj.dims();
618
619 if input_shape.len() != 4 {
620 return Err(TorshError::InvalidArgument(format!(
621 "Expected 4D input tensor for depthwise_conv2d, got {}D",
622 input_shape.len()
623 )));
624 }
625
626 if weight_shape.len() != 4 {
627 return Err(TorshError::InvalidArgument(format!(
628 "Expected 4D weight tensor for depthwise_conv2d, got {}D",
629 weight_shape.len()
630 )));
631 }
632
633 let batch_size = input_shape[0];
634 let in_channels = input_shape[1];
635 let input_height = input_shape[2];
636 let input_width = input_shape[3];
637
638 let kernel_height = weight_shape[2];
639 let kernel_width = weight_shape[3];
640
641 if weight_shape[0] != in_channels || weight_shape[1] != 1 {
643 return Err(TorshError::InvalidArgument(format!(
644 "Weight tensor must have shape ({}, 1, kernel_h, kernel_w), got ({}, {}, {}, {})",
645 in_channels, weight_shape[0], weight_shape[1], weight_shape[2], weight_shape[3]
646 )));
647 }
648
649 let effective_kernel_h = (kernel_height - 1) * dilation.0 + 1;
651 let effective_kernel_w = (kernel_width - 1) * dilation.1 + 1;
652 let padded_height = input_height + 2 * padding.0;
653 let padded_width = input_width + 2 * padding.1;
654 let output_height = (padded_height - effective_kernel_h) / stride.0 + 1;
655 let output_width = (padded_width - effective_kernel_w) / stride.1 + 1;
656
657 let mut output_data = vec![
659 <T as TensorElement>::zero();
660 batch_size * in_channels * output_height * output_width
661 ];
662
663 let _self_data = self.to_vec()?;
664 let _weight_data = weight.to_vec()?;
665
666 for n in 0..batch_size {
668 for c in 0..in_channels {
669 for oh in 0..output_height {
670 for ow in 0..output_width {
671 let mut sum = <T as TensorElement>::zero();
672
673 for kh in 0..kernel_height {
674 for kw in 0..kernel_width {
675 let ih =
676 (oh * stride.0 + kh * dilation.0) as i32 - padding.0 as i32;
677 let iw =
678 (ow * stride.1 + kw * dilation.1) as i32 - padding.1 as i32;
679
680 if ih >= 0
681 && (ih as usize) < input_height
682 && iw >= 0
683 && (iw as usize) < input_width
684 {
685 let input_idx = n * in_channels * input_height * input_width
686 + c * input_height * input_width
687 + ih as usize * input_width
688 + iw as usize;
689 let weight_idx =
690 c * kernel_height * kernel_width + kh * kernel_width + kw;
691
692 let input_val = self.storage.get(input_idx)?;
693 let weight_val = weight.storage.get(weight_idx)?;
694 sum = sum + input_val * weight_val;
695 }
696 }
697 }
698
699 let output_idx = n * in_channels * output_height * output_width
700 + c * output_height * output_width
701 + oh * output_width
702 + ow;
703 output_data[output_idx] = sum;
704 }
705 }
706 }
707 }
708
709 let mut output = Tensor::from_data(
711 output_data,
712 vec![batch_size, in_channels, output_height, output_width],
713 self.device(),
714 )?;
715
716 if let Some(b) = bias {
718 if b.shape().dims() != [in_channels] {
719 return Err(TorshError::InvalidArgument(format!(
720 "Bias must have shape [{}], got {:?}",
721 in_channels,
722 b.shape().dims()
723 )));
724 }
725
726 let bias_data = b.to_vec()?;
727
728 let mut output_data = output.to_vec()?;
729
730 for n in 0..batch_size {
731 #[allow(clippy::needless_range_loop)]
732 for c in 0..in_channels {
733 for oh in 0..output_height {
734 for ow in 0..output_width {
735 let idx = n * in_channels * output_height * output_width
736 + c * output_height * output_width
737 + oh * output_width
738 + ow;
739 output_data[idx] = output_data[idx] + bias_data[c];
740 }
741 }
742 }
743 }
744
745 output = Tensor::from_data(
747 output_data,
748 vec![batch_size, in_channels, output_height, output_width],
749 self.device(),
750 )?;
751 }
752
753 if self.requires_grad
755 || weight.requires_grad
756 || (bias.is_some() && bias.expect("bias checked with is_some").requires_grad)
757 {
758 use std::sync::Arc;
759 output.requires_grad = true;
760 output.operation = crate::Operation::Custom(
761 "depthwise_conv2d".to_string(),
762 vec![
763 Arc::downgrade(&Arc::new(self.clone())),
764 Arc::downgrade(&Arc::new(weight.clone())),
765 ],
766 );
767 }
768
769 Ok(output)
770 }
771
772 pub fn separable_conv2d(
775 &self,
776 depthwise_weight: &Self,
777 pointwise_weight: &Self,
778 bias: Option<&Self>,
779 stride: (usize, usize),
780 padding: (usize, usize),
781 dilation: (usize, usize),
782 ) -> Result<Self> {
783 let depthwise_output = self.depthwise_conv2d(
785 depthwise_weight,
786 None, stride,
788 padding,
789 dilation,
790 )?;
791
792 let output = depthwise_output.conv2d(
794 pointwise_weight,
795 bias,
796 (1, 1), (0, 0), (1, 1), 1, )?;
801
802 if self.requires_grad
804 || depthwise_weight.requires_grad
805 || pointwise_weight.requires_grad
806 || (bias.is_some() && bias.expect("bias checked with is_some").requires_grad)
807 {
808 use std::sync::Arc;
809 let mut tracked_output = output;
810 tracked_output.requires_grad = true;
811 tracked_output.operation = crate::Operation::Custom(
812 "separable_conv2d".to_string(),
813 vec![
814 Arc::downgrade(&Arc::new(self.clone())),
815 Arc::downgrade(&Arc::new(depthwise_weight.clone())),
816 Arc::downgrade(&Arc::new(pointwise_weight.clone())),
817 ],
818 );
819 Ok(tracked_output)
820 } else {
821 Ok(output)
822 }
823 }
824
825 #[allow(clippy::too_many_arguments)]
827 pub fn conv_transpose2d(
828 &self,
829 weight: &Self,
830 bias: Option<&Self>,
831 stride: (usize, usize),
832 padding: (usize, usize),
833 output_padding: (usize, usize),
834 dilation: (usize, usize),
835 groups: usize,
836 ) -> Result<Self> {
837 let input_shape_obj = self.shape();
842 let input_shape = input_shape_obj.dims();
843 let weight_shape_obj = weight.shape();
844 let weight_shape = weight_shape_obj.dims();
845
846 if input_shape.len() != 4 {
847 return Err(TorshError::InvalidArgument(format!(
848 "Expected 4D input tensor for conv_transpose2d, got {}D",
849 input_shape.len()
850 )));
851 }
852
853 if weight_shape.len() != 4 {
854 return Err(TorshError::InvalidArgument(format!(
855 "Expected 4D weight tensor for conv_transpose2d, got {}D",
856 weight_shape.len()
857 )));
858 }
859
860 let batch_size = input_shape[0];
861 let in_channels = input_shape[1];
862 let input_height = input_shape[2];
863 let input_width = input_shape[3];
864
865 let out_channels = weight_shape[1] * groups;
866 let kernel_height = weight_shape[2];
867 let kernel_width = weight_shape[3];
868
869 if in_channels % groups != 0 || out_channels % groups != 0 {
871 return Err(TorshError::InvalidArgument(
872 "in_channels and out_channels must be divisible by groups".to_string(),
873 ));
874 }
875
876 if weight_shape[0] != in_channels {
877 return Err(TorshError::InvalidArgument(format!(
878 "Weight tensor has wrong number of input channels: expected {}, got {}",
879 in_channels, weight_shape[0]
880 )));
881 }
882
883 let effective_kernel_h = (kernel_height - 1) * dilation.0 + 1;
885 let effective_kernel_w = (kernel_width - 1) * dilation.1 + 1;
886 let output_height =
887 (input_height - 1) * stride.0 - 2 * padding.0 + effective_kernel_h + output_padding.0;
888 let output_width =
889 (input_width - 1) * stride.1 - 2 * padding.1 + effective_kernel_w + output_padding.1;
890
891 let mut output_data = vec![
893 <T as TensorElement>::zero();
894 batch_size * out_channels * output_height * output_width
895 ];
896
897 let self_data = self.to_vec()?;
898 let weight_data = weight.to_vec()?;
899
900 for n in 0..batch_size {
902 for g in 0..groups {
903 let in_ch_start = g * (in_channels / groups);
904 let in_ch_end = (g + 1) * (in_channels / groups);
905 let out_ch_start = g * (out_channels / groups);
906 let out_ch_end = (g + 1) * (out_channels / groups);
907
908 for ic in in_ch_start..in_ch_end {
909 for ih in 0..input_height {
910 for iw in 0..input_width {
911 let input_val = self_data[n * in_channels * input_height * input_width
912 + ic * input_height * input_width
913 + ih * input_width
914 + iw];
915
916 for oc in out_ch_start..out_ch_end {
917 let oc_rel = oc - out_ch_start;
918 for kh in 0..kernel_height {
919 for kw in 0..kernel_width {
920 let oh = ih * stride.0 + kh * dilation.0;
921 let ow = iw * stride.1 + kw * dilation.1;
922
923 if oh >= padding.0 && ow >= padding.1 {
924 let oh_final = oh - padding.0;
925 let ow_final = ow - padding.1;
926
927 if oh_final < output_height && ow_final < output_width {
928 let weight_idx = ic
929 * (out_channels / groups)
930 * kernel_height
931 * kernel_width
932 + oc_rel * kernel_height * kernel_width
933 + kh * kernel_width
934 + kw;
935 let output_idx =
936 n * out_channels * output_height * output_width
937 + oc * output_height * output_width
938 + oh_final * output_width
939 + ow_final;
940
941 output_data[output_idx] = output_data[output_idx]
942 + input_val * weight_data[weight_idx];
943 }
944 }
945 }
946 }
947 }
948 }
949 }
950 }
951 }
952 }
953
954 let mut output = Tensor::from_data(
956 output_data,
957 vec![batch_size, out_channels, output_height, output_width],
958 self.device(),
959 )?;
960
961 if let Some(b) = bias {
963 if b.shape().dims() != [out_channels] {
964 return Err(TorshError::InvalidArgument(format!(
965 "Bias must have shape [{}], got {:?}",
966 out_channels,
967 b.shape().dims()
968 )));
969 }
970
971 let bias_data = b.to_vec()?;
972
973 let mut output_data = output.to_vec()?;
974
975 for n in 0..batch_size {
976 #[allow(clippy::needless_range_loop)]
977 for oc in 0..out_channels {
978 for oh in 0..output_height {
979 for ow in 0..output_width {
980 let idx = n * out_channels * output_height * output_width
981 + oc * output_height * output_width
982 + oh * output_width
983 + ow;
984 output_data[idx] = output_data[idx] + bias_data[oc];
985 }
986 }
987 }
988 }
989
990 output = Tensor::from_data(
992 output_data,
993 vec![batch_size, out_channels, output_height, output_width],
994 self.device(),
995 )?;
996 }
997
998 if self.requires_grad
1000 || weight.requires_grad
1001 || (bias.is_some() && bias.expect("bias checked with is_some").requires_grad)
1002 {
1003 use std::sync::Arc;
1004 output.requires_grad = true;
1005 output.operation = crate::Operation::Custom(
1006 "conv_transpose2d".to_string(),
1007 vec![
1008 Arc::downgrade(&Arc::new(self.clone())),
1009 Arc::downgrade(&Arc::new(weight.clone())),
1010 ],
1011 );
1012 }
1013
1014 Ok(output)
1015 }
1016
1017 #[allow(clippy::needless_range_loop)]
1020 pub fn xcorr1d(&self, other: &Self, mode: CorrelationMode) -> Result<Self> {
1021 let self_shape_ref = self.shape();
1022 let other_shape_ref = other.shape();
1023 let self_shape = self_shape_ref.dims();
1024 let other_shape = other_shape_ref.dims();
1025
1026 if self_shape.len() != 1 || other_shape.len() != 1 {
1027 return Err(TorshError::InvalidArgument(
1028 "xcorr1d requires 1D tensors".to_string(),
1029 ));
1030 }
1031
1032 let n = self_shape[0];
1033 let m = other_shape[0];
1034
1035 let (output_size, lag_start) = match mode {
1036 CorrelationMode::Full => (n + m - 1, -(m as i32 - 1)),
1037 CorrelationMode::Valid => {
1038 if n < m || m < n {
1039 return Err(TorshError::InvalidArgument(
1040 "Valid mode requires both tensors to have the same size or one to be smaller".to_string(),
1041 ));
1042 }
1043 (std::cmp::max(n, m) - std::cmp::min(n, m) + 1, 0)
1044 }
1045 CorrelationMode::Same => (n, -((m as i32 - 1) / 2)),
1046 };
1047
1048 let mut output_data = vec![<T as TensorElement>::zero(); output_size];
1049 let self_data = self.to_vec()?;
1050 let other_data = other.to_vec()?;
1051
1052 for i in 0..output_size {
1054 let mut sum = <T as TensorElement>::zero();
1055 let lag = lag_start + i as i32;
1056
1057 for j in 0..n {
1058 let other_idx = j as i32 - lag;
1059 if other_idx >= 0 && (other_idx as usize) < m {
1060 sum = sum + self_data[j] * other_data[other_idx as usize];
1061 }
1062 }
1063 output_data[i] = sum;
1064 }
1065
1066 let output = Tensor::from_data(output_data, vec![output_size], self.device())?;
1067
1068 Ok(output)
1069 }
1070
1071 pub fn autocorr1d(&self, max_lag: Option<usize>) -> Result<Self> {
1074 let shape_ref = self.shape();
1075 let shape = shape_ref.dims();
1076 if shape.len() != 1 {
1077 return Err(TorshError::InvalidArgument(
1078 "autocorr1d requires 1D tensor".to_string(),
1079 ));
1080 }
1081
1082 let n = shape[0];
1083 let max_lag = max_lag.unwrap_or(n - 1).min(n - 1);
1084
1085 let self_data = self.to_vec()?;
1086 let mut output_data = Vec::with_capacity(max_lag + 1);
1087
1088 for lag in 0..=max_lag {
1090 let mut sum = <T as TensorElement>::zero();
1091
1092 for i in lag..n {
1093 sum = sum + self_data[i] * self_data[i - lag];
1094 }
1095
1096 output_data.push(sum);
1097 }
1098
1099 let output = Tensor::from_data(output_data, vec![max_lag + 1], self.device())?;
1100 Ok(output)
1101 }
1102
1103 pub fn xcorr2d(&self, other: &Self, mode: CorrelationMode) -> Result<Self> {
1106 let self_shape_ref = self.shape();
1107 let other_shape_ref = other.shape();
1108 let self_shape = self_shape_ref.dims();
1109 let other_shape = other_shape_ref.dims();
1110
1111 if self_shape.len() != 2 || other_shape.len() != 2 {
1112 return Err(TorshError::InvalidArgument(
1113 "xcorr2d requires 2D tensors".to_string(),
1114 ));
1115 }
1116
1117 let (h1, w1) = (self_shape[0], self_shape[1]);
1118 let (h2, w2) = (other_shape[0], other_shape[1]);
1119
1120 let (out_h, out_w, start_h, start_w) = match mode {
1121 CorrelationMode::Full => (h1 + h2 - 1, w1 + w2 - 1, 0, 0),
1122 CorrelationMode::Valid => {
1123 if h1 < h2 || w1 < w2 {
1124 return Err(TorshError::InvalidArgument(
1125 "Valid mode requires first tensor to be larger than or equal to second"
1126 .to_string(),
1127 ));
1128 }
1129 (h1 - h2 + 1, w1 - w2 + 1, h2 - 1, w2 - 1)
1130 }
1131 CorrelationMode::Same => (h1, w1, (h2 - 1) / 2, (w2 - 1) / 2),
1132 };
1133
1134 let mut output_data = vec![<T as TensorElement>::zero(); out_h * out_w];
1135 let self_data = self.to_vec()?;
1136 let other_data = other.to_vec()?;
1137
1138 for i in 0..out_h {
1140 for j in 0..out_w {
1141 let mut sum = <T as TensorElement>::zero();
1142 let actual_i = i + start_h;
1143 let actual_j = j + start_w;
1144
1145 for ki in 0..h2 {
1146 for kj in 0..w2 {
1147 let src_i = actual_i as i32 - ki as i32;
1148 let src_j = actual_j as i32 - kj as i32;
1149
1150 if src_i >= 0
1151 && (src_i as usize) < h1
1152 && src_j >= 0
1153 && (src_j as usize) < w1
1154 {
1155 let self_idx = src_i as usize * w1 + src_j as usize;
1156 let other_idx = ki * w2 + kj;
1157 sum = sum + self_data[self_idx] * other_data[other_idx];
1158 }
1159 }
1160 }
1161 output_data[i * out_w + j] = sum;
1162 }
1163 }
1164
1165 let output = Tensor::from_data(output_data, vec![out_h, out_w], self.device())?;
1166 Ok(output)
1167 }
1168
1169 pub fn median_filter1d(&self, window_size: usize) -> Result<Self> {
1172 let shape_ref = self.shape();
1173 let shape = shape_ref.dims();
1174 if shape.len() != 1 {
1175 return Err(TorshError::InvalidArgument(
1176 "median_filter1d requires 1D tensor".to_string(),
1177 ));
1178 }
1179
1180 if window_size == 0 || window_size % 2 == 0 {
1181 return Err(TorshError::InvalidArgument(
1182 "Window size must be odd and greater than 0".to_string(),
1183 ));
1184 }
1185
1186 let n = shape[0];
1187 let half_window = window_size / 2;
1188 let mut output_data = Vec::with_capacity(n);
1189 let self_data = self.to_vec()?;
1190
1191 for i in 0..n {
1192 let mut window_values = Vec::new();
1193
1194 for j in 0..window_size {
1196 let idx = i as i32 + j as i32 - half_window as i32;
1197 let actual_idx = if idx < 0 {
1198 0
1199 } else if idx >= n as i32 {
1200 n - 1
1201 } else {
1202 idx as usize
1203 };
1204 window_values.push(self_data[actual_idx]);
1205 }
1206
1207 window_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1209 output_data.push(window_values[half_window]);
1210 }
1211
1212 let output = Tensor::from_data(output_data, vec![n], self.device())?;
1213 Ok(output)
1214 }
1215
1216 pub fn median_filter2d(&self, window_size: (usize, usize)) -> Result<Self> {
1219 let shape_ref = self.shape();
1220 let shape = shape_ref.dims();
1221 if shape.len() != 2 {
1222 return Err(TorshError::InvalidArgument(
1223 "median_filter2d requires 2D tensor".to_string(),
1224 ));
1225 }
1226
1227 let (window_h, window_w) = window_size;
1228 if window_h == 0 || window_w == 0 || window_h % 2 == 0 || window_w % 2 == 0 {
1229 return Err(TorshError::InvalidArgument(
1230 "Window dimensions must be odd and greater than 0".to_string(),
1231 ));
1232 }
1233
1234 let (h, w) = (shape[0], shape[1]);
1235 let half_h = window_h / 2;
1236 let half_w = window_w / 2;
1237 let mut output_data = Vec::with_capacity(h * w);
1238 let self_data = self.to_vec()?;
1239
1240 for i in 0..h {
1241 for j in 0..w {
1242 let mut window_values = Vec::new();
1243
1244 for di in 0..window_h {
1246 for dj in 0..window_w {
1247 let row = i as i32 + di as i32 - half_h as i32;
1248 let col = j as i32 + dj as i32 - half_w as i32;
1249
1250 let actual_row = row.max(0).min(h as i32 - 1) as usize;
1252 let actual_col = col.max(0).min(w as i32 - 1) as usize;
1253
1254 window_values.push(self_data[actual_row * w + actual_col]);
1255 }
1256 }
1257
1258 window_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1260 output_data.push(window_values[window_values.len() / 2]);
1261 }
1262 }
1263
1264 let output = Tensor::from_data(output_data, vec![h, w], self.device())?;
1265 Ok(output)
1266 }
1267
1268 pub fn gaussian_filter1d(&self, sigma: f32, kernel_size: Option<usize>) -> Result<Self> {
1271 let tensor_shape = self.shape();
1272 let shape = tensor_shape.dims();
1273 if shape.len() != 1 {
1274 return Err(TorshError::InvalidArgument(
1275 "gaussian_filter1d requires 1D tensor".to_string(),
1276 ));
1277 }
1278
1279 if sigma <= 0.0 {
1280 return Err(TorshError::InvalidArgument(
1281 "Sigma must be positive".to_string(),
1282 ));
1283 }
1284
1285 let kernel_size = kernel_size.unwrap_or(((6.0 * sigma) as usize).max(3));
1287 let kernel_size = if kernel_size % 2 == 0 {
1288 kernel_size + 1
1289 } else {
1290 kernel_size
1291 };
1292
1293 let half_size = kernel_size / 2;
1295 let mut kernel = Vec::with_capacity(kernel_size);
1296 let mut sum = 0.0f32;
1297
1298 for i in 0..kernel_size {
1299 let x = i as f32 - half_size as f32;
1300 let value = (-0.5 * (x / sigma).powi(2)).exp();
1301 kernel.push(value);
1302 sum += value;
1303 }
1304
1305 for value in &mut kernel {
1307 *value /= sum;
1308 }
1309
1310 let kernel_data: Vec<T> = kernel
1312 .into_iter()
1313 .map(|v| {
1314 T::from(v as f64)
1315 .unwrap_or_else(|| T::from(0.0).expect("numeric conversion should succeed"))
1316 })
1317 .collect();
1318 let kernel_tensor = Tensor::from_data(kernel_data, vec![kernel_size], self.device())?;
1319
1320 self.xcorr1d(&kernel_tensor, CorrelationMode::Same)
1322 }
1323
1324 pub fn gaussian_filter2d(
1327 &self,
1328 sigma: (f32, f32),
1329 kernel_size: Option<(usize, usize)>,
1330 ) -> Result<Self> {
1331 let tensor_shape = self.shape();
1332 let shape = tensor_shape.dims();
1333 if shape.len() != 2 {
1334 return Err(TorshError::InvalidArgument(
1335 "gaussian_filter2d requires 2D tensor".to_string(),
1336 ));
1337 }
1338
1339 let (sigma_x, sigma_y) = sigma;
1340 if sigma_x <= 0.0 || sigma_y <= 0.0 {
1341 return Err(TorshError::InvalidArgument(
1342 "Sigma values must be positive".to_string(),
1343 ));
1344 }
1345
1346 let (kernel_h, kernel_w) = kernel_size.unwrap_or((
1348 ((6.0 * sigma_y) as usize).max(3),
1349 ((6.0 * sigma_x) as usize).max(3),
1350 ));
1351 let kernel_h = if kernel_h % 2 == 0 {
1352 kernel_h + 1
1353 } else {
1354 kernel_h
1355 };
1356 let kernel_w = if kernel_w % 2 == 0 {
1357 kernel_w + 1
1358 } else {
1359 kernel_w
1360 };
1361
1362 let half_h = kernel_h / 2;
1364 let half_w = kernel_w / 2;
1365 let mut kernel = Vec::with_capacity(kernel_h * kernel_w);
1366 let mut sum = 0.0f32;
1367
1368 for i in 0..kernel_h {
1369 for j in 0..kernel_w {
1370 let y = i as f32 - half_h as f32;
1371 let x = j as f32 - half_w as f32;
1372 let value = (-0.5 * ((x / sigma_x).powi(2) + (y / sigma_y).powi(2))).exp();
1373 kernel.push(value);
1374 sum += value;
1375 }
1376 }
1377
1378 for value in &mut kernel {
1380 *value /= sum;
1381 }
1382
1383 let kernel_data: Vec<T> = kernel
1385 .into_iter()
1386 .map(|v| {
1387 T::from(v as f64)
1388 .unwrap_or_else(|| T::from(0.0).expect("numeric conversion should succeed"))
1389 })
1390 .collect();
1391 let kernel_tensor =
1392 Tensor::from_data(kernel_data, vec![kernel_h, kernel_w], self.device())?;
1393
1394 self.xcorr2d(&kernel_tensor, CorrelationMode::Same)
1396 }
1397}
1398
1399#[derive(Debug, Clone, Copy, PartialEq)]
1401pub enum CorrelationMode {
1402 Full,
1404 Valid,
1406 Same,
1408}