Skip to main content

svod_tensor/
transformer.rs

1//! Transformer building blocks: embedding, attention, rotary position embeddings.
2
3use crate::Tensor;
4use bon::bon;
5use snafu::ensure;
6use svod_dtype::DType;
7use svod_ir::ConstValue;
8
9use crate::error::FloatDTypeRequiredSnafu;
10
11type Result<T> = crate::Result<T>;
12
13impl Tensor {
14    /// Embedding lookup: `self` is the weight table `[vocab_size, embed_dim]`.
15    /// Returns `self[indices]` with shape `[*indices.shape, embed_dim]`.
16    pub fn embedding(&self, indices: &Tensor) -> Result<Tensor> {
17        let weight_shape = self.shape()?;
18        let embed_dim = weight_shape[1].as_const().expect("embedding weight dim 1 must be concrete") as isize;
19        let idx_shape = indices.shape()?;
20
21        let flat = indices.try_reshape([-1])?;
22        let expanded = flat.try_unsqueeze(-1)?.try_expand([-1, embed_dim])?;
23        let gathered = self.gather(0, &expanded)?;
24
25        let mut out_shape: Vec<isize> =
26            idx_shape.iter().map(|d| d.as_const().expect("embedding index dims must be concrete") as isize).collect();
27        out_shape.push(embed_dim);
28        gathered.try_reshape(&out_shape)
29    }
30
31    /// Apply rotary position embedding rotation.
32    /// `self`: `[..., rot_dim]` tensor to rotate.
33    /// `cos`, `sin`: broadcastable to `self`'s shape `[..., rot_dim/2]`.
34    /// If interleaved: pairs are (even, odd) indices.
35    /// If not interleaved: pairs are (first_half, second_half).
36    pub fn apply_rotary_emb(&self, cos: &Tensor, sin: &Tensor, interleaved: bool) -> Result<Tensor> {
37        let shape = self.shape()?;
38        let last_dim = shape
39            .last()
40            .expect("apply_rotary_emb requires non-scalar input")
41            .as_const()
42            .expect("last dim must be concrete");
43        let half = last_dim / 2;
44
45        let (x1, x2) = if interleaved {
46            let mut rs: Vec<isize> = shape
47                .iter()
48                .take(shape.len() - 1)
49                .map(|d| d.as_const().expect("dims must be concrete") as isize)
50                .collect();
51            rs.push(half as isize);
52            rs.push(2);
53            let r = self.try_reshape(&rs)?;
54            let p = r.split(&[1, 1], -1)?;
55            (p[0].try_squeeze(Some(-1))?, p[1].try_squeeze(Some(-1))?)
56        } else {
57            let p = self.split(&[half, half], -1)?;
58            (p[0].clone(), p[1].clone())
59        };
60
61        let real = x1.try_mul(cos)?.try_sub(&x2.try_mul(sin)?)?;
62        let imag = x1.try_mul(sin)?.try_add(&x2.try_mul(cos)?)?;
63
64        if interleaved {
65            let stacked = Tensor::stack(&[&real, &imag], -1)?;
66            let mut fs: Vec<isize> = shape.iter().map(|d| d.as_const().unwrap() as isize).collect();
67            // Last dim already correct from original shape
68            let _ = fs.last_mut().map(|d| *d = last_dim as isize);
69            stacked.try_reshape(&fs)
70        } else {
71            Tensor::cat(&[&real, &imag], -1)
72        }
73    }
74}
75
76#[bon]
77impl Tensor {
78    /// Scaled dot-product attention.
79    /// `self` (Q): `[B, H, Sq, D]`, `key` (K): `[B, H, Sk, D]`, `value` (V): `[B, H, Sk, Dv]`.
80    /// Returns `[B, H, Sq, Dv]`.
81    #[builder]
82    pub fn scaled_dot_product_attention(
83        &self,
84        key: &Tensor,
85        value: &Tensor,
86        attn_mask: Option<&Tensor>,
87        scale: Option<f64>,
88        #[builder(default)] is_causal: bool,
89        softcap: Option<f64>,
90    ) -> Result<Tensor> {
91        let q_dtype = self.uop().dtype();
92        ensure!(
93            q_dtype.is_float(),
94            FloatDTypeRequiredSnafu { op: "scaled_dot_product_attention", arg: "query", dtype: q_dtype.clone() }
95        );
96        let k_dtype = key.uop().dtype();
97        ensure!(
98            k_dtype.is_float(),
99            FloatDTypeRequiredSnafu { op: "scaled_dot_product_attention", arg: "key", dtype: k_dtype.clone() }
100        );
101        let v_dtype = value.uop().dtype();
102        ensure!(
103            v_dtype.is_float(),
104            FloatDTypeRequiredSnafu { op: "scaled_dot_product_attention", arg: "value", dtype: v_dtype.clone() }
105        );
106
107        let q_shape = self.shape()?;
108        let k_shape = key.shape()?;
109        let head_dim = q_shape[q_shape.len() - 1].as_const().expect("Q head_dim must be concrete");
110        let scale_val = scale.unwrap_or(1.0 / (head_dim as f64).sqrt());
111
112        let scores_dtype = self.uop().dtype();
113
114        // Q @ K^T
115        let kt = key.try_transpose(-1, -2)?;
116        let mut scores = self.matmul(&kt)?;
117
118        // Scale
119        let scale_t = Tensor::const_(scale_val, scores_dtype.clone());
120        scores = scores.try_mul(&scale_t)?;
121
122        // Causal mask
123        if is_causal {
124            let q_len = q_shape[q_shape.len() - 2].as_const().expect("Q seq_len must be concrete");
125            let k_len = k_shape[k_shape.len() - 2].as_const().expect("K seq_len must be concrete");
126            let causal = Tensor::full(&[q_len, k_len], true, DType::Bool)?.tril(0)?;
127            let neg_large = Tensor::const_(ConstValue::min(scores_dtype.base()), scores_dtype.clone());
128            scores = scores.where_(&causal, &neg_large)?;
129        }
130
131        // Attention mask
132        let mut bool_mask: Option<Tensor> = None;
133        if let Some(mask) = attn_mask {
134            let mask_dtype = mask.uop().dtype();
135            if mask_dtype == DType::Bool {
136                // Bool mask: True = mask out, False = keep.
137                let neg_large = Tensor::const_(ConstValue::min(scores_dtype.base()), scores_dtype.clone());
138                let zero = Tensor::const_(ConstValue::zero(scores_dtype.base()), scores_dtype.clone());
139                let additive = neg_large.where_(mask, &zero)?;
140                scores = scores.try_add(&additive)?;
141                bool_mask = Some(mask.clone());
142            } else {
143                // Float additive mask
144                scores = scores.try_add(mask)?;
145            }
146        }
147
148        // Softcap
149        if let Some(cap) = softcap
150            && cap > 0.0
151        {
152            let cap_t = Tensor::const_(cap, scores_dtype.clone());
153            scores = scores.try_div(&cap_t)?.tanh()?.try_mul(&cap_t)?;
154        }
155
156        // Softmax + output
157        let mut attn_weights = scores.softmax(-1isize)?;
158        if let Some(mask) = bool_mask.as_ref() {
159            let zero = Tensor::const_(ConstValue::zero(scores_dtype.base()), scores_dtype);
160            attn_weights = zero.where_(mask, &attn_weights)?;
161        }
162        attn_weights.matmul(value)
163    }
164}