tiny_recursive_rs/layers/
positional.rs1use candle_core::{Result, Tensor, Device, DType, D};
5
6fn rotate_half(x: &Tensor) -> Result<Tensor> {
11 let last_dim = x.dims().len() - 1;
12 let dim_size = x.dim(last_dim)?;
13 let half = dim_size / 2;
14
15 let x1 = x.narrow(last_dim, 0, half)?;
17 let x2 = x.narrow(last_dim, half, half)?;
18
19 Tensor::cat(&[&x2.neg()?, &x1], last_dim)
21}
22
23pub fn apply_rotary_pos_emb(
34 q: &Tensor,
35 k: &Tensor,
36 cos: &Tensor,
37 sin: &Tensor,
38) -> Result<(Tensor, Tensor)> {
39 let orig_dtype = q.dtype();
40
41 let q = if q.dtype() != cos.dtype() {
43 q.to_dtype(cos.dtype())?
44 } else {
45 q.clone()
46 };
47
48 let k = if k.dtype() != cos.dtype() {
49 k.to_dtype(cos.dtype())?
50 } else {
51 k.clone()
52 };
53
54 let cos = cos.unsqueeze(0)?.unsqueeze(2)?;
56 let sin = sin.unsqueeze(0)?.unsqueeze(2)?;
57
58 let q_rotated = rotate_half(&q)?;
60 let q_embed = q.broadcast_mul(&cos)?.add(&q_rotated.broadcast_mul(&sin)?)?;
61
62 let k_rotated = rotate_half(&k)?;
63 let k_embed = k.broadcast_mul(&cos)?.add(&k_rotated.broadcast_mul(&sin)?)?;
64
65 let q_embed = if q_embed.dtype() != orig_dtype {
67 q_embed.to_dtype(orig_dtype)?
68 } else {
69 q_embed
70 };
71
72 let k_embed = if k_embed.dtype() != orig_dtype {
73 k_embed.to_dtype(orig_dtype)?
74 } else {
75 k_embed
76 };
77
78 Ok((q_embed, k_embed))
79}
80
81pub struct RotaryEmbedding {
85 cos_cached: Tensor,
86 sin_cached: Tensor,
87}
88
89impl RotaryEmbedding {
90 pub fn new(
98 dim: usize,
99 max_position_embeddings: usize,
100 base: f32,
101 device: &Device,
102 ) -> Result<Self> {
103 let inv_freq: Vec<f32> = (0..dim)
105 .step_by(2)
106 .map(|i| {
107 let exponent = i as f32 / dim as f32;
108 1.0 / base.powf(exponent)
109 })
110 .collect();
111
112 let inv_freq = Tensor::new(inv_freq.as_slice(), device)?;
113
114 let t: Vec<f32> = (0..max_position_embeddings).map(|i| i as f32).collect();
116 let t = Tensor::new(t.as_slice(), device)?;
117
118 let freqs = t.unsqueeze(1)?.broadcast_mul(&inv_freq.unsqueeze(0)?)?;
121
122 let emb = Tensor::cat(&[&freqs, &freqs], 1)?;
125
126 let cos_cached = emb.cos()?;
128 let sin_cached = emb.sin()?;
129
130 Ok(Self {
131 cos_cached,
132 sin_cached,
133 })
134 }
135
136 pub fn forward(&self) -> Result<(Tensor, Tensor)> {
141 Ok((self.cos_cached.clone(), self.sin_cached.clone()))
142 }
143
144 pub fn forward_with_len(&self, seq_len: usize) -> Result<(Tensor, Tensor)> {
152 let cos = self.cos_cached.narrow(0, 0, seq_len)?;
153 let sin = self.sin_cached.narrow(0, 0, seq_len)?;
154 Ok((cos, sin))
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161
162 #[test]
163 fn test_rotate_half() -> Result<()> {
164 let device = Device::Cpu;
165
166 let x = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0], &device)?.reshape((1, 4))?;
168 let rotated = rotate_half(&x)?;
169
170 let expected = Tensor::new(&[-3.0f32, -4.0, 1.0, 2.0], &device)?.reshape((1, 4))?;
171
172 let diff = rotated.sub(&expected)?.abs()?.sum_all()?.to_scalar::<f32>()?;
173 assert!(diff < 1e-6, "rotate_half failed");
174
175 Ok(())
176 }
177
178 #[test]
179 fn test_rotary_embedding_shape() -> Result<()> {
180 let device = Device::Cpu;
181
182 let rope = RotaryEmbedding::new(64, 512, 10000.0, &device)?;
183 let (cos, sin) = rope.forward()?;
184
185 assert_eq!(cos.dims(), &[512, 64]);
186 assert_eq!(sin.dims(), &[512, 64]);
187
188 Ok(())
189 }
190
191 #[test]
192 fn test_rotary_embedding_with_len() -> Result<()> {
193 let device = Device::Cpu;
194
195 let rope = RotaryEmbedding::new(64, 512, 10000.0, &device)?;
196 let (cos, sin) = rope.forward_with_len(128)?;
197
198 assert_eq!(cos.dims(), &[128, 64]);
199 assert_eq!(sin.dims(), &[128, 64]);
200
201 Ok(())
202 }
203
204 #[test]
205 fn test_apply_rotary_pos_emb_shape() -> Result<()> {
206 let device = Device::Cpu;
207
208 let q = Tensor::randn(0f32, 1.0, (2, 16, 8, 64), &device)?;
210 let k = Tensor::randn(0f32, 1.0, (2, 16, 8, 64), &device)?;
211
212 let rope = RotaryEmbedding::new(64, 512, 10000.0, &device)?;
214 let (cos, sin) = rope.forward_with_len(16)?;
215
216 let (q_embed, k_embed) = apply_rotary_pos_emb(&q, &k, &cos, &sin)?;
218
219 assert_eq!(q_embed.dims(), q.dims());
221 assert_eq!(k_embed.dims(), k.dims());
222
223 Ok(())
224 }
225}