svod_tensor/
transformer.rs1use crate::Tensor;
4use bon::bon;
5use snafu::ensure;
6use svod_dtype::DType;
7use svod_ir::ConstValue;
8
9use crate::error::FloatDTypeRequiredSnafu;
10
11type Result<T> = crate::Result<T>;
12
13impl Tensor {
14 pub fn embedding(&self, indices: &Tensor) -> Result<Tensor> {
17 let weight_shape = self.shape()?;
18 let embed_dim = weight_shape[1].as_const().expect("embedding weight dim 1 must be concrete") as isize;
19 let idx_shape = indices.shape()?;
20
21 let flat = indices.try_reshape([-1])?;
22 let expanded = flat.try_unsqueeze(-1)?.try_expand([-1, embed_dim])?;
23 let gathered = self.gather(0, &expanded)?;
24
25 let mut out_shape: Vec<isize> =
26 idx_shape.iter().map(|d| d.as_const().expect("embedding index dims must be concrete") as isize).collect();
27 out_shape.push(embed_dim);
28 gathered.try_reshape(&out_shape)
29 }
30
31 pub fn apply_rotary_emb(&self, cos: &Tensor, sin: &Tensor, interleaved: bool) -> Result<Tensor> {
37 let shape = self.shape()?;
38 let last_dim = shape
39 .last()
40 .expect("apply_rotary_emb requires non-scalar input")
41 .as_const()
42 .expect("last dim must be concrete");
43 let half = last_dim / 2;
44
45 let (x1, x2) = if interleaved {
46 let mut rs: Vec<isize> = shape
47 .iter()
48 .take(shape.len() - 1)
49 .map(|d| d.as_const().expect("dims must be concrete") as isize)
50 .collect();
51 rs.push(half as isize);
52 rs.push(2);
53 let r = self.try_reshape(&rs)?;
54 let p = r.split(&[1, 1], -1)?;
55 (p[0].try_squeeze(Some(-1))?, p[1].try_squeeze(Some(-1))?)
56 } else {
57 let p = self.split(&[half, half], -1)?;
58 (p[0].clone(), p[1].clone())
59 };
60
61 let real = x1.try_mul(cos)?.try_sub(&x2.try_mul(sin)?)?;
62 let imag = x1.try_mul(sin)?.try_add(&x2.try_mul(cos)?)?;
63
64 if interleaved {
65 let stacked = Tensor::stack(&[&real, &imag], -1)?;
66 let mut fs: Vec<isize> = shape.iter().map(|d| d.as_const().unwrap() as isize).collect();
67 let _ = fs.last_mut().map(|d| *d = last_dim as isize);
69 stacked.try_reshape(&fs)
70 } else {
71 Tensor::cat(&[&real, &imag], -1)
72 }
73 }
74}
75
76#[bon]
77impl Tensor {
78 #[builder]
82 pub fn scaled_dot_product_attention(
83 &self,
84 key: &Tensor,
85 value: &Tensor,
86 attn_mask: Option<&Tensor>,
87 scale: Option<f64>,
88 #[builder(default)] is_causal: bool,
89 softcap: Option<f64>,
90 ) -> Result<Tensor> {
91 let q_dtype = self.uop().dtype();
92 ensure!(
93 q_dtype.is_float(),
94 FloatDTypeRequiredSnafu { op: "scaled_dot_product_attention", arg: "query", dtype: q_dtype.clone() }
95 );
96 let k_dtype = key.uop().dtype();
97 ensure!(
98 k_dtype.is_float(),
99 FloatDTypeRequiredSnafu { op: "scaled_dot_product_attention", arg: "key", dtype: k_dtype.clone() }
100 );
101 let v_dtype = value.uop().dtype();
102 ensure!(
103 v_dtype.is_float(),
104 FloatDTypeRequiredSnafu { op: "scaled_dot_product_attention", arg: "value", dtype: v_dtype.clone() }
105 );
106
107 let q_shape = self.shape()?;
108 let k_shape = key.shape()?;
109 let head_dim = q_shape[q_shape.len() - 1].as_const().expect("Q head_dim must be concrete");
110 let scale_val = scale.unwrap_or(1.0 / (head_dim as f64).sqrt());
111
112 let scores_dtype = self.uop().dtype();
113
114 let kt = key.try_transpose(-1, -2)?;
116 let mut scores = self.matmul(&kt)?;
117
118 let scale_t = Tensor::const_(scale_val, scores_dtype.clone());
120 scores = scores.try_mul(&scale_t)?;
121
122 if is_causal {
124 let q_len = q_shape[q_shape.len() - 2].as_const().expect("Q seq_len must be concrete");
125 let k_len = k_shape[k_shape.len() - 2].as_const().expect("K seq_len must be concrete");
126 let causal = Tensor::full(&[q_len, k_len], true, DType::Bool)?.tril(0)?;
127 let neg_large = Tensor::const_(ConstValue::min(scores_dtype.base()), scores_dtype.clone());
128 scores = scores.where_(&causal, &neg_large)?;
129 }
130
131 let mut bool_mask: Option<Tensor> = None;
133 if let Some(mask) = attn_mask {
134 let mask_dtype = mask.uop().dtype();
135 if mask_dtype == DType::Bool {
136 let neg_large = Tensor::const_(ConstValue::min(scores_dtype.base()), scores_dtype.clone());
138 let zero = Tensor::const_(ConstValue::zero(scores_dtype.base()), scores_dtype.clone());
139 let additive = neg_large.where_(mask, &zero)?;
140 scores = scores.try_add(&additive)?;
141 bool_mask = Some(mask.clone());
142 } else {
143 scores = scores.try_add(mask)?;
145 }
146 }
147
148 if let Some(cap) = softcap
150 && cap > 0.0
151 {
152 let cap_t = Tensor::const_(cap, scores_dtype.clone());
153 scores = scores.try_div(&cap_t)?.tanh()?.try_mul(&cap_t)?;
154 }
155
156 let mut attn_weights = scores.softmax(-1isize)?;
158 if let Some(mask) = bool_mask.as_ref() {
159 let zero = Tensor::const_(ConstValue::zero(scores_dtype.base()), scores_dtype);
160 attn_weights = zero.where_(mask, &attn_weights)?;
161 }
162 attn_weights.matmul(value)
163 }
164}