Skip to main content

svod_tensor/
matmul.rs

1//! Matrix multiplication and linear transformations.
2//!
3//! This module provides dot product and matrix multiplication operations
4//! following Tinygrad's implementation strategy.
5
6use std::iter;
7
8use bon::bon;
9use snafu::{ResultExt, ensure};
10use svod_dtype::DType;
11use svod_ir::{SInt, shape::Shape};
12
13use crate::{Result, Tensor, UOpSnafu, error::*};
14
15impl Tensor {
16    /// Dot product / matrix multiplication.
17    ///
18    /// Core method following Tinygrad's API:
19    /// - 1D @ 1D: dot product (scalar)
20    /// - 2D @ 2D: matrix multiplication
21    /// - 1D @ 2D: vector @ matrix
22    /// - 2D @ 1D: matrix @ vector
23    /// - 3D+: batched matmul (batch dims broadcast)
24    ///
25    /// # Arguments
26    /// * `other` - Right-hand tensor
27    ///
28    /// # Examples
29    /// ```ignore
30    /// // Vector dot product
31    /// let a = Tensor::from_slice(&[1.0f32, 2.0, 3.0]);
32    /// let b = Tensor::from_slice(&[4.0f32, 5.0, 6.0]);
33    /// let result = a.dot(&b)?; // scalar: 32.0
34    ///
35    /// // Matrix multiplication
36    /// let a = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0]).try_reshape(&[2, 2])?;
37    /// let b = Tensor::from_slice(&[5.0f32, 6.0, 7.0, 8.0]).try_reshape(&[2, 2])?;
38    /// let result = a.dot(&b)?; // [2, 2]
39    /// ```
40    pub fn dot(&self, other: &Tensor) -> Result<Tensor> {
41        self.matmul_with().other(other).call()
42    }
43
44    /// Matrix multiplication (alias for dot).
45    ///
46    /// Matches PyTorch API. Equivalent to `self.dot(other)`.
47    ///
48    /// # Examples
49    /// ```ignore
50    /// let a = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0]).try_reshape(&[2, 2])?;
51    /// let b = Tensor::from_slice(&[5.0f32, 6.0, 7.0, 8.0]).try_reshape(&[2, 2])?;
52    /// let result = a.matmul(&b)?;
53    /// ```
54    pub fn matmul(&self, other: &Tensor) -> Result<Tensor> {
55        self.matmul_with().other(other).call()
56    }
57}
58
59/// Build matmul broadcast shape by inserting broadcast dimensions.
60///
61/// Constructs: shape[..prefix_len] + [1; broadcast_dims] + shape[tail_start..]
62fn build_matmul_broadcast_shape(shape: &Shape, prefix_len: usize, broadcast_dims: usize, tail_start: usize) -> Shape {
63    shape[..prefix_len]
64        .iter()
65        .cloned()
66        .chain(iter::repeat_n(SInt::Const(1), broadcast_dims))
67        .chain(shape[tail_start..].iter().cloned())
68        .collect()
69}
70
71#[bon]
72impl Tensor {
73    /// Matrix multiplication with optional dtype.
74    ///
75    /// # Examples
76    /// ```ignore
77    /// let a = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0]).try_reshape(&[2, 2])?;
78    /// let b = Tensor::from_slice(&[5.0f32, 6.0, 7.0, 8.0]).try_reshape(&[2, 2])?;
79    /// let result = a.matmul_with(&b).dtype(DType::Float64).call()?;
80    /// ```
81    #[builder]
82    pub fn matmul_with(&self, other: &Tensor, dtype: Option<DType>) -> Result<Tensor> {
83        // Step 1: Check dimensions
84        let (dx, dw) = (self.ndim()?, other.ndim()?);
85        ensure!(dx != 0 && dw != 0, DotDimensionSnafu { lhs_dims: dx, rhs_dims: dw });
86
87        let x_shape = self.shape()?;
88        let w_shape = other.shape()?;
89
90        // Step 2: Determine contraction axis and validate
91        let axis_w = -(dw.min(2) as isize);
92        ensure!(self.dim(-1)? == other.dim(axis_w)?, DotShapeMismatchSnafu { lhs_shape: x_shape, rhs_shape: w_shape });
93
94        // Step 3: Reshape for broadcasting
95        let broadcast_dims = (dx - 1).min(dw - 1).min(1);
96
97        // Reshape x: [..., K] → [..., 1, K]
98        let x_new_shape = build_matmul_broadcast_shape(&x_shape, dx - 1, broadcast_dims, dx - 1);
99        let x_reshaped = self.uop().try_reshape(&x_new_shape).map(Self::new).context(UOpSnafu)?;
100
101        // Reshape w: [..., K, N] → [..., 1, K, N]
102        let axis_w_pos = Tensor::normalize_axis(axis_w, dw)?;
103        let w_new_shape = build_matmul_broadcast_shape(&w_shape, dw.saturating_sub(2), broadcast_dims, axis_w_pos);
104        let w_reshaped = other.uop().try_reshape(&w_new_shape).map(Self::new).context(UOpSnafu)?;
105
106        // Step 4: Transpose, multiply, and sum
107        let product = x_reshaped.try_mul(&w_reshaped.try_transpose(-1, axis_w)?)?;
108
109        if let Some(dt) = dtype { product.sum_with().axes(-1).dtype(dt).call() } else { product.sum(-1) }
110    }
111
112    /// General Matrix Multiplication: alpha * A @ B + beta * C
113    #[builder]
114    pub fn gemm(
115        &self,
116        b: &Tensor,
117        #[builder(default = 1.0)] alpha: f32,
118        #[builder(default = 1.0)] beta: f32,
119        #[builder(default = false)] trans_a: bool,
120        #[builder(default = false)] trans_b: bool,
121        c: Option<&Tensor>,
122    ) -> Result<Tensor> {
123        let a = if trans_a { self.try_transpose(0, 1)? } else { self.clone() };
124        let b = if trans_b { b.try_transpose(0, 1)? } else { b.clone() };
125        let mut result = a.matmul(&b)?;
126        if alpha != 1.0 {
127            result = result.try_mul(&Tensor::from_slice([alpha]))?;
128        }
129        if let Some(c) = c {
130            let c = if beta != 1.0 { c.try_mul(&Tensor::from_slice([beta]))? } else { c.clone() };
131            result = result.try_add(&c)?;
132        }
133        Ok(result)
134    }
135
136    /// Linear transformation: `self @ weight.T + bias`.
137    ///
138    /// Common operation in neural networks (fully connected layers).
139    /// Follows PyTorch convention where weight has shape `[out_features, in_features]`
140    /// and is transposed before multiplication.
141    ///
142    /// # Arguments
143    /// * `weight` - Weight matrix (shape: `[out_features, in_features]`)
144    /// * `bias` - Optional bias vector (shape: `[out_features]`)
145    ///
146    /// # Shape Requirements
147    /// - self: `[..., in_features]`
148    /// - weight: `[out_features, in_features]`
149    /// - bias: `[out_features]` or None
150    /// - result: `[..., out_features]`
151    ///
152    /// # Examples
153    /// ```ignore
154    /// let input = Tensor::from_slice(&[1.0f32, 2.0, 3.0]).try_reshape(&[1, 3])?;
155    /// let weight = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]).try_reshape(&[2, 3])?;
156    /// let bias = Tensor::from_slice(&[0.1f32, 0.2f32]);
157    /// let result = input.linear().weight(&weight).bias(&bias).call()?;
158    /// // result shape: [1, 2]
159    /// ```
160    #[builder]
161    pub fn linear(&self, weight: &Tensor, bias: Option<&Tensor>, dtype: Option<DType>) -> Result<Tensor> {
162        let weight_shape = weight.shape()?;
163
164        // For 1D weight, use element-wise multiply (broadcast)
165        let result = if weight_shape.len() == 1 {
166            if let Some(dt) = dtype {
167                let casted = self.cast(dt)?;
168                casted.try_mul(weight)?
169            } else {
170                self.try_mul(weight)?
171            }
172        } else {
173            // For 2D+ weight, transpose it first (PyTorch convention)
174            // PyTorch Linear layer: x @ weight.T
175            let weight_t = weight.try_transpose(-1, -2)?;
176            self.matmul_with().other(&weight_t).maybe_dtype(dtype).call()?
177        };
178
179        // Add bias if provided
180        if let Some(bias_tensor) = bias {
181            let result_shape = result.shape()?;
182            let bias_broadcasted = bias_tensor.broadcast_to(&result_shape)?;
183            result.try_add(&bias_broadcasted)
184        } else {
185            Ok(result)
186        }
187    }
188}