1use torsh_core::{Result as TorshResult, TorshError};
12use torsh_tensor::{
13 creation::{ones, zeros},
14 Tensor,
15};
16
17#[derive(Debug, Clone, Copy, PartialEq)]
19pub enum PaddingMode {
20 Constant,
22 Reflect,
24 Replicate,
26 Circular,
28}
29
30pub fn pad(input: &Tensor, pad: &[usize], mode: PaddingMode, value: f32) -> TorshResult<Tensor> {
43 let input_shape_binding = input.shape();
44 let input_shape = input_shape_binding.dims();
45 let ndim = input_shape.len();
46
47 if pad.len() % 2 != 0 {
48 return Err(TorshError::invalid_argument_with_context(
49 "Padding specification must have even length",
50 "pad",
51 ));
52 }
53
54 if pad.len() / 2 > ndim {
55 return Err(TorshError::invalid_argument_with_context(
56 "Padding specification exceeds tensor dimensions",
57 "pad",
58 ));
59 }
60
61 let mut output_shape = input_shape.to_vec();
63 let pad_pairs = pad.len() / 2;
64
65 for i in 0..pad_pairs {
66 let dim_idx = ndim - 1 - i; let pad_left = pad[2 * i];
68 let pad_right = pad[2 * i + 1];
69 output_shape[dim_idx] += pad_left + pad_right;
70 }
71
72 let output = match mode {
74 PaddingMode::Constant => {
75 let mut result = zeros(&output_shape)?;
76 if value != 0.0 {
77 result = result.add_scalar(value)?;
78 }
79
80 let input_volume: usize = input_shape.iter().product();
84 let output_volume: usize = output_shape.iter().product();
85
86 if input_volume <= output_volume {
87 let _expanded = input.view(&[input_volume as i32])?;
89 let padded_flat = zeros(&[output_volume])?;
90 padded_flat.view(&output_shape.iter().map(|&x| x as i32).collect::<Vec<_>>())?
92 } else {
93 result
94 }
95 }
96
97 PaddingMode::Reflect => {
98 let result = zeros(&output_shape)?;
100 result
102 }
103
104 PaddingMode::Replicate => {
105 let result = zeros(&output_shape)?;
107 result
109 }
110
111 PaddingMode::Circular => {
112 let result = zeros(&output_shape)?;
114 result
116 }
117 };
118
119 Ok(output)
120}
121
122pub fn slice_with_step(
136 input: &Tensor,
137 dim: usize,
138 start: i32,
139 end: Option<i32>,
140 step: usize,
141) -> TorshResult<Tensor> {
142 let shape_binding = input.shape();
143 let shape = shape_binding.dims();
144
145 if dim >= shape.len() {
146 return Err(TorshError::invalid_argument_with_context(
147 "Dimension index out of bounds",
148 "slice_with_step",
149 ));
150 }
151
152 if step == 0 {
153 return Err(TorshError::invalid_argument_with_context(
154 "Step size must be positive",
155 "slice_with_step",
156 ));
157 }
158
159 let dim_size = shape[dim] as i32;
160
161 let norm_start = if start < 0 {
163 (dim_size + start).max(0)
164 } else {
165 start.min(dim_size)
166 };
167
168 let norm_end = if let Some(e) = end {
169 if e < 0 {
170 (dim_size + e).max(0)
171 } else {
172 e.min(dim_size)
173 }
174 } else {
175 dim_size
176 };
177
178 let slice_len = if norm_end > norm_start {
180 ((norm_end - norm_start + step as i32 - 1) / step as i32) as usize
181 } else {
182 0
183 };
184
185 let mut output_shape = shape.to_vec();
187 output_shape[dim] = slice_len;
188
189 let output_data = zeros(&output_shape)?;
192 Ok(output_data)
193}
194
195pub fn boolean_index(input: &Tensor, mask: &Tensor) -> TorshResult<Tensor> {
204 if input.shape().dims() != mask.shape().dims() {
205 return Err(TorshError::invalid_argument_with_context(
206 "Input and mask must have same shape",
207 "boolean_index",
208 ));
209 }
210
211 let mask_data = mask.sum()?.data()?;
219 let true_count = *mask_data.get(0).unwrap_or(&0.0) as usize;
220 let result = zeros(&[true_count])?;
221 Ok(result)
222}
223
224pub fn masked_fill(input: &Tensor, mask: &Tensor, fill_value: f32) -> TorshResult<Tensor> {
234 if input.shape().dims() != mask.shape().dims() {
235 return Err(TorshError::invalid_argument_with_context(
236 "Input and mask must have same shape",
237 "masked_fill",
238 ));
239 }
240
241 let ones_tensor = ones(&mask.shape().dims())?;
243 let inverted_mask = ones_tensor.sub(mask)?;
244 let masked_input = input.mul_op(&inverted_mask)?;
245 let fill_tensor = ones(&input.shape().dims())?.mul_scalar(fill_value)?;
246 let filled_values = fill_tensor.mul_op(mask)?;
247
248 masked_input.add_op(&filled_values)
249}
250
251pub fn where_tensor(condition: &Tensor, input: &Tensor, other: &Tensor) -> TorshResult<Tensor> {
261 if input.shape().dims() != other.shape().dims() {
263 return Err(TorshError::invalid_argument_with_context(
264 "Input and other tensors must have same shape",
265 "where_tensor",
266 ));
267 }
268
269 let ones_tensor = ones(&condition.shape().dims())?;
271 let inverted_condition = ones_tensor.sub(condition)?;
272 let selected_input = condition.mul_op(input)?;
273 let selected_other = inverted_condition.mul_op(other)?;
274
275 selected_input.add_op(&selected_other)
276}
277
278pub fn cat(tensors: &[Tensor], dim: usize) -> TorshResult<Tensor> {
287 if tensors.is_empty() {
288 return Err(TorshError::invalid_argument_with_context(
289 "Cannot concatenate empty list of tensors",
290 "cat",
291 ));
292 }
293
294 let first_shape_binding = tensors[0].shape();
295 let first_shape = first_shape_binding.dims();
296
297 if dim >= first_shape.len() {
298 return Err(TorshError::invalid_argument_with_context(
299 "Concatenation dimension out of bounds",
300 "cat",
301 ));
302 }
303
304 for (i, tensor) in tensors.iter().enumerate().skip(1) {
306 let shape_binding = tensor.shape();
307 let shape = shape_binding.dims();
308 if shape.len() != first_shape.len() {
309 return Err(TorshError::invalid_argument_with_context(
310 &format!("Tensor {} has incompatible number of dimensions", i),
311 "cat",
312 ));
313 }
314
315 for (j, (&s1, &s2)) in first_shape.iter().zip(shape.iter()).enumerate() {
316 if j != dim && s1 != s2 {
317 return Err(TorshError::invalid_argument_with_context(
318 &format!("Tensor {} has incompatible shape at dimension {}", i, j),
319 "cat",
320 ));
321 }
322 }
323 }
324
325 let mut output_shape = first_shape.to_vec();
327 output_shape[dim] = tensors.iter().map(|t| t.shape().dims()[dim]).sum();
328
329 let result = zeros(&output_shape)?;
332 Ok(result)
333}
334
335pub fn split(
345 input: &Tensor,
346 split_size_or_sections: &[usize],
347 dim: usize,
348) -> TorshResult<Vec<Tensor>> {
349 let shape_binding = input.shape();
350 let shape = shape_binding.dims();
351
352 if dim >= shape.len() {
353 return Err(TorshError::invalid_argument_with_context(
354 "Split dimension out of bounds",
355 "split",
356 ));
357 }
358
359 let dim_size = shape[dim];
360
361 let split_points = if split_size_or_sections.len() == 1 {
363 let chunk_size = split_size_or_sections[0];
365 let num_chunks = (dim_size + chunk_size - 1) / chunk_size;
366 (0..num_chunks)
367 .map(|i| chunk_size.min(dim_size - i * chunk_size))
368 .collect()
369 } else {
370 split_size_or_sections.to_vec()
372 };
373
374 let total_size: usize = split_points.iter().sum();
376 if total_size != dim_size {
377 return Err(TorshError::invalid_argument_with_context(
378 "Split sizes do not sum to dimension size",
379 "split",
380 ));
381 }
382
383 let mut results = Vec::new();
385 for &split_size in &split_points {
386 let mut chunk_shape = shape.to_vec();
387 chunk_shape[dim] = split_size;
388 results.push(zeros(&chunk_shape)?);
389 }
390
391 Ok(results)
392}
393
394pub fn reshape(input: &Tensor, shape: &[i32]) -> TorshResult<Tensor> {
403 let input_numel = input.numel();
404 let mut new_shape = shape.to_vec();
405
406 let neg_one_count = shape.iter().filter(|&&x| x == -1).count();
408 if neg_one_count > 1 {
409 return Err(TorshError::invalid_argument_with_context(
410 "Can only infer one dimension (use at most one -1)",
411 "reshape",
412 ));
413 }
414
415 if neg_one_count == 1 {
416 let known_size: i32 = shape.iter().filter(|&&x| x != -1).product();
417 if known_size == 0 {
418 return Err(TorshError::invalid_argument_with_context(
419 "Cannot infer dimension when other dimensions are zero",
420 "reshape",
421 ));
422 }
423
424 let inferred_size = input_numel as i32 / known_size;
425 if inferred_size * known_size != input_numel as i32 {
426 return Err(TorshError::invalid_argument_with_context(
427 "Cannot reshape tensor to requested shape",
428 "reshape",
429 ));
430 }
431
432 for dim in new_shape.iter_mut() {
434 if *dim == -1 {
435 *dim = inferred_size;
436 break;
437 }
438 }
439 }
440
441 let new_numel: i32 = new_shape.iter().product();
443 if new_numel != input_numel as i32 {
444 return Err(TorshError::invalid_argument_with_context(
445 "New shape is not compatible with input shape",
446 "reshape",
447 ));
448 }
449
450 input.view(&new_shape)
451}
452
453pub fn squeeze(input: &Tensor, dim: Option<usize>) -> TorshResult<Tensor> {
462 let shape_binding = input.shape();
463 let shape = shape_binding.dims();
464
465 let new_shape: Vec<i32> = if let Some(d) = dim {
466 if d >= shape.len() {
467 return Err(TorshError::invalid_argument_with_context(
468 "Dimension index out of bounds",
469 "squeeze",
470 ));
471 }
472 if shape[d] != 1 {
473 return Err(TorshError::invalid_argument_with_context(
474 "Cannot squeeze dimension that is not size 1",
475 "squeeze",
476 ));
477 }
478 shape
479 .iter()
480 .enumerate()
481 .filter(|(i, _)| *i != d)
482 .map(|(_, &s)| s as i32)
483 .collect()
484 } else {
485 shape
486 .iter()
487 .filter(|&&s| s != 1)
488 .map(|&s| s as i32)
489 .collect()
490 };
491
492 if new_shape.is_empty() {
493 input.view(&[])
495 } else {
496 input.view(&new_shape)
497 }
498}
499
500pub fn unsqueeze(input: &Tensor, dim: usize) -> TorshResult<Tensor> {
509 let shape_binding = input.shape();
510 let shape = shape_binding.dims();
511
512 if dim > shape.len() {
513 return Err(TorshError::invalid_argument_with_context(
514 "Dimension index out of bounds",
515 "unsqueeze",
516 ));
517 }
518
519 let mut new_shape: Vec<i32> = Vec::with_capacity(shape.len() + 1);
520 for (i, &s) in shape.iter().enumerate() {
521 if i == dim {
522 new_shape.push(1);
523 }
524 new_shape.push(s as i32);
525 }
526 if dim == shape.len() {
527 new_shape.push(1);
528 }
529
530 input.view(&new_shape)
531}
532
533#[cfg(test)]
534mod tests {
535 use super::*;
536 use crate::random_ops::randn;
537
538 #[test]
539 fn test_pad_constant() {
540 let input = randn(&[2, 3, 4], None, None, None).unwrap();
541 let padded = pad(&input, &[1, 1, 2, 2], PaddingMode::Constant, 0.0).unwrap();
542 assert_eq!(padded.shape().dims(), &[2, 7, 6]); }
544
545 #[test]
546 fn test_slice_with_step() {
547 let input = randn(&[10, 5], None, None, None).unwrap();
548 let sliced = slice_with_step(&input, 0, 1, Some(8), 2).unwrap();
549 assert_eq!(sliced.shape().dims()[0], 4);
551 assert_eq!(sliced.shape().dims()[1], 5);
552 }
553
554 #[test]
555 fn test_masked_fill() {
556 let input = randn(&[3, 3], None, None, None).unwrap();
557 let mask: Tensor<f32> = zeros(&[3, 3]).unwrap();
558 let filled = masked_fill(&input, &mask, 99.0).unwrap();
559 assert_eq!(filled.shape().dims(), input.shape().dims());
560 }
561
562 #[test]
563 fn test_cat() {
564 let t1 = randn(&[2, 3, 4], None, None, None).unwrap();
565 let t2 = randn(&[2, 3, 4], None, None, None).unwrap();
566 let t3 = randn(&[2, 3, 4], None, None, None).unwrap();
567
568 let result = cat(&[t1, t2, t3], 0).unwrap();
569 assert_eq!(result.shape().dims(), &[6, 3, 4]); }
571
572 #[test]
573 fn test_split() {
574 let input = randn(&[6, 3, 4], None, None, None).unwrap();
575 let chunks = split(&input, &[2], 0).unwrap(); assert_eq!(chunks.len(), 3);
577 for chunk in chunks {
578 assert_eq!(chunk.shape().dims(), &[2, 3, 4]);
579 }
580 }
581
582 #[test]
583 fn test_reshape() {
584 let input = randn(&[2, 3, 4], None, None, None).unwrap();
585 let reshaped = reshape(&input, &[6, -1]).unwrap(); assert_eq!(reshaped.shape().dims(), &[6, 4]);
587 }
588
589 #[test]
590 fn test_squeeze_unsqueeze() {
591 let input = randn(&[2, 1, 3, 1], None, None, None).unwrap();
592
593 let squeezed = squeeze(&input, None).unwrap();
595 assert_eq!(squeezed.shape().dims(), &[2, 3]);
596
597 let unsqueezed = unsqueeze(&squeezed, 1).unwrap();
599 assert_eq!(unsqueezed.shape().dims(), &[2, 1, 3]);
600 }
601
602 #[test]
603 fn test_where_tensor() {
604 let condition = ones(&[2, 3]).unwrap();
605 let input = randn(&[2, 3], None, None, None).unwrap();
606 let other = zeros(&[2, 3]).unwrap();
607
608 let result = where_tensor(&condition, &input, &other).unwrap();
609 assert_eq!(result.shape().dims(), &[2, 3]);
610 }
611}