svod_tensor/matmul.rs
1//! Matrix multiplication and linear transformations.
2//!
3//! This module provides dot product and matrix multiplication operations
4//! following Tinygrad's implementation strategy.
5
6use std::iter;
7
8use bon::bon;
9use snafu::{ResultExt, ensure};
10use svod_dtype::DType;
11use svod_ir::{SInt, shape::Shape};
12
13use crate::{Result, Tensor, UOpSnafu, error::*};
14
15impl Tensor {
16 /// Dot product / matrix multiplication.
17 ///
18 /// Core method following Tinygrad's API:
19 /// - 1D @ 1D: dot product (scalar)
20 /// - 2D @ 2D: matrix multiplication
21 /// - 1D @ 2D: vector @ matrix
22 /// - 2D @ 1D: matrix @ vector
23 /// - 3D+: batched matmul (batch dims broadcast)
24 ///
25 /// # Arguments
26 /// * `other` - Right-hand tensor
27 ///
28 /// # Examples
29 /// ```ignore
30 /// // Vector dot product
31 /// let a = Tensor::from_slice(&[1.0f32, 2.0, 3.0]);
32 /// let b = Tensor::from_slice(&[4.0f32, 5.0, 6.0]);
33 /// let result = a.dot(&b)?; // scalar: 32.0
34 ///
35 /// // Matrix multiplication
36 /// let a = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0]).try_reshape(&[2, 2])?;
37 /// let b = Tensor::from_slice(&[5.0f32, 6.0, 7.0, 8.0]).try_reshape(&[2, 2])?;
38 /// let result = a.dot(&b)?; // [2, 2]
39 /// ```
40 pub fn dot(&self, other: &Tensor) -> Result<Tensor> {
41 self.matmul_with().other(other).call()
42 }
43
44 /// Matrix multiplication (alias for dot).
45 ///
46 /// Matches PyTorch API. Equivalent to `self.dot(other)`.
47 ///
48 /// # Examples
49 /// ```ignore
50 /// let a = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0]).try_reshape(&[2, 2])?;
51 /// let b = Tensor::from_slice(&[5.0f32, 6.0, 7.0, 8.0]).try_reshape(&[2, 2])?;
52 /// let result = a.matmul(&b)?;
53 /// ```
54 pub fn matmul(&self, other: &Tensor) -> Result<Tensor> {
55 self.matmul_with().other(other).call()
56 }
57}
58
59/// Build matmul broadcast shape by inserting broadcast dimensions.
60///
61/// Constructs: shape[..prefix_len] + [1; broadcast_dims] + shape[tail_start..]
62fn build_matmul_broadcast_shape(shape: &Shape, prefix_len: usize, broadcast_dims: usize, tail_start: usize) -> Shape {
63 shape[..prefix_len]
64 .iter()
65 .cloned()
66 .chain(iter::repeat_n(SInt::Const(1), broadcast_dims))
67 .chain(shape[tail_start..].iter().cloned())
68 .collect()
69}
70
71#[bon]
72impl Tensor {
73 /// Matrix multiplication with optional dtype.
74 ///
75 /// # Examples
76 /// ```ignore
77 /// let a = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0]).try_reshape(&[2, 2])?;
78 /// let b = Tensor::from_slice(&[5.0f32, 6.0, 7.0, 8.0]).try_reshape(&[2, 2])?;
79 /// let result = a.matmul_with(&b).dtype(DType::Float64).call()?;
80 /// ```
81 #[builder]
82 pub fn matmul_with(&self, other: &Tensor, dtype: Option<DType>) -> Result<Tensor> {
83 // Step 1: Check dimensions
84 let (dx, dw) = (self.ndim()?, other.ndim()?);
85 ensure!(dx != 0 && dw != 0, DotDimensionSnafu { lhs_dims: dx, rhs_dims: dw });
86
87 let x_shape = self.shape()?;
88 let w_shape = other.shape()?;
89
90 // Step 2: Determine contraction axis and validate
91 let axis_w = -(dw.min(2) as isize);
92 ensure!(self.dim(-1)? == other.dim(axis_w)?, DotShapeMismatchSnafu { lhs_shape: x_shape, rhs_shape: w_shape });
93
94 // Step 3: Reshape for broadcasting
95 let broadcast_dims = (dx - 1).min(dw - 1).min(1);
96
97 // Reshape x: [..., K] → [..., 1, K]
98 let x_new_shape = build_matmul_broadcast_shape(&x_shape, dx - 1, broadcast_dims, dx - 1);
99 let x_reshaped = self.uop().try_reshape(&x_new_shape).map(Self::new).context(UOpSnafu)?;
100
101 // Reshape w: [..., K, N] → [..., 1, K, N]
102 let axis_w_pos = Tensor::normalize_axis(axis_w, dw)?;
103 let w_new_shape = build_matmul_broadcast_shape(&w_shape, dw.saturating_sub(2), broadcast_dims, axis_w_pos);
104 let w_reshaped = other.uop().try_reshape(&w_new_shape).map(Self::new).context(UOpSnafu)?;
105
106 // Step 4: Transpose, multiply, and sum
107 let product = x_reshaped.try_mul(&w_reshaped.try_transpose(-1, axis_w)?)?;
108
109 if let Some(dt) = dtype { product.sum_with().axes(-1).dtype(dt).call() } else { product.sum(-1) }
110 }
111
112 /// General Matrix Multiplication: alpha * A @ B + beta * C
113 #[builder]
114 pub fn gemm(
115 &self,
116 b: &Tensor,
117 #[builder(default = 1.0)] alpha: f32,
118 #[builder(default = 1.0)] beta: f32,
119 #[builder(default = false)] trans_a: bool,
120 #[builder(default = false)] trans_b: bool,
121 c: Option<&Tensor>,
122 ) -> Result<Tensor> {
123 let a = if trans_a { self.try_transpose(0, 1)? } else { self.clone() };
124 let b = if trans_b { b.try_transpose(0, 1)? } else { b.clone() };
125 let mut result = a.matmul(&b)?;
126 if alpha != 1.0 {
127 result = result.try_mul(&Tensor::from_slice([alpha]))?;
128 }
129 if let Some(c) = c {
130 let c = if beta != 1.0 { c.try_mul(&Tensor::from_slice([beta]))? } else { c.clone() };
131 result = result.try_add(&c)?;
132 }
133 Ok(result)
134 }
135
136 /// Linear transformation: `self @ weight.T + bias`.
137 ///
138 /// Common operation in neural networks (fully connected layers).
139 /// Follows PyTorch convention where weight has shape `[out_features, in_features]`
140 /// and is transposed before multiplication.
141 ///
142 /// # Arguments
143 /// * `weight` - Weight matrix (shape: `[out_features, in_features]`)
144 /// * `bias` - Optional bias vector (shape: `[out_features]`)
145 ///
146 /// # Shape Requirements
147 /// - self: `[..., in_features]`
148 /// - weight: `[out_features, in_features]`
149 /// - bias: `[out_features]` or None
150 /// - result: `[..., out_features]`
151 ///
152 /// # Examples
153 /// ```ignore
154 /// let input = Tensor::from_slice(&[1.0f32, 2.0, 3.0]).try_reshape(&[1, 3])?;
155 /// let weight = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]).try_reshape(&[2, 3])?;
156 /// let bias = Tensor::from_slice(&[0.1f32, 0.2f32]);
157 /// let result = input.linear().weight(&weight).bias(&bias).call()?;
158 /// // result shape: [1, 2]
159 /// ```
160 #[builder]
161 pub fn linear(&self, weight: &Tensor, bias: Option<&Tensor>, dtype: Option<DType>) -> Result<Tensor> {
162 let weight_shape = weight.shape()?;
163
164 // For 1D weight, use element-wise multiply (broadcast)
165 let result = if weight_shape.len() == 1 {
166 if let Some(dt) = dtype {
167 let casted = self.cast(dt)?;
168 casted.try_mul(weight)?
169 } else {
170 self.try_mul(weight)?
171 }
172 } else {
173 // For 2D+ weight, transpose it first (PyTorch convention)
174 // PyTorch Linear layer: x @ weight.T
175 let weight_t = weight.try_transpose(-1, -2)?;
176 self.matmul_with().other(&weight_t).maybe_dtype(dtype).call()?
177 };
178
179 // Add bias if provided
180 if let Some(bias_tensor) = bias {
181 let result_shape = result.shape()?;
182 let bias_broadcasted = bias_tensor.broadcast_to(&result_shape)?;
183 result.try_add(&bias_broadcasted)
184 } else {
185 Ok(result)
186 }
187 }
188}