Skip to main content

torsh_tensor/
conv.rs

1//! Convolution and signal processing operations for tensors
2
3use crate::{FloatElement, Tensor};
4use torsh_core::error::{Result, TorshError};
5use torsh_core::TensorElement;
6
7impl<T: FloatElement> Tensor<T> {
8    /// 1D convolution operation
9    pub fn conv1d(
10        &self,
11        weight: &Self,
12        bias: Option<&Self>,
13        stride: usize,
14        padding: usize,
15        dilation: usize,
16        groups: usize,
17    ) -> Result<Self> {
18        // Input shape: (N, C_in, L)
19        // Weight shape: (C_out, C_in/groups, kernel_size)
20        // Output shape: (N, C_out, L_out)
21
22        let input_shape_obj = self.shape();
23        let input_shape = input_shape_obj.dims();
24        let weight_shape_obj = weight.shape();
25        let weight_shape = weight_shape_obj.dims();
26
27        if input_shape.len() != 3 {
28            return Err(TorshError::InvalidArgument(format!(
29                "Expected 3D input tensor for conv1d, got {}D",
30                input_shape.len()
31            )));
32        }
33
34        if weight_shape.len() != 3 {
35            return Err(TorshError::InvalidArgument(format!(
36                "Expected 3D weight tensor for conv1d, got {}D",
37                weight_shape.len()
38            )));
39        }
40
41        let batch_size = input_shape[0];
42        let in_channels = input_shape[1];
43        let input_length = input_shape[2];
44
45        let out_channels = weight_shape[0];
46        let kernel_size = weight_shape[2];
47
48        // Check groups
49        if in_channels % groups != 0 || out_channels % groups != 0 {
50            return Err(TorshError::InvalidArgument(
51                "in_channels and out_channels must be divisible by groups".to_string(),
52            ));
53        }
54
55        if weight_shape[1] != in_channels / groups {
56            return Err(TorshError::InvalidArgument(format!(
57                "Weight tensor has wrong number of input channels: expected {}, got {}",
58                in_channels / groups,
59                weight_shape[1]
60            )));
61        }
62
63        // Calculate output length
64        let effective_kernel = (kernel_size - 1) * dilation + 1;
65        let padded_length = input_length + 2 * padding;
66        let output_length = (padded_length - effective_kernel) / stride + 1;
67
68        // Initialize output
69        let mut output_data =
70            vec![<T as TensorElement>::zero(); batch_size * out_channels * output_length];
71
72        // Perform convolution
73        for n in 0..batch_size {
74            for g in 0..groups {
75                let out_ch_start = g * (out_channels / groups);
76                let out_ch_end = (g + 1) * (out_channels / groups);
77                let in_ch_start = g * (in_channels / groups);
78                let in_ch_end = (g + 1) * (in_channels / groups);
79
80                for oc in out_ch_start..out_ch_end {
81                    for ol in 0..output_length {
82                        let mut sum = <T as TensorElement>::zero();
83
84                        for ic in in_ch_start..in_ch_end {
85                            let ic_rel = ic - in_ch_start;
86                            for k in 0..kernel_size {
87                                let il = (ol * stride + k * dilation) as i32 - padding as i32;
88
89                                if il >= 0 && (il as usize) < input_length {
90                                    let input_idx = n * in_channels * input_length
91                                        + ic * input_length
92                                        + il as usize;
93                                    let weight_idx = oc * (in_channels / groups) * kernel_size
94                                        + ic_rel * kernel_size
95                                        + k;
96
97                                    let input_val = self.storage.get(input_idx)?;
98                                    let weight_val = weight.storage.get(weight_idx)?;
99                                    sum = sum + input_val * weight_val;
100                                }
101                            }
102                        }
103
104                        let output_idx = n * out_channels * output_length + oc * output_length + ol;
105                        output_data[output_idx] = sum;
106                    }
107                }
108            }
109        }
110
111        // Create output tensor
112        let mut output = Tensor::from_data(
113            output_data,
114            vec![batch_size, out_channels, output_length],
115            self.device(),
116        )?;
117
118        // Add bias if provided
119        if let Some(b) = bias {
120            if b.shape().dims() != [out_channels] {
121                return Err(TorshError::InvalidArgument(format!(
122                    "Bias must have shape [{}], got {:?}",
123                    out_channels,
124                    b.shape().dims()
125                )));
126            }
127
128            // For now, use element-wise addition - TODO: implement efficient broadcasting
129            let bias_data = b.to_vec()?;
130            let mut output_data = output.to_vec()?;
131
132            for n in 0..batch_size {
133                #[allow(clippy::needless_range_loop)]
134                for oc in 0..out_channels {
135                    for ol in 0..output_length {
136                        let idx = n * out_channels * output_length + oc * output_length + ol;
137                        output_data[idx] = output_data[idx] + bias_data[oc];
138                    }
139                }
140            }
141
142            // Recreate tensor with modified data
143            output = Tensor::from_data(
144                output_data,
145                vec![batch_size, out_channels, output_length],
146                self.device(),
147            )?;
148        }
149
150        // Track operation for autograd
151        if self.requires_grad
152            || weight.requires_grad
153            || (bias.is_some() && bias.expect("bias checked with is_some").requires_grad)
154        {
155            use std::sync::Arc;
156            output.requires_grad = true;
157            output.operation = crate::Operation::Custom(
158                "conv1d".to_string(),
159                vec![
160                    Arc::downgrade(&Arc::new(self.clone())),
161                    Arc::downgrade(&Arc::new(weight.clone())),
162                ],
163            );
164        }
165
166        Ok(output)
167    }
168
169    /// 2D convolution operation
170    pub fn conv2d(
171        &self,
172        weight: &Self,
173        bias: Option<&Self>,
174        stride: (usize, usize),
175        padding: (usize, usize),
176        dilation: (usize, usize),
177        groups: usize,
178    ) -> Result<Self> {
179        // Input shape: (N, C_in, H, W)
180        // Weight shape: (C_out, C_in/groups, kernel_h, kernel_w)
181        // Output shape: (N, C_out, H_out, W_out)
182
183        let input_shape_obj = self.shape();
184        let input_shape = input_shape_obj.dims();
185        let weight_shape_obj = weight.shape();
186        let weight_shape = weight_shape_obj.dims();
187
188        if input_shape.len() != 4 {
189            return Err(TorshError::InvalidArgument(format!(
190                "Expected 4D input tensor for conv2d, got {}D",
191                input_shape.len()
192            )));
193        }
194
195        if weight_shape.len() != 4 {
196            return Err(TorshError::InvalidArgument(format!(
197                "Expected 4D weight tensor for conv2d, got {}D",
198                weight_shape.len()
199            )));
200        }
201
202        let batch_size = input_shape[0];
203        let in_channels = input_shape[1];
204        let input_height = input_shape[2];
205        let input_width = input_shape[3];
206
207        let out_channels = weight_shape[0];
208        let kernel_height = weight_shape[2];
209        let kernel_width = weight_shape[3];
210
211        // Check groups
212        if in_channels % groups != 0 || out_channels % groups != 0 {
213            return Err(TorshError::InvalidArgument(
214                "in_channels and out_channels must be divisible by groups".to_string(),
215            ));
216        }
217
218        if weight_shape[1] != in_channels / groups {
219            return Err(TorshError::InvalidArgument(format!(
220                "Weight tensor has wrong number of input channels: expected {}, got {}",
221                in_channels / groups,
222                weight_shape[1]
223            )));
224        }
225
226        // Calculate output dimensions
227        let effective_kernel_h = (kernel_height - 1) * dilation.0 + 1;
228        let effective_kernel_w = (kernel_width - 1) * dilation.1 + 1;
229        let padded_height = input_height + 2 * padding.0;
230        let padded_width = input_width + 2 * padding.1;
231        let output_height = (padded_height - effective_kernel_h) / stride.0 + 1;
232        let output_width = (padded_width - effective_kernel_w) / stride.1 + 1;
233
234        // Initialize output
235        let mut output_data = vec![
236            <T as TensorElement>::zero();
237            batch_size * out_channels * output_height * output_width
238        ];
239
240        let self_data = self.to_vec()?;
241        let weight_data = weight.to_vec()?;
242
243        // Perform convolution
244        for n in 0..batch_size {
245            for g in 0..groups {
246                let out_ch_start = g * (out_channels / groups);
247                let out_ch_end = (g + 1) * (out_channels / groups);
248                let in_ch_start = g * (in_channels / groups);
249                let in_ch_end = (g + 1) * (in_channels / groups);
250
251                for oc in out_ch_start..out_ch_end {
252                    for oh in 0..output_height {
253                        for ow in 0..output_width {
254                            let mut sum = <T as TensorElement>::zero();
255
256                            for ic in in_ch_start..in_ch_end {
257                                let ic_rel = ic - in_ch_start;
258                                for kh in 0..kernel_height {
259                                    for kw in 0..kernel_width {
260                                        let ih = (oh * stride.0 + kh * dilation.0) as i32
261                                            - padding.0 as i32;
262                                        let iw = (ow * stride.1 + kw * dilation.1) as i32
263                                            - padding.1 as i32;
264
265                                        if ih >= 0
266                                            && (ih as usize) < input_height
267                                            && iw >= 0
268                                            && (iw as usize) < input_width
269                                        {
270                                            let input_idx =
271                                                n * in_channels * input_height * input_width
272                                                    + ic * input_height * input_width
273                                                    + ih as usize * input_width
274                                                    + iw as usize;
275                                            let weight_idx = oc
276                                                * (in_channels / groups)
277                                                * kernel_height
278                                                * kernel_width
279                                                + ic_rel * kernel_height * kernel_width
280                                                + kh * kernel_width
281                                                + kw;
282
283                                            sum = sum
284                                                + self_data[input_idx] * weight_data[weight_idx];
285                                        }
286                                    }
287                                }
288                            }
289
290                            let output_idx = n * out_channels * output_height * output_width
291                                + oc * output_height * output_width
292                                + oh * output_width
293                                + ow;
294                            output_data[output_idx] = sum;
295                        }
296                    }
297                }
298            }
299        }
300
301        // Create output tensor
302        let mut output = Tensor::from_data(
303            output_data,
304            vec![batch_size, out_channels, output_height, output_width],
305            self.device(),
306        )?;
307
308        // Add bias if provided
309        if let Some(b) = bias {
310            if b.shape().dims() != [out_channels] {
311                return Err(TorshError::InvalidArgument(format!(
312                    "Bias must have shape [{}], got {:?}",
313                    out_channels,
314                    b.shape().dims()
315                )));
316            }
317
318            let bias_data = b.to_vec()?;
319
320            let mut output_data = output.to_vec()?;
321
322            for n in 0..batch_size {
323                #[allow(clippy::needless_range_loop)]
324                for oc in 0..out_channels {
325                    for oh in 0..output_height {
326                        for ow in 0..output_width {
327                            let idx = n * out_channels * output_height * output_width
328                                + oc * output_height * output_width
329                                + oh * output_width
330                                + ow;
331                            output_data[idx] = output_data[idx] + bias_data[oc];
332                        }
333                    }
334                }
335            }
336
337            // Create new output tensor with bias added
338            output = Tensor::from_data(
339                output_data,
340                vec![batch_size, out_channels, output_height, output_width],
341                self.device(),
342            )?;
343        }
344
345        // Track operation for autograd
346        if self.requires_grad
347            || weight.requires_grad
348            || (bias.is_some() && bias.expect("bias checked with is_some").requires_grad)
349        {
350            use std::sync::Arc;
351            output.requires_grad = true;
352            output.operation = crate::Operation::Custom(
353                "conv2d".to_string(),
354                vec![
355                    Arc::downgrade(&Arc::new(self.clone())),
356                    Arc::downgrade(&Arc::new(weight.clone())),
357                ],
358            );
359        }
360
361        Ok(output)
362    }
363
364    /// 3D convolution operation
365    pub fn conv3d(
366        &self,
367        weight: &Self,
368        bias: Option<&Self>,
369        stride: (usize, usize, usize),
370        padding: (usize, usize, usize),
371        dilation: (usize, usize, usize),
372        groups: usize,
373    ) -> Result<Self> {
374        // Input shape: (N, C_in, D, H, W)
375        // Weight shape: (C_out, C_in/groups, kernel_d, kernel_h, kernel_w)
376        // Output shape: (N, C_out, D_out, H_out, W_out)
377
378        let input_shape_obj = self.shape();
379        let input_shape = input_shape_obj.dims();
380        let weight_shape_obj = weight.shape();
381        let weight_shape = weight_shape_obj.dims();
382
383        if input_shape.len() != 5 {
384            return Err(TorshError::InvalidArgument(format!(
385                "Expected 5D input tensor for conv3d, got {}D",
386                input_shape.len()
387            )));
388        }
389
390        if weight_shape.len() != 5 {
391            return Err(TorshError::InvalidArgument(format!(
392                "Expected 5D weight tensor for conv3d, got {}D",
393                weight_shape.len()
394            )));
395        }
396
397        let batch_size = input_shape[0];
398        let in_channels = input_shape[1];
399        let input_depth = input_shape[2];
400        let input_height = input_shape[3];
401        let input_width = input_shape[4];
402
403        let out_channels = weight_shape[0];
404        let kernel_depth = weight_shape[2];
405        let kernel_height = weight_shape[3];
406        let kernel_width = weight_shape[4];
407
408        // Check groups
409        if in_channels % groups != 0 || out_channels % groups != 0 {
410            return Err(TorshError::InvalidArgument(
411                "in_channels and out_channels must be divisible by groups".to_string(),
412            ));
413        }
414
415        if weight_shape[1] != in_channels / groups {
416            return Err(TorshError::InvalidArgument(format!(
417                "Weight tensor has wrong number of input channels: expected {}, got {}",
418                in_channels / groups,
419                weight_shape[1]
420            )));
421        }
422
423        // Calculate output dimensions
424        let effective_kernel_d = (kernel_depth - 1) * dilation.0 + 1;
425        let effective_kernel_h = (kernel_height - 1) * dilation.1 + 1;
426        let effective_kernel_w = (kernel_width - 1) * dilation.2 + 1;
427        let padded_depth = input_depth + 2 * padding.0;
428        let padded_height = input_height + 2 * padding.1;
429        let padded_width = input_width + 2 * padding.2;
430        let output_depth = (padded_depth - effective_kernel_d) / stride.0 + 1;
431        let output_height = (padded_height - effective_kernel_h) / stride.1 + 1;
432        let output_width = (padded_width - effective_kernel_w) / stride.2 + 1;
433
434        // Initialize output
435        let output_size = batch_size * out_channels * output_depth * output_height * output_width;
436        let mut output_data = vec![<T as TensorElement>::zero(); output_size];
437
438        let self_data = self.to_vec()?;
439        let weight_data = weight.to_vec()?;
440
441        // Perform convolution
442        for n in 0..batch_size {
443            for g in 0..groups {
444                let out_ch_start = g * (out_channels / groups);
445                let out_ch_end = (g + 1) * (out_channels / groups);
446                let in_ch_start = g * (in_channels / groups);
447                let in_ch_end = (g + 1) * (in_channels / groups);
448
449                for oc in out_ch_start..out_ch_end {
450                    for od in 0..output_depth {
451                        for oh in 0..output_height {
452                            for ow in 0..output_width {
453                                let mut sum = <T as TensorElement>::zero();
454
455                                for ic in in_ch_start..in_ch_end {
456                                    let ic_rel = ic - in_ch_start;
457                                    for kd in 0..kernel_depth {
458                                        for kh in 0..kernel_height {
459                                            for kw in 0..kernel_width {
460                                                let id = (od * stride.0 + kd * dilation.0) as i32
461                                                    - padding.0 as i32;
462                                                let ih = (oh * stride.1 + kh * dilation.1) as i32
463                                                    - padding.1 as i32;
464                                                let iw = (ow * stride.2 + kw * dilation.2) as i32
465                                                    - padding.2 as i32;
466
467                                                if id >= 0
468                                                    && (id as usize) < input_depth
469                                                    && ih >= 0
470                                                    && (ih as usize) < input_height
471                                                    && iw >= 0
472                                                    && (iw as usize) < input_width
473                                                {
474                                                    let input_idx = n
475                                                        * in_channels
476                                                        * input_depth
477                                                        * input_height
478                                                        * input_width
479                                                        + ic * input_depth
480                                                            * input_height
481                                                            * input_width
482                                                        + id as usize * input_height * input_width
483                                                        + ih as usize * input_width
484                                                        + iw as usize;
485                                                    let weight_idx = oc
486                                                        * (in_channels / groups)
487                                                        * kernel_depth
488                                                        * kernel_height
489                                                        * kernel_width
490                                                        + ic_rel
491                                                            * kernel_depth
492                                                            * kernel_height
493                                                            * kernel_width
494                                                        + kd * kernel_height * kernel_width
495                                                        + kh * kernel_width
496                                                        + kw;
497
498                                                    sum = sum
499                                                        + self_data[input_idx]
500                                                            * weight_data[weight_idx];
501                                                }
502                                            }
503                                        }
504                                    }
505                                }
506
507                                let output_idx =
508                                    n * out_channels * output_depth * output_height * output_width
509                                        + oc * output_depth * output_height * output_width
510                                        + od * output_height * output_width
511                                        + oh * output_width
512                                        + ow;
513                                output_data[output_idx] = sum;
514                            }
515                        }
516                    }
517                }
518            }
519        }
520
521        // Create output tensor
522        let mut output = Tensor::from_data(
523            output_data,
524            vec![
525                batch_size,
526                out_channels,
527                output_depth,
528                output_height,
529                output_width,
530            ],
531            self.device(),
532        )?;
533
534        // Add bias if provided
535        if let Some(b) = bias {
536            if b.shape().dims() != [out_channels] {
537                return Err(TorshError::InvalidArgument(format!(
538                    "Bias must have shape [{}], got {:?}",
539                    out_channels,
540                    b.shape().dims()
541                )));
542            }
543
544            let bias_data = b.to_vec()?;
545
546            let mut output_data = output.to_vec()?;
547
548            for n in 0..batch_size {
549                #[allow(clippy::needless_range_loop)]
550                for oc in 0..out_channels {
551                    for od in 0..output_depth {
552                        for oh in 0..output_height {
553                            for ow in 0..output_width {
554                                let idx =
555                                    n * out_channels * output_depth * output_height * output_width
556                                        + oc * output_depth * output_height * output_width
557                                        + od * output_height * output_width
558                                        + oh * output_width
559                                        + ow;
560                                output_data[idx] = output_data[idx] + bias_data[oc];
561                            }
562                        }
563                    }
564                }
565            }
566
567            // Create new output tensor with bias added
568            output = Tensor::from_data(
569                output_data,
570                vec![
571                    batch_size,
572                    out_channels,
573                    output_depth,
574                    output_height,
575                    output_width,
576                ],
577                self.device(),
578            )?;
579        }
580
581        // Track operation for autograd
582        if self.requires_grad
583            || weight.requires_grad
584            || (bias.is_some() && bias.expect("bias checked with is_some").requires_grad)
585        {
586            use std::sync::Arc;
587            output.requires_grad = true;
588            output.operation = crate::Operation::Custom(
589                "conv3d".to_string(),
590                vec![
591                    Arc::downgrade(&Arc::new(self.clone())),
592                    Arc::downgrade(&Arc::new(weight.clone())),
593                ],
594            );
595        }
596
597        Ok(output)
598    }
599
600    /// Depthwise 2D convolution operation
601    /// Each input channel is convolved with its own kernel independently
602    pub fn depthwise_conv2d(
603        &self,
604        weight: &Self,
605        bias: Option<&Self>,
606        stride: (usize, usize),
607        padding: (usize, usize),
608        dilation: (usize, usize),
609    ) -> Result<Self> {
610        // Input shape: (N, C_in, H, W)
611        // Weight shape: (C_in, 1, kernel_h, kernel_w) - each channel has its own kernel
612        // Output shape: (N, C_in, H_out, W_out)
613
614        let input_shape_obj = self.shape();
615        let input_shape = input_shape_obj.dims();
616        let weight_shape_obj = weight.shape();
617        let weight_shape = weight_shape_obj.dims();
618
619        if input_shape.len() != 4 {
620            return Err(TorshError::InvalidArgument(format!(
621                "Expected 4D input tensor for depthwise_conv2d, got {}D",
622                input_shape.len()
623            )));
624        }
625
626        if weight_shape.len() != 4 {
627            return Err(TorshError::InvalidArgument(format!(
628                "Expected 4D weight tensor for depthwise_conv2d, got {}D",
629                weight_shape.len()
630            )));
631        }
632
633        let batch_size = input_shape[0];
634        let in_channels = input_shape[1];
635        let input_height = input_shape[2];
636        let input_width = input_shape[3];
637
638        let kernel_height = weight_shape[2];
639        let kernel_width = weight_shape[3];
640
641        // For depthwise conv, weight should have shape (C_in, 1, kernel_h, kernel_w)
642        if weight_shape[0] != in_channels || weight_shape[1] != 1 {
643            return Err(TorshError::InvalidArgument(format!(
644                "Weight tensor must have shape ({}, 1, kernel_h, kernel_w), got ({}, {}, {}, {})",
645                in_channels, weight_shape[0], weight_shape[1], weight_shape[2], weight_shape[3]
646            )));
647        }
648
649        // Calculate output dimensions
650        let effective_kernel_h = (kernel_height - 1) * dilation.0 + 1;
651        let effective_kernel_w = (kernel_width - 1) * dilation.1 + 1;
652        let padded_height = input_height + 2 * padding.0;
653        let padded_width = input_width + 2 * padding.1;
654        let output_height = (padded_height - effective_kernel_h) / stride.0 + 1;
655        let output_width = (padded_width - effective_kernel_w) / stride.1 + 1;
656
657        // Initialize output
658        let mut output_data = vec![
659            <T as TensorElement>::zero();
660            batch_size * in_channels * output_height * output_width
661        ];
662
663        let _self_data = self.to_vec()?;
664        let _weight_data = weight.to_vec()?;
665
666        // Perform depthwise convolution
667        for n in 0..batch_size {
668            for c in 0..in_channels {
669                for oh in 0..output_height {
670                    for ow in 0..output_width {
671                        let mut sum = <T as TensorElement>::zero();
672
673                        for kh in 0..kernel_height {
674                            for kw in 0..kernel_width {
675                                let ih =
676                                    (oh * stride.0 + kh * dilation.0) as i32 - padding.0 as i32;
677                                let iw =
678                                    (ow * stride.1 + kw * dilation.1) as i32 - padding.1 as i32;
679
680                                if ih >= 0
681                                    && (ih as usize) < input_height
682                                    && iw >= 0
683                                    && (iw as usize) < input_width
684                                {
685                                    let input_idx = n * in_channels * input_height * input_width
686                                        + c * input_height * input_width
687                                        + ih as usize * input_width
688                                        + iw as usize;
689                                    let weight_idx =
690                                        c * kernel_height * kernel_width + kh * kernel_width + kw;
691
692                                    let input_val = self.storage.get(input_idx)?;
693                                    let weight_val = weight.storage.get(weight_idx)?;
694                                    sum = sum + input_val * weight_val;
695                                }
696                            }
697                        }
698
699                        let output_idx = n * in_channels * output_height * output_width
700                            + c * output_height * output_width
701                            + oh * output_width
702                            + ow;
703                        output_data[output_idx] = sum;
704                    }
705                }
706            }
707        }
708
709        // Create output tensor
710        let mut output = Tensor::from_data(
711            output_data,
712            vec![batch_size, in_channels, output_height, output_width],
713            self.device(),
714        )?;
715
716        // Add bias if provided
717        if let Some(b) = bias {
718            if b.shape().dims() != [in_channels] {
719                return Err(TorshError::InvalidArgument(format!(
720                    "Bias must have shape [{}], got {:?}",
721                    in_channels,
722                    b.shape().dims()
723                )));
724            }
725
726            let bias_data = b.to_vec()?;
727
728            let mut output_data = output.to_vec()?;
729
730            for n in 0..batch_size {
731                #[allow(clippy::needless_range_loop)]
732                for c in 0..in_channels {
733                    for oh in 0..output_height {
734                        for ow in 0..output_width {
735                            let idx = n * in_channels * output_height * output_width
736                                + c * output_height * output_width
737                                + oh * output_width
738                                + ow;
739                            output_data[idx] = output_data[idx] + bias_data[c];
740                        }
741                    }
742                }
743            }
744
745            // Create new output tensor with bias added
746            output = Tensor::from_data(
747                output_data,
748                vec![batch_size, in_channels, output_height, output_width],
749                self.device(),
750            )?;
751        }
752
753        // Track operation for autograd
754        if self.requires_grad
755            || weight.requires_grad
756            || (bias.is_some() && bias.expect("bias checked with is_some").requires_grad)
757        {
758            use std::sync::Arc;
759            output.requires_grad = true;
760            output.operation = crate::Operation::Custom(
761                "depthwise_conv2d".to_string(),
762                vec![
763                    Arc::downgrade(&Arc::new(self.clone())),
764                    Arc::downgrade(&Arc::new(weight.clone())),
765                ],
766            );
767        }
768
769        Ok(output)
770    }
771
772    /// Separable 2D convolution operation
773    /// Factorized into depthwise convolution followed by pointwise (1x1) convolution
774    pub fn separable_conv2d(
775        &self,
776        depthwise_weight: &Self,
777        pointwise_weight: &Self,
778        bias: Option<&Self>,
779        stride: (usize, usize),
780        padding: (usize, usize),
781        dilation: (usize, usize),
782    ) -> Result<Self> {
783        // Step 1: Depthwise convolution
784        let depthwise_output = self.depthwise_conv2d(
785            depthwise_weight,
786            None, // No bias in depthwise step
787            stride,
788            padding,
789            dilation,
790        )?;
791
792        // Step 2: Pointwise (1x1) convolution
793        let output = depthwise_output.conv2d(
794            pointwise_weight,
795            bias,
796            (1, 1), // stride = 1 for pointwise
797            (0, 0), // padding = 0 for pointwise
798            (1, 1), // dilation = 1 for pointwise
799            1,      // groups = 1 for pointwise
800        )?;
801
802        // Track operation for autograd
803        if self.requires_grad
804            || depthwise_weight.requires_grad
805            || pointwise_weight.requires_grad
806            || (bias.is_some() && bias.expect("bias checked with is_some").requires_grad)
807        {
808            use std::sync::Arc;
809            let mut tracked_output = output;
810            tracked_output.requires_grad = true;
811            tracked_output.operation = crate::Operation::Custom(
812                "separable_conv2d".to_string(),
813                vec![
814                    Arc::downgrade(&Arc::new(self.clone())),
815                    Arc::downgrade(&Arc::new(depthwise_weight.clone())),
816                    Arc::downgrade(&Arc::new(pointwise_weight.clone())),
817                ],
818            );
819            Ok(tracked_output)
820        } else {
821            Ok(output)
822        }
823    }
824
825    /// Transposed (deconvolution) 2D convolution operation
826    #[allow(clippy::too_many_arguments)]
827    pub fn conv_transpose2d(
828        &self,
829        weight: &Self,
830        bias: Option<&Self>,
831        stride: (usize, usize),
832        padding: (usize, usize),
833        output_padding: (usize, usize),
834        dilation: (usize, usize),
835        groups: usize,
836    ) -> Result<Self> {
837        // Input shape: (N, C_in, H, W)
838        // Weight shape: (C_in, C_out/groups, kernel_h, kernel_w)
839        // Output shape: (N, C_out, H_out, W_out)
840
841        let input_shape_obj = self.shape();
842        let input_shape = input_shape_obj.dims();
843        let weight_shape_obj = weight.shape();
844        let weight_shape = weight_shape_obj.dims();
845
846        if input_shape.len() != 4 {
847            return Err(TorshError::InvalidArgument(format!(
848                "Expected 4D input tensor for conv_transpose2d, got {}D",
849                input_shape.len()
850            )));
851        }
852
853        if weight_shape.len() != 4 {
854            return Err(TorshError::InvalidArgument(format!(
855                "Expected 4D weight tensor for conv_transpose2d, got {}D",
856                weight_shape.len()
857            )));
858        }
859
860        let batch_size = input_shape[0];
861        let in_channels = input_shape[1];
862        let input_height = input_shape[2];
863        let input_width = input_shape[3];
864
865        let out_channels = weight_shape[1] * groups;
866        let kernel_height = weight_shape[2];
867        let kernel_width = weight_shape[3];
868
869        // Check groups
870        if in_channels % groups != 0 || out_channels % groups != 0 {
871            return Err(TorshError::InvalidArgument(
872                "in_channels and out_channels must be divisible by groups".to_string(),
873            ));
874        }
875
876        if weight_shape[0] != in_channels {
877            return Err(TorshError::InvalidArgument(format!(
878                "Weight tensor has wrong number of input channels: expected {}, got {}",
879                in_channels, weight_shape[0]
880            )));
881        }
882
883        // Calculate output dimensions
884        let effective_kernel_h = (kernel_height - 1) * dilation.0 + 1;
885        let effective_kernel_w = (kernel_width - 1) * dilation.1 + 1;
886        let output_height =
887            (input_height - 1) * stride.0 - 2 * padding.0 + effective_kernel_h + output_padding.0;
888        let output_width =
889            (input_width - 1) * stride.1 - 2 * padding.1 + effective_kernel_w + output_padding.1;
890
891        // Initialize output
892        let mut output_data = vec![
893            <T as TensorElement>::zero();
894            batch_size * out_channels * output_height * output_width
895        ];
896
897        let self_data = self.to_vec()?;
898        let weight_data = weight.to_vec()?;
899
900        // Perform transposed convolution
901        for n in 0..batch_size {
902            for g in 0..groups {
903                let in_ch_start = g * (in_channels / groups);
904                let in_ch_end = (g + 1) * (in_channels / groups);
905                let out_ch_start = g * (out_channels / groups);
906                let out_ch_end = (g + 1) * (out_channels / groups);
907
908                for ic in in_ch_start..in_ch_end {
909                    for ih in 0..input_height {
910                        for iw in 0..input_width {
911                            let input_val = self_data[n * in_channels * input_height * input_width
912                                + ic * input_height * input_width
913                                + ih * input_width
914                                + iw];
915
916                            for oc in out_ch_start..out_ch_end {
917                                let oc_rel = oc - out_ch_start;
918                                for kh in 0..kernel_height {
919                                    for kw in 0..kernel_width {
920                                        let oh = ih * stride.0 + kh * dilation.0;
921                                        let ow = iw * stride.1 + kw * dilation.1;
922
923                                        if oh >= padding.0 && ow >= padding.1 {
924                                            let oh_final = oh - padding.0;
925                                            let ow_final = ow - padding.1;
926
927                                            if oh_final < output_height && ow_final < output_width {
928                                                let weight_idx = ic
929                                                    * (out_channels / groups)
930                                                    * kernel_height
931                                                    * kernel_width
932                                                    + oc_rel * kernel_height * kernel_width
933                                                    + kh * kernel_width
934                                                    + kw;
935                                                let output_idx =
936                                                    n * out_channels * output_height * output_width
937                                                        + oc * output_height * output_width
938                                                        + oh_final * output_width
939                                                        + ow_final;
940
941                                                output_data[output_idx] = output_data[output_idx]
942                                                    + input_val * weight_data[weight_idx];
943                                            }
944                                        }
945                                    }
946                                }
947                            }
948                        }
949                    }
950                }
951            }
952        }
953
954        // Create output tensor
955        let mut output = Tensor::from_data(
956            output_data,
957            vec![batch_size, out_channels, output_height, output_width],
958            self.device(),
959        )?;
960
961        // Add bias if provided
962        if let Some(b) = bias {
963            if b.shape().dims() != [out_channels] {
964                return Err(TorshError::InvalidArgument(format!(
965                    "Bias must have shape [{}], got {:?}",
966                    out_channels,
967                    b.shape().dims()
968                )));
969            }
970
971            let bias_data = b.to_vec()?;
972
973            let mut output_data = output.to_vec()?;
974
975            for n in 0..batch_size {
976                #[allow(clippy::needless_range_loop)]
977                for oc in 0..out_channels {
978                    for oh in 0..output_height {
979                        for ow in 0..output_width {
980                            let idx = n * out_channels * output_height * output_width
981                                + oc * output_height * output_width
982                                + oh * output_width
983                                + ow;
984                            output_data[idx] = output_data[idx] + bias_data[oc];
985                        }
986                    }
987                }
988            }
989
990            // Create new output tensor with bias added
991            output = Tensor::from_data(
992                output_data,
993                vec![batch_size, out_channels, output_height, output_width],
994                self.device(),
995            )?;
996        }
997
998        // Track operation for autograd
999        if self.requires_grad
1000            || weight.requires_grad
1001            || (bias.is_some() && bias.expect("bias checked with is_some").requires_grad)
1002        {
1003            use std::sync::Arc;
1004            output.requires_grad = true;
1005            output.operation = crate::Operation::Custom(
1006                "conv_transpose2d".to_string(),
1007                vec![
1008                    Arc::downgrade(&Arc::new(self.clone())),
1009                    Arc::downgrade(&Arc::new(weight.clone())),
1010                ],
1011            );
1012        }
1013
1014        Ok(output)
1015    }
1016
1017    /// 1D cross-correlation operation
1018    /// Computes the cross-correlation between two 1D signals
1019    #[allow(clippy::needless_range_loop)]
1020    pub fn xcorr1d(&self, other: &Self, mode: CorrelationMode) -> Result<Self> {
1021        let self_shape_ref = self.shape();
1022        let other_shape_ref = other.shape();
1023        let self_shape = self_shape_ref.dims();
1024        let other_shape = other_shape_ref.dims();
1025
1026        if self_shape.len() != 1 || other_shape.len() != 1 {
1027            return Err(TorshError::InvalidArgument(
1028                "xcorr1d requires 1D tensors".to_string(),
1029            ));
1030        }
1031
1032        let n = self_shape[0];
1033        let m = other_shape[0];
1034
1035        let (output_size, lag_start) = match mode {
1036            CorrelationMode::Full => (n + m - 1, -(m as i32 - 1)),
1037            CorrelationMode::Valid => {
1038                if n < m || m < n {
1039                    return Err(TorshError::InvalidArgument(
1040                        "Valid mode requires both tensors to have the same size or one to be smaller".to_string(),
1041                    ));
1042                }
1043                (std::cmp::max(n, m) - std::cmp::min(n, m) + 1, 0)
1044            }
1045            CorrelationMode::Same => (n, -((m as i32 - 1) / 2)),
1046        };
1047
1048        let mut output_data = vec![<T as TensorElement>::zero(); output_size];
1049        let self_data = self.to_vec()?;
1050        let other_data = other.to_vec()?;
1051
1052        // Compute cross-correlation: (f ★ g)[lag] = Σ_i f[i] * g[i - lag]
1053        for i in 0..output_size {
1054            let mut sum = <T as TensorElement>::zero();
1055            let lag = lag_start + i as i32;
1056
1057            for j in 0..n {
1058                let other_idx = j as i32 - lag;
1059                if other_idx >= 0 && (other_idx as usize) < m {
1060                    sum = sum + self_data[j] * other_data[other_idx as usize];
1061                }
1062            }
1063            output_data[i] = sum;
1064        }
1065
1066        let output = Tensor::from_data(output_data, vec![output_size], self.device())?;
1067
1068        Ok(output)
1069    }
1070
1071    /// 1D auto-correlation operation
1072    /// Computes the auto-correlation of a 1D signal
1073    pub fn autocorr1d(&self, max_lag: Option<usize>) -> Result<Self> {
1074        let shape_ref = self.shape();
1075        let shape = shape_ref.dims();
1076        if shape.len() != 1 {
1077            return Err(TorshError::InvalidArgument(
1078                "autocorr1d requires 1D tensor".to_string(),
1079            ));
1080        }
1081
1082        let n = shape[0];
1083        let max_lag = max_lag.unwrap_or(n - 1).min(n - 1);
1084
1085        let self_data = self.to_vec()?;
1086        let mut output_data = Vec::with_capacity(max_lag + 1);
1087
1088        // Directly compute auto-correlation: R[k] = Σ_n x[n] * x[n-k]
1089        for lag in 0..=max_lag {
1090            let mut sum = <T as TensorElement>::zero();
1091
1092            for i in lag..n {
1093                sum = sum + self_data[i] * self_data[i - lag];
1094            }
1095
1096            output_data.push(sum);
1097        }
1098
1099        let output = Tensor::from_data(output_data, vec![max_lag + 1], self.device())?;
1100        Ok(output)
1101    }
1102
1103    /// 2D cross-correlation operation
1104    /// Computes the 2D cross-correlation between two signals
1105    pub fn xcorr2d(&self, other: &Self, mode: CorrelationMode) -> Result<Self> {
1106        let self_shape_ref = self.shape();
1107        let other_shape_ref = other.shape();
1108        let self_shape = self_shape_ref.dims();
1109        let other_shape = other_shape_ref.dims();
1110
1111        if self_shape.len() != 2 || other_shape.len() != 2 {
1112            return Err(TorshError::InvalidArgument(
1113                "xcorr2d requires 2D tensors".to_string(),
1114            ));
1115        }
1116
1117        let (h1, w1) = (self_shape[0], self_shape[1]);
1118        let (h2, w2) = (other_shape[0], other_shape[1]);
1119
1120        let (out_h, out_w, start_h, start_w) = match mode {
1121            CorrelationMode::Full => (h1 + h2 - 1, w1 + w2 - 1, 0, 0),
1122            CorrelationMode::Valid => {
1123                if h1 < h2 || w1 < w2 {
1124                    return Err(TorshError::InvalidArgument(
1125                        "Valid mode requires first tensor to be larger than or equal to second"
1126                            .to_string(),
1127                    ));
1128                }
1129                (h1 - h2 + 1, w1 - w2 + 1, h2 - 1, w2 - 1)
1130            }
1131            CorrelationMode::Same => (h1, w1, (h2 - 1) / 2, (w2 - 1) / 2),
1132        };
1133
1134        let mut output_data = vec![<T as TensorElement>::zero(); out_h * out_w];
1135        let self_data = self.to_vec()?;
1136        let other_data = other.to_vec()?;
1137
1138        // Compute 2D cross-correlation
1139        for i in 0..out_h {
1140            for j in 0..out_w {
1141                let mut sum = <T as TensorElement>::zero();
1142                let actual_i = i + start_h;
1143                let actual_j = j + start_w;
1144
1145                for ki in 0..h2 {
1146                    for kj in 0..w2 {
1147                        let src_i = actual_i as i32 - ki as i32;
1148                        let src_j = actual_j as i32 - kj as i32;
1149
1150                        if src_i >= 0
1151                            && (src_i as usize) < h1
1152                            && src_j >= 0
1153                            && (src_j as usize) < w1
1154                        {
1155                            let self_idx = src_i as usize * w1 + src_j as usize;
1156                            let other_idx = ki * w2 + kj;
1157                            sum = sum + self_data[self_idx] * other_data[other_idx];
1158                        }
1159                    }
1160                }
1161                output_data[i * out_w + j] = sum;
1162            }
1163        }
1164
1165        let output = Tensor::from_data(output_data, vec![out_h, out_w], self.device())?;
1166        Ok(output)
1167    }
1168
1169    /// 1D median filter
1170    /// Applies a median filter with the specified window size
1171    pub fn median_filter1d(&self, window_size: usize) -> Result<Self> {
1172        let shape_ref = self.shape();
1173        let shape = shape_ref.dims();
1174        if shape.len() != 1 {
1175            return Err(TorshError::InvalidArgument(
1176                "median_filter1d requires 1D tensor".to_string(),
1177            ));
1178        }
1179
1180        if window_size == 0 || window_size % 2 == 0 {
1181            return Err(TorshError::InvalidArgument(
1182                "Window size must be odd and greater than 0".to_string(),
1183            ));
1184        }
1185
1186        let n = shape[0];
1187        let half_window = window_size / 2;
1188        let mut output_data = Vec::with_capacity(n);
1189        let self_data = self.to_vec()?;
1190
1191        for i in 0..n {
1192            let mut window_values = Vec::new();
1193
1194            // Collect values in the window (with padding by repeating edge values)
1195            for j in 0..window_size {
1196                let idx = i as i32 + j as i32 - half_window as i32;
1197                let actual_idx = if idx < 0 {
1198                    0
1199                } else if idx >= n as i32 {
1200                    n - 1
1201                } else {
1202                    idx as usize
1203                };
1204                window_values.push(self_data[actual_idx]);
1205            }
1206
1207            // Sort to find median
1208            window_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1209            output_data.push(window_values[half_window]);
1210        }
1211
1212        let output = Tensor::from_data(output_data, vec![n], self.device())?;
1213        Ok(output)
1214    }
1215
1216    /// 2D median filter
1217    /// Applies a 2D median filter with the specified window size
1218    pub fn median_filter2d(&self, window_size: (usize, usize)) -> Result<Self> {
1219        let shape_ref = self.shape();
1220        let shape = shape_ref.dims();
1221        if shape.len() != 2 {
1222            return Err(TorshError::InvalidArgument(
1223                "median_filter2d requires 2D tensor".to_string(),
1224            ));
1225        }
1226
1227        let (window_h, window_w) = window_size;
1228        if window_h == 0 || window_w == 0 || window_h % 2 == 0 || window_w % 2 == 0 {
1229            return Err(TorshError::InvalidArgument(
1230                "Window dimensions must be odd and greater than 0".to_string(),
1231            ));
1232        }
1233
1234        let (h, w) = (shape[0], shape[1]);
1235        let half_h = window_h / 2;
1236        let half_w = window_w / 2;
1237        let mut output_data = Vec::with_capacity(h * w);
1238        let self_data = self.to_vec()?;
1239
1240        for i in 0..h {
1241            for j in 0..w {
1242                let mut window_values = Vec::new();
1243
1244                // Collect values in the 2D window
1245                for di in 0..window_h {
1246                    for dj in 0..window_w {
1247                        let row = i as i32 + di as i32 - half_h as i32;
1248                        let col = j as i32 + dj as i32 - half_w as i32;
1249
1250                        // Handle boundaries by clamping
1251                        let actual_row = row.max(0).min(h as i32 - 1) as usize;
1252                        let actual_col = col.max(0).min(w as i32 - 1) as usize;
1253
1254                        window_values.push(self_data[actual_row * w + actual_col]);
1255                    }
1256                }
1257
1258                // Sort to find median
1259                window_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1260                output_data.push(window_values[window_values.len() / 2]);
1261            }
1262        }
1263
1264        let output = Tensor::from_data(output_data, vec![h, w], self.device())?;
1265        Ok(output)
1266    }
1267
1268    /// 1D Gaussian filter
1269    /// Applies a Gaussian filter with specified sigma (standard deviation)
1270    pub fn gaussian_filter1d(&self, sigma: f32, kernel_size: Option<usize>) -> Result<Self> {
1271        let tensor_shape = self.shape();
1272        let shape = tensor_shape.dims();
1273        if shape.len() != 1 {
1274            return Err(TorshError::InvalidArgument(
1275                "gaussian_filter1d requires 1D tensor".to_string(),
1276            ));
1277        }
1278
1279        if sigma <= 0.0 {
1280            return Err(TorshError::InvalidArgument(
1281                "Sigma must be positive".to_string(),
1282            ));
1283        }
1284
1285        // Calculate kernel size if not provided (6 sigma rule)
1286        let kernel_size = kernel_size.unwrap_or(((6.0 * sigma) as usize).max(3));
1287        let kernel_size = if kernel_size % 2 == 0 {
1288            kernel_size + 1
1289        } else {
1290            kernel_size
1291        };
1292
1293        // Generate Gaussian kernel
1294        let half_size = kernel_size / 2;
1295        let mut kernel = Vec::with_capacity(kernel_size);
1296        let mut sum = 0.0f32;
1297
1298        for i in 0..kernel_size {
1299            let x = i as f32 - half_size as f32;
1300            let value = (-0.5 * (x / sigma).powi(2)).exp();
1301            kernel.push(value);
1302            sum += value;
1303        }
1304
1305        // Normalize kernel
1306        for value in &mut kernel {
1307            *value /= sum;
1308        }
1309
1310        // Create kernel tensor
1311        let kernel_data: Vec<T> = kernel
1312            .into_iter()
1313            .map(|v| {
1314                T::from(v as f64)
1315                    .unwrap_or_else(|| T::from(0.0).expect("numeric conversion should succeed"))
1316            })
1317            .collect();
1318        let kernel_tensor = Tensor::from_data(kernel_data, vec![kernel_size], self.device())?;
1319
1320        // Apply convolution (which is equivalent to correlation for symmetric kernels)
1321        self.xcorr1d(&kernel_tensor, CorrelationMode::Same)
1322    }
1323
1324    /// 2D Gaussian filter
1325    /// Applies a 2D Gaussian filter with specified sigma values
1326    pub fn gaussian_filter2d(
1327        &self,
1328        sigma: (f32, f32),
1329        kernel_size: Option<(usize, usize)>,
1330    ) -> Result<Self> {
1331        let tensor_shape = self.shape();
1332        let shape = tensor_shape.dims();
1333        if shape.len() != 2 {
1334            return Err(TorshError::InvalidArgument(
1335                "gaussian_filter2d requires 2D tensor".to_string(),
1336            ));
1337        }
1338
1339        let (sigma_x, sigma_y) = sigma;
1340        if sigma_x <= 0.0 || sigma_y <= 0.0 {
1341            return Err(TorshError::InvalidArgument(
1342                "Sigma values must be positive".to_string(),
1343            ));
1344        }
1345
1346        // Calculate kernel sizes if not provided
1347        let (kernel_h, kernel_w) = kernel_size.unwrap_or((
1348            ((6.0 * sigma_y) as usize).max(3),
1349            ((6.0 * sigma_x) as usize).max(3),
1350        ));
1351        let kernel_h = if kernel_h % 2 == 0 {
1352            kernel_h + 1
1353        } else {
1354            kernel_h
1355        };
1356        let kernel_w = if kernel_w % 2 == 0 {
1357            kernel_w + 1
1358        } else {
1359            kernel_w
1360        };
1361
1362        // Generate 2D Gaussian kernel
1363        let half_h = kernel_h / 2;
1364        let half_w = kernel_w / 2;
1365        let mut kernel = Vec::with_capacity(kernel_h * kernel_w);
1366        let mut sum = 0.0f32;
1367
1368        for i in 0..kernel_h {
1369            for j in 0..kernel_w {
1370                let y = i as f32 - half_h as f32;
1371                let x = j as f32 - half_w as f32;
1372                let value = (-0.5 * ((x / sigma_x).powi(2) + (y / sigma_y).powi(2))).exp();
1373                kernel.push(value);
1374                sum += value;
1375            }
1376        }
1377
1378        // Normalize kernel
1379        for value in &mut kernel {
1380            *value /= sum;
1381        }
1382
1383        // Create kernel tensor
1384        let kernel_data: Vec<T> = kernel
1385            .into_iter()
1386            .map(|v| {
1387                T::from(v as f64)
1388                    .unwrap_or_else(|| T::from(0.0).expect("numeric conversion should succeed"))
1389            })
1390            .collect();
1391        let kernel_tensor =
1392            Tensor::from_data(kernel_data, vec![kernel_h, kernel_w], self.device())?;
1393
1394        // Apply 2D correlation
1395        self.xcorr2d(&kernel_tensor, CorrelationMode::Same)
1396    }
1397}
1398
1399/// Correlation modes for signal processing operations
1400#[derive(Debug, Clone, Copy, PartialEq)]
1401pub enum CorrelationMode {
1402    /// Full correlation output
1403    Full,
1404    /// Valid correlation output (no padding)
1405    Valid,
1406    /// Same size as input (with padding)
1407    Same,
1408}