1use crate::retnet::config::RetNetConfig;
2use std::io::Read;
3use trustformers_core::{
4 device::Device,
5 errors::{tensor_op_error, Result},
6 layers::{Embedding, LayerNorm, Linear},
7 tensor::Tensor,
8 traits::{Config, Layer, Model},
9};
10
11pub struct RotaryPositionEmbedding {
13 dim: usize,
14 #[allow(dead_code)]
15 max_seq_len: usize,
16 #[allow(dead_code)]
17 base: f32,
18 inv_freq: Tensor,
19 device: Device,
20}
21
22impl RotaryPositionEmbedding {
23 pub fn new(dim: usize, max_seq_len: usize, base: f32) -> Result<Self> {
24 Self::new_with_device(dim, max_seq_len, base, Device::CPU)
25 }
26
27 pub fn new_with_device(
28 dim: usize,
29 max_seq_len: usize,
30 base: f32,
31 device: Device,
32 ) -> Result<Self> {
33 let mut inv_freq_vec = Vec::new();
34 for i in (0..dim).step_by(2) {
35 let freq = 1.0 / base.powf(i as f32 / dim as f32);
36 inv_freq_vec.push(freq);
37 }
38
39 let inv_freq = Tensor::from_vec(inv_freq_vec, &[dim / 2])?.to_device_enum(&device)?;
40
41 Ok(Self {
42 dim,
43 max_seq_len,
44 base,
45 inv_freq,
46 device,
47 })
48 }
49
50 pub fn device(&self) -> Device {
51 self.device
52 }
53
54 pub fn apply_rotary_pos_emb(
56 &self,
57 q: &Tensor,
58 k: &Tensor,
59 position: usize,
60 ) -> Result<(Tensor, Tensor)> {
61 let cos_sin = self.get_cos_sin(position)?;
62 let cos_emb = &cos_sin.0;
63 let sin_emb = &cos_sin.1;
64
65 let q_rot = self.rotate_half(q)?;
66 let k_rot = self.rotate_half(k)?;
67
68 let q_embed = q.mul(cos_emb)?.add(&q_rot.mul(sin_emb)?)?;
69 let k_embed = k.mul(cos_emb)?.add(&k_rot.mul(sin_emb)?)?;
70
71 Ok((q_embed, k_embed))
72 }
73
74 fn get_cos_sin(&self, position: usize) -> Result<(Tensor, Tensor)> {
75 let pos = position as f32;
76 let mut cos_vals = Vec::new();
77 let mut sin_vals = Vec::new();
78
79 for i in 0..self.dim / 2 {
80 let freq = self.inv_freq.get_scalar(&[i])?;
81 let angle = pos * freq;
82 cos_vals.push(angle.cos());
83 cos_vals.push(angle.cos()); sin_vals.push(angle.sin());
85 sin_vals.push(angle.sin()); }
87
88 let cos_emb = Tensor::from_vec(cos_vals, &[self.dim])?.to_device_enum(&self.device)?;
89 let sin_emb = Tensor::from_vec(sin_vals, &[self.dim])?.to_device_enum(&self.device)?;
90
91 Ok((cos_emb, sin_emb))
92 }
93
94 fn rotate_half(&self, x: &Tensor) -> Result<Tensor> {
95 let shape = x.shape();
96 let last_dim = shape[shape.len() - 1];
97 let half_dim = last_dim / 2;
98
99 let x1_ranges: Vec<_> = (0..shape.len() - 1).map(|i| (0, shape[i])).collect();
101 let mut x1_ranges = x1_ranges;
102 x1_ranges.push((0, half_dim));
103
104 let mut x2_ranges: Vec<_> = (0..shape.len() - 1).map(|i| (0, shape[i])).collect();
105 x2_ranges.push((half_dim, last_dim));
106
107 let x1 = x.slice_ranges(&x1_ranges)?;
108 let x2 = x.slice_ranges(&x2_ranges)?;
109
110 let neg_x2 = x2.mul_scalar(-1.0)?;
112 self.concatenate_last_dim(&neg_x2, &x1)
113 }
114
115 fn concatenate_last_dim(&self, x1: &Tensor, x2: &Tensor) -> Result<Tensor> {
116 let shape1 = x1.shape();
117 let shape2 = x2.shape();
118
119 let mut result_shape = shape1.to_vec();
120 let last_idx = result_shape.len() - 1;
121 result_shape[last_idx] = shape1[shape1.len() - 1] + shape2[shape2.len() - 1];
122
123 let _result = Tensor::zeros(&result_shape)?;
124
125 Ok(x1.clone())
128 }
129}
130
131pub struct AdvancedChunkProcessor {
133 chunk_size: usize,
134 overlap_size: usize,
135 use_gradient_checkpointing: bool,
136}
137
138impl AdvancedChunkProcessor {
139 pub fn new(chunk_size: usize, overlap_size: usize, use_gradient_checkpointing: bool) -> Self {
140 Self {
141 chunk_size,
142 overlap_size,
143 use_gradient_checkpointing,
144 }
145 }
146
147 pub fn process_chunks<F>(&self, sequence: &Tensor, mut processor: F) -> Result<Tensor>
149 where
150 F: FnMut(&Tensor, Option<&Tensor>) -> Result<(Tensor, Tensor)>,
151 {
152 let seq_len = sequence.shape()[1];
153 let batch_size = sequence.shape()[0];
154 let hidden_size = sequence.shape()[2];
155
156 if seq_len <= self.chunk_size {
157 let (output, _) = processor(sequence, None)?;
158 return Ok(output);
159 }
160
161 let mut chunks = Vec::new();
162 let mut state = None;
163 let effective_step = self.chunk_size - self.overlap_size;
164
165 for start in (0..seq_len).step_by(effective_step) {
166 let end = std::cmp::min(start + self.chunk_size, seq_len);
167 let chunk =
168 sequence.slice_ranges(&[(0, batch_size), (start, end), (0, hidden_size)])?;
169
170 let (chunk_output, new_state) = if self.use_gradient_checkpointing {
171 self.checkpoint_forward(&chunk, state.as_ref(), &mut processor)?
172 } else {
173 processor(&chunk, state.as_ref())?
174 };
175
176 let output_start = if start == 0 { 0 } else { self.overlap_size };
178 let output_end = chunk_output.shape()[1];
179
180 if output_end > output_start {
181 let trimmed_output = chunk_output.slice_ranges(&[
182 (0, batch_size),
183 (output_start, output_end),
184 (0, hidden_size),
185 ])?;
186 chunks.push(trimmed_output);
187 }
188
189 state = Some(new_state);
190 }
191
192 self.concatenate_chunks(chunks)
193 }
194
195 fn checkpoint_forward<F>(
197 &self,
198 chunk: &Tensor,
199 state: Option<&Tensor>,
200 processor: &mut F,
201 ) -> Result<(Tensor, Tensor)>
202 where
203 F: FnMut(&Tensor, Option<&Tensor>) -> Result<(Tensor, Tensor)>,
204 {
205 processor(chunk, state)
208 }
209
210 fn concatenate_chunks(&self, chunks: Vec<Tensor>) -> Result<Tensor> {
211 if chunks.is_empty() {
212 return Err(tensor_op_error(
213 "tensor_operation",
214 "No chunks to concatenate".to_string(),
215 ));
216 }
217
218 let batch_size = chunks[0].shape()[0];
219 let hidden_size = chunks[0].shape()[2];
220 let total_seq_len: usize = chunks.iter().map(|c| c.shape()[1]).sum();
221
222 let device = chunks[0].device();
224 let mut result =
225 Tensor::zeros(&[batch_size, total_seq_len, hidden_size])?.to_device(&device)?;
226 let mut offset = 0;
227
228 for chunk in chunks {
229 let chunk_seq_len = chunk.shape()[1];
230
231 for b in 0..batch_size {
232 for s in 0..chunk_seq_len {
233 for h in 0..hidden_size {
234 let val = chunk.get_scalar(&[b, s, h])?;
235 result = result.set_scalar(&[b, offset + s, h], val)?;
236 }
237 }
238 }
239
240 offset += chunk_seq_len;
241 }
242
243 Ok(result)
244 }
245}
246
247pub struct RetNetStateCache {
249 states: std::collections::HashMap<usize, Tensor>,
250 max_cache_size: usize,
251 current_size: usize,
252}
253
254impl RetNetStateCache {
255 pub fn new(max_cache_size: usize) -> Self {
256 Self {
257 states: std::collections::HashMap::new(),
258 max_cache_size,
259 current_size: 0,
260 }
261 }
262
263 pub fn get_state(&self, layer_idx: usize) -> Option<&Tensor> {
264 self.states.get(&layer_idx)
265 }
266
267 pub fn set_state(&mut self, layer_idx: usize, state: Tensor) -> Result<()> {
268 while self.current_size >= self.max_cache_size && !self.states.is_empty() {
270 let oldest_key = *self.states.keys().next().expect("operation failed");
271 self.states.remove(&oldest_key);
272 self.current_size -= 1;
273 }
274
275 self.states.insert(layer_idx, state);
276 self.current_size += 1;
277 Ok(())
278 }
279
280 pub fn clear(&mut self) {
281 self.states.clear();
282 self.current_size = 0;
283 }
284
285 pub fn size(&self) -> usize {
286 self.current_size
287 }
288}
289
290pub struct MultiScaleRetention {
292 num_heads: usize,
293 head_dim: usize,
294 #[allow(dead_code)]
295 hidden_size: usize,
296
297 q_proj: Linear,
299 k_proj: Linear,
300 v_proj: Linear,
301 g_proj: Linear, out_proj: Linear,
303
304 gamma: Vec<f32>, #[allow(dead_code)]
307 dropout: f32,
308 #[allow(dead_code)]
309 value_factor: f32,
310
311 #[allow(dead_code)]
313 pos_emb: Option<RotaryPositionEmbedding>,
314 chunk_processor: Option<AdvancedChunkProcessor>,
315 state_cache: Option<RetNetStateCache>,
316 #[allow(dead_code)]
317 use_memory_efficient_attention: bool,
318 device: Device,
319}
320
321impl MultiScaleRetention {
322 pub fn new(config: &RetNetConfig) -> Result<Self> {
323 Self::new_with_device(config, Device::CPU)
324 }
325
326 pub fn new_with_device(config: &RetNetConfig, device: Device) -> Result<Self> {
327 let head_dim = config.retention_head_dim();
328 let retention_dim = config.retention_dim();
329
330 let q_proj =
331 Linear::new_with_device(config.hidden_size, retention_dim, config.use_bias, device);
332 let k_proj =
333 Linear::new_with_device(config.hidden_size, retention_dim, config.use_bias, device);
334 let v_proj = Linear::new_with_device(
335 config.hidden_size,
336 config.hidden_size,
337 config.use_bias,
338 device,
339 );
340 let g_proj = Linear::new_with_device(
341 config.hidden_size,
342 config.hidden_size,
343 config.use_bias,
344 device,
345 );
346 let out_proj = Linear::new_with_device(
347 config.hidden_size,
348 config.hidden_size,
349 config.use_bias,
350 device,
351 );
352
353 let mut gamma = Vec::new();
355 for i in 0..config.retention_heads {
356 let decay = 1.0 - 2.0_f32.powf(-(5.0 + i as f32));
358 gamma.push(decay);
359 }
360
361 let pos_emb = if config.max_position_embeddings > 0 {
363 Some(RotaryPositionEmbedding::new_with_device(
364 head_dim,
365 config.max_position_embeddings,
366 10000.0,
367 device,
368 )?)
369 } else {
370 None
371 };
372
373 let chunk_processor = if config.uses_chunking() {
374 Some(AdvancedChunkProcessor::new(
375 config.chunk_size,
376 config.chunk_size / 4, config.deepnorm, ))
379 } else {
380 None
381 };
382
383 let state_cache = Some(RetNetStateCache::new(config.num_hidden_layers * 2));
384
385 Ok(Self {
386 num_heads: config.retention_heads,
387 head_dim,
388 hidden_size: config.hidden_size,
389 q_proj,
390 k_proj,
391 v_proj,
392 g_proj,
393 out_proj,
394 gamma,
395 dropout: config.attention_dropout,
396 value_factor: config.value_factor,
397 pos_emb,
398 chunk_processor,
399 state_cache,
400 use_memory_efficient_attention: config.sequence_parallel,
401 device,
402 })
403 }
404
405 pub fn device(&self) -> Device {
406 self.device
407 }
408
409 pub fn set_inference_mode(&mut self, cache_size: Option<usize>) {
411 if let Some(size) = cache_size {
412 self.state_cache = Some(RetNetStateCache::new(size));
413 }
414 }
415
416 pub fn clear_cache(&mut self) {
418 if let Some(ref mut cache) = self.state_cache {
419 cache.clear();
420 }
421 }
422
423 pub fn forward_chunked(&self, input: &Tensor, _layer_idx: usize) -> Result<Tensor> {
425 if let Some(ref processor) = self.chunk_processor {
426 let _cache_ref: Option<()> = None; processor.process_chunks(input, |chunk, _state| {
429 let q = self.q_proj.forward(chunk.clone())?;
430 let k = self.k_proj.forward(chunk.clone())?;
431 let v = self.v_proj.forward(chunk.clone())?;
432 let g = self.g_proj.forward(chunk.clone())?;
433
434 let g_activated = g.silu()?;
435 let retention_output = self.parallel_retention(&q, &k, &v)?;
436 let gated_output = retention_output.mul(&g_activated)?;
437 let output = self.out_proj.forward(gated_output)?;
438
439 let state = Tensor::zeros(&[1, self.num_heads, self.head_dim, self.head_dim])?
441 .to_device_enum(&self.device)?;
442 Ok((output, state))
443 })
444 } else {
445 self.forward(input.clone())
447 }
448 }
449
450 fn parallel_retention(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
452 let batch_size = q.shape()[0];
453 let seq_len = q.shape()[1];
454 let num_heads = self.num_heads;
455 let head_dim = self.head_dim;
456
457 let q_heads = self.reshape_for_heads(q)?;
459 let k_heads = self.reshape_for_heads(k)?;
460 let v_heads = self.reshape_for_heads(v)?;
461
462 let mut output = Tensor::zeros(&[batch_size, num_heads, seq_len, head_dim])?
463 .to_device_enum(&self.device)?;
464
465 for h in 0..num_heads {
467 let gamma_h = self.gamma[h];
468 let q_h = q_heads.slice_ranges(&[
469 (0, batch_size),
470 (h, h + 1),
471 (0, seq_len),
472 (0, head_dim),
473 ])?;
474 let k_h = k_heads.slice_ranges(&[
475 (0, batch_size),
476 (h, h + 1),
477 (0, seq_len),
478 (0, head_dim),
479 ])?;
480 let v_h = v_heads.slice_ranges(&[
481 (0, batch_size),
482 (h, h + 1),
483 (0, seq_len),
484 (head_dim * 2, head_dim * 3),
485 ])?;
486
487 let retention_output = self.compute_retention(&q_h, &k_h, &v_h, gamma_h)?;
488
489 for b in 0..batch_size {
491 for s in 0..seq_len {
492 for d in 0..head_dim {
493 let val = retention_output.get_scalar(&[b, 0, s, d])?;
494 output = output.set_scalar(&[b, h, s, d], val)?;
495 }
496 }
497 }
498 }
499
500 self.reshape_from_heads(&output)
502 }
503
504 fn compute_retention(&self, q: &Tensor, k: &Tensor, v: &Tensor, gamma: f32) -> Result<Tensor> {
506 let batch_size = q.shape()[0];
507 let seq_len = q.shape()[2];
508 let head_dim = q.shape()[3];
509
510 let mut output =
511 Tensor::zeros(&[batch_size, 1, seq_len, head_dim])?.to_device_enum(&self.device)?;
512
513 for b in 0..batch_size {
515 let mut state = Tensor::zeros(&[head_dim, head_dim])?.to_device_enum(&self.device)?;
516
517 for i in 0..seq_len {
518 let q_i = q.slice_ranges(&[(b, b + 1), (0, 1), (i, i + 1), (0, head_dim)])?;
520 let k_i = k.slice_ranges(&[(b, b + 1), (0, 1), (i, i + 1), (0, head_dim)])?;
521 let v_i = v.slice_ranges(&[(b, b + 1), (0, 1), (i, i + 1), (0, head_dim)])?;
522
523 state = state.mul_scalar(gamma)?;
525 let k_i_flat = k_i.reshape(&[head_dim, 1])?;
526 let v_i_flat = v_i.reshape(&[1, head_dim])?;
527 let outer_product = k_i_flat.matmul(&v_i_flat)?;
528 state = state.add(&outer_product)?;
529
530 let q_i_flat = q_i.reshape(&[1, head_dim])?;
532 let o_i = q_i_flat.matmul(&state)?;
533 let o_i_reshaped = o_i.reshape(&[1, 1, 1, head_dim])?;
534
535 for d in 0..head_dim {
537 let val = o_i_reshaped.get_scalar(&[0, 0, 0, d])?;
538 output = output.set_scalar(&[b, 0, i, d], val)?;
539 }
540 }
541 }
542
543 Ok(output)
544 }
545
546 #[allow(dead_code)]
548 fn recurrent_retention(
549 &self,
550 q: &Tensor,
551 k: &Tensor,
552 v: &Tensor,
553 prev_state: Option<&Tensor>,
554 ) -> Result<(Tensor, Tensor)> {
555 let batch_size = q.shape()[0];
556 let seq_len = q.shape()[1];
557
558 if seq_len != 1 {
560 return self.parallel_retention(q, k, v).map(|out| {
561 let state =
562 Tensor::zeros(&[batch_size, self.num_heads, self.head_dim, self.head_dim])?
563 .to_device_enum(&self.device)?;
564 Ok((out, state))
565 })?;
566 }
567
568 let q_heads = self.reshape_for_heads(q)?;
569 let k_heads = self.reshape_for_heads(k)?;
570 let v_heads = self.reshape_for_heads(v)?;
571
572 let mut output = Tensor::zeros(&[batch_size, self.num_heads, 1, self.head_dim])?
573 .to_device_enum(&self.device)?;
574 let mut new_states = Vec::new();
575
576 for h in 0..self.num_heads {
577 let gamma_h = self.gamma[h];
578
579 let q_h =
581 q_heads.slice_ranges(&[(0, batch_size), (h, h + 1), (0, 1), (0, self.head_dim)])?;
582 let k_h =
583 k_heads.slice_ranges(&[(0, batch_size), (h, h + 1), (0, 1), (0, self.head_dim)])?;
584 let v_h = v_heads.slice_ranges(&[
585 (0, batch_size),
586 (h, h + 1),
587 (0, 1),
588 (self.head_dim * 2, self.head_dim * 3),
589 ])?;
590
591 let prev_state_h = if let Some(prev) = prev_state {
593 prev.slice_ranges(&[
594 (0, batch_size),
595 (h, h + 1),
596 (0, self.head_dim),
597 (0, self.head_dim),
598 ])?
599 } else {
600 Tensor::zeros(&[batch_size, 1, self.head_dim, self.head_dim])?
601 .to_device_enum(&self.device)?
602 };
603
604 let mut new_state_h = prev_state_h.mul_scalar(gamma_h)?;
606
607 for b in 0..batch_size {
608 let k_b = k_h
609 .slice_ranges(&[(b, b + 1), (0, 1), (0, 1), (0, self.head_dim)])?
610 .reshape(&[self.head_dim, 1])?;
611 let v_b = v_h
612 .slice_ranges(&[(b, b + 1), (0, 1), (0, 1), (0, self.head_dim)])?
613 .reshape(&[1, self.head_dim])?;
614 let outer = k_b.matmul(&v_b)?;
615
616 let prev_state_b = new_state_h
617 .slice_ranges(&[(b, b + 1), (0, 1), (0, self.head_dim), (0, self.head_dim)])?
618 .reshape(&[self.head_dim, self.head_dim])?;
619 let updated_state = prev_state_b.add(&outer)?;
620
621 for i in 0..self.head_dim {
623 for j in 0..self.head_dim {
624 let val = updated_state.get_scalar(&[i, j])?;
625 new_state_h = new_state_h.set_scalar(&[b, 0, i, j], val)?;
626 }
627 }
628
629 let q_b = q_h
631 .slice_ranges(&[(b, b + 1), (0, 1), (0, 1), (0, self.head_dim)])?
632 .reshape(&[1, self.head_dim])?;
633 let out_b = q_b.matmul(&updated_state)?;
634
635 for d in 0..self.head_dim {
637 let val = out_b.get_scalar(&[0, d])?;
638 output = output.set_scalar(&[b, h, 0, d], val)?;
639 }
640 }
641
642 new_states.push(new_state_h);
643 }
644
645 let new_state = self.concatenate_states(new_states)?;
647 let final_output = self.reshape_from_heads(&output)?;
648
649 Ok((final_output, new_state))
650 }
651
652 fn concatenate_states(&self, states: Vec<Tensor>) -> Result<Tensor> {
654 let batch_size = states[0].shape()[0];
655 let mut result =
656 Tensor::zeros(&[batch_size, self.num_heads, self.head_dim, self.head_dim])?
657 .to_device_enum(&self.device)?;
658
659 for (h, state) in states.iter().enumerate() {
660 for b in 0..batch_size {
661 for i in 0..self.head_dim {
662 for j in 0..self.head_dim {
663 let val = state.get_scalar(&[b, 0, i, j])?;
664 result = result.set_scalar(&[b, h, i, j], val)?;
665 }
666 }
667 }
668 }
669
670 Ok(result)
671 }
672
673 fn chunk_retention(
675 &self,
676 q: &Tensor,
677 k: &Tensor,
678 v: &Tensor,
679 chunk_size: usize,
680 ) -> Result<Tensor> {
681 let batch_size = q.shape()[0];
682 let seq_len = q.shape()[1];
683 let hidden_size = q.shape()[2];
684
685 let mut outputs = Vec::new();
686
687 for start in (0..seq_len).step_by(chunk_size) {
689 let end = std::cmp::min(start + chunk_size, seq_len);
690
691 let q_chunk = q.slice_ranges(&[(0, batch_size), (start, end), (0, hidden_size)])?;
692 let k_chunk = k.slice_ranges(&[(0, batch_size), (start, end), (0, hidden_size)])?;
693 let v_chunk = v.slice_ranges(&[(0, batch_size), (start, end), (0, hidden_size)])?;
694
695 let chunk_output = self.parallel_retention(&q_chunk, &k_chunk, &v_chunk)?;
696 outputs.push(chunk_output);
697 }
698
699 self.concatenate_chunks(outputs)
701 }
702
703 fn reshape_for_heads(&self, x: &Tensor) -> Result<Tensor> {
704 let batch_size = x.shape()[0];
705 let seq_len = x.shape()[1];
706 let hidden_size = x.shape()[2];
707
708 x.reshape(&[
709 batch_size,
710 seq_len,
711 self.num_heads,
712 hidden_size / self.num_heads,
713 ])?
714 .permute(&[0, 2, 1, 3])
715 }
716
717 fn reshape_from_heads(&self, x: &Tensor) -> Result<Tensor> {
718 let batch_size = x.shape()[0];
719 let num_heads = x.shape()[1];
720 let seq_len = x.shape()[2];
721 let head_dim = x.shape()[3];
722
723 x.permute(&[0, 2, 1, 3])?.reshape(&[batch_size, seq_len, num_heads * head_dim])
724 }
725
726 fn concatenate_chunks(&self, chunks: Vec<Tensor>) -> Result<Tensor> {
727 if chunks.is_empty() {
729 return Err(tensor_op_error(
730 "tensor_operation",
731 "No chunks to concatenate".to_string(),
732 ));
733 }
734
735 let batch_size = chunks[0].shape()[0];
736 let hidden_size = chunks[0].shape()[2];
737 let total_seq_len: usize = chunks.iter().map(|c| c.shape()[1]).sum();
738
739 let mut result = Tensor::zeros(&[batch_size, total_seq_len, hidden_size])?
740 .to_device_enum(&self.device)?;
741 let mut offset = 0;
742
743 for chunk in chunks {
744 let chunk_seq_len = chunk.shape()[1];
745
746 for b in 0..batch_size {
747 for s in 0..chunk_seq_len {
748 for h in 0..hidden_size {
749 let val = chunk.get_scalar(&[b, s, h])?;
750 result = result.set_scalar(&[b, offset + s, h], val)?;
751 }
752 }
753 }
754
755 offset += chunk_seq_len;
756 }
757
758 Ok(result)
759 }
760
761 pub fn parameter_count(&self) -> usize {
762 self.q_proj.parameter_count()
763 + self.k_proj.parameter_count()
764 + self.v_proj.parameter_count()
765 + self.g_proj.parameter_count()
766 + self.out_proj.parameter_count()
767 }
768}
769
770impl Layer for MultiScaleRetention {
771 type Input = Tensor;
772 type Output = Tensor;
773
774 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
775 let seq_len = input.shape()[1];
776
777 let q = self.q_proj.forward(input.clone())?;
779 let k = self.k_proj.forward(input.clone())?;
780 let v = self.v_proj.forward(input.clone())?;
781 let g = self.g_proj.forward(input)?;
782
783 let g_activated = g.silu()?;
785
786 let retention_output = if seq_len > 2048 {
788 self.chunk_retention(&q, &k, &v, 512)?
790 } else {
791 self.parallel_retention(&q, &k, &v)?
792 };
793
794 let gated_output = retention_output.mul(&g_activated)?;
796
797 self.out_proj.forward(gated_output)
799 }
800}
801
802pub struct RetNetFFN {
804 gate_proj: Linear,
805 up_proj: Linear,
806 down_proj: Linear,
807 activation: String,
808 use_glu: bool,
809 #[allow(dead_code)]
810 dropout: f32,
811 device: Device,
812}
813
814impl RetNetFFN {
815 pub fn new(config: &RetNetConfig) -> Result<Self> {
816 Self::new_with_device(config, Device::CPU)
817 }
818
819 pub fn new_with_device(config: &RetNetConfig, device: Device) -> Result<Self> {
820 let gate_proj = if config.use_glu {
821 Some(Linear::new_with_device(
822 config.hidden_size,
823 config.intermediate_size,
824 config.use_bias,
825 device,
826 ))
827 } else {
828 None
829 };
830
831 let up_proj = Linear::new_with_device(
832 config.hidden_size,
833 config.intermediate_size,
834 config.use_bias,
835 device,
836 );
837 let down_proj = Linear::new_with_device(
838 config.intermediate_size,
839 config.hidden_size,
840 config.use_bias,
841 device,
842 );
843
844 Ok(Self {
845 gate_proj: gate_proj.unwrap_or_else(|| {
846 Linear::new_with_device(
847 config.hidden_size,
848 config.intermediate_size,
849 config.use_bias,
850 device,
851 ) }),
853 up_proj,
854 down_proj,
855 activation: config.hidden_act.clone(),
856 use_glu: config.use_glu,
857 dropout: config.activation_dropout,
858 device,
859 })
860 }
861
862 pub fn device(&self) -> Device {
863 self.device
864 }
865
866 fn apply_activation(&self, x: &Tensor) -> Result<Tensor> {
867 match self.activation.as_str() {
868 "swish" | "silu" => x.silu(),
869 "gelu" => x.gelu(),
870 "relu" => x.relu(),
871 _ => Ok(x.clone()),
872 }
873 }
874
875 pub fn parameter_count(&self) -> usize {
876 self.gate_proj.parameter_count()
877 + self.up_proj.parameter_count()
878 + self.down_proj.parameter_count()
879 }
880}
881
882impl Layer for RetNetFFN {
883 type Input = Tensor;
884 type Output = Tensor;
885
886 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
887 if self.use_glu {
888 let gate = self.gate_proj.forward(input.clone())?;
890 let up = self.up_proj.forward(input)?;
891 let activated_up = self.apply_activation(&up)?;
892 let gated = gate.mul(&activated_up)?;
893 self.down_proj.forward(gated)
894 } else {
895 let up = self.up_proj.forward(input)?;
897 let activated = self.apply_activation(&up)?;
898 self.down_proj.forward(activated)
899 }
900 }
901}
902
903pub struct RetNetDecoderLayer {
905 retention: MultiScaleRetention,
906 ffn: RetNetFFN,
907 retention_norm: LayerNorm,
908 ffn_norm: LayerNorm,
909 #[allow(dead_code)]
910 dropout: f32,
911 deepnorm: bool,
912 alpha: f32,
913 beta: f32,
914 device: Device,
915}
916
917impl RetNetDecoderLayer {
918 pub fn new(config: &RetNetConfig) -> Result<Self> {
919 Self::new_with_device(config, Device::CPU)
920 }
921
922 pub fn new_with_device(config: &RetNetConfig, device: Device) -> Result<Self> {
923 let retention = MultiScaleRetention::new_with_device(config, device)?;
924 let ffn = RetNetFFN::new_with_device(config, device)?;
925 let retention_norm =
926 LayerNorm::new_with_device(vec![config.hidden_size], config.layer_norm_eps, device)?;
927 let ffn_norm =
928 LayerNorm::new_with_device(vec![config.hidden_size], config.layer_norm_eps, device)?;
929
930 let (alpha, beta) = if config.deepnorm {
931 (config.deepnorm_alpha(), config.deepnorm_beta())
932 } else {
933 (1.0, 1.0)
934 };
935
936 Ok(Self {
937 retention,
938 ffn,
939 retention_norm,
940 ffn_norm,
941 dropout: config.hidden_dropout_prob,
942 deepnorm: config.deepnorm,
943 alpha,
944 beta,
945 device,
946 })
947 }
948
949 pub fn device(&self) -> Device {
950 self.device
951 }
952
953 pub fn parameter_count(&self) -> usize {
954 self.retention.parameter_count()
955 + self.ffn.parameter_count()
956 + self.retention_norm.parameter_count()
957 + self.ffn_norm.parameter_count()
958 }
959}
960
961impl Layer for RetNetDecoderLayer {
962 type Input = Tensor;
963 type Output = Tensor;
964
965 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
966 let norm1 = self.retention_norm.forward(input.clone())?;
968 let retention_out = self.retention.forward(norm1)?;
969
970 let residual1 = if self.deepnorm {
971 let scaled_input = input.mul_scalar(self.alpha)?;
973 let scaled_retention = retention_out.mul_scalar(self.beta)?;
974 scaled_input.add(&scaled_retention)?
975 } else {
976 input.add(&retention_out)?
977 };
978
979 let norm2 = self.ffn_norm.forward(residual1.clone())?;
981 let ffn_out = self.ffn.forward(norm2)?;
982
983 let residual2 = if self.deepnorm {
984 let scaled_residual1 = residual1.mul_scalar(self.alpha)?;
985 let scaled_ffn = ffn_out.mul_scalar(self.beta)?;
986 scaled_residual1.add(&scaled_ffn)?
987 } else {
988 residual1.add(&ffn_out)?
989 };
990
991 Ok(residual2)
992 }
993}
994
995pub struct RetNetEmbeddings {
997 word_embeddings: Embedding,
998 layer_norm: Option<LayerNorm>,
999 #[allow(dead_code)]
1000 dropout: f32,
1001 device: Device,
1002}
1003
1004impl RetNetEmbeddings {
1005 pub fn new(config: &RetNetConfig) -> Result<Self> {
1006 Self::new_with_device(config, Device::CPU)
1007 }
1008
1009 pub fn new_with_device(config: &RetNetConfig, device: Device) -> Result<Self> {
1010 let word_embeddings = Embedding::new_with_device(
1011 config.vocab_size,
1012 config.hidden_size,
1013 Some(config.pad_token_id as usize),
1014 device,
1015 )?;
1016
1017 let layer_norm = if config.layernorm_embedding {
1018 Some(LayerNorm::new_with_device(
1019 vec![config.hidden_size],
1020 config.layer_norm_eps,
1021 device,
1022 )?)
1023 } else {
1024 None
1025 };
1026
1027 Ok(Self {
1028 word_embeddings,
1029 layer_norm,
1030 dropout: config.hidden_dropout_prob,
1031 device,
1032 })
1033 }
1034
1035 pub fn device(&self) -> Device {
1036 self.device
1037 }
1038
1039 pub fn parameter_count(&self) -> usize {
1040 let mut count = self.word_embeddings.parameter_count();
1041 if let Some(ln) = &self.layer_norm {
1042 count += ln.parameter_count();
1043 }
1044 count
1045 }
1046}
1047
1048impl Layer for RetNetEmbeddings {
1049 type Input = Vec<u32>;
1050 type Output = Tensor;
1051
1052 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
1053 let mut embeddings = self.word_embeddings.forward(input)?;
1054
1055 if let Some(ref ln) = self.layer_norm {
1057 embeddings = ln.forward(embeddings)?;
1058 }
1059
1060 Ok(embeddings)
1062 }
1063}
1064
1065pub struct RetNetModel {
1067 config: RetNetConfig,
1068 embeddings: RetNetEmbeddings,
1069 layers: Vec<RetNetDecoderLayer>,
1070 final_norm: LayerNorm,
1071 device: Device,
1072}
1073
1074impl RetNetModel {
1075 pub fn new(config: RetNetConfig) -> Result<Self> {
1076 Self::new_with_device(config, Device::CPU)
1077 }
1078
1079 pub fn new_with_device(config: RetNetConfig, device: Device) -> Result<Self> {
1080 config.validate()?;
1081
1082 let embeddings = RetNetEmbeddings::new_with_device(&config, device)?;
1083
1084 let mut layers = Vec::new();
1085 for _ in 0..config.num_hidden_layers {
1086 layers.push(RetNetDecoderLayer::new_with_device(&config, device)?);
1087 }
1088
1089 let final_norm =
1090 LayerNorm::new_with_device(vec![config.hidden_size], config.layer_norm_eps, device)?;
1091
1092 Ok(Self {
1093 config,
1094 embeddings,
1095 layers,
1096 final_norm,
1097 device,
1098 })
1099 }
1100
1101 pub fn device(&self) -> Device {
1102 self.device
1103 }
1104}
1105
1106impl Model for RetNetModel {
1107 type Config = RetNetConfig;
1108 type Input = Vec<u32>;
1109 type Output = Tensor;
1110
1111 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
1112 let mut hidden_states = self.embeddings.forward(input)?;
1113
1114 for layer in &self.layers {
1115 hidden_states = layer.forward(hidden_states)?;
1116 }
1117
1118 self.final_norm.forward(hidden_states)
1119 }
1120
1121 fn load_pretrained(&mut self, _reader: &mut dyn Read) -> Result<()> {
1122 Ok(())
1123 }
1124
1125 fn get_config(&self) -> &Self::Config {
1126 &self.config
1127 }
1128
1129 fn num_parameters(&self) -> usize {
1130 let mut total = 0;
1131
1132 total += self.embeddings.parameter_count();
1134
1135 for layer in &self.layers {
1137 total += layer.parameter_count();
1138 }
1139
1140 total += self.final_norm.parameter_count();
1142
1143 total
1144 }
1145}
1146
1147pub trait RetNetGeneration {
1149 fn generate_recurrent(
1151 &self,
1152 input_ids: Vec<u32>,
1153 max_length: usize,
1154 temperature: f32,
1155 top_p: f32,
1156 top_k: Option<u32>,
1157 ) -> Result<Vec<u32>>;
1158
1159 fn generate_beam_search(
1161 &self,
1162 input_ids: Vec<u32>,
1163 max_length: usize,
1164 num_beams: usize,
1165 early_stopping: bool,
1166 ) -> Result<Vec<Vec<u32>>>;
1167
1168 fn generate_stream<F>(&self, input_ids: Vec<u32>, max_length: usize, callback: F) -> Result<()>
1170 where
1171 F: Fn(&[u32]) -> bool; }
1173
1174pub struct RetNetLongSequence {
1176 model: RetNetModel,
1177 chunk_size: usize,
1178 overlap_size: usize,
1179 state_cache: RetNetStateCache,
1180 device: Device,
1181}
1182
1183impl RetNetLongSequence {
1184 pub fn new(config: RetNetConfig, chunk_size: usize) -> Result<Self> {
1185 Self::new_with_device(config, chunk_size, Device::CPU)
1186 }
1187
1188 pub fn new_with_device(
1189 config: RetNetConfig,
1190 chunk_size: usize,
1191 device: Device,
1192 ) -> Result<Self> {
1193 let model = RetNetModel::new_with_device(config.clone(), device)?;
1194 let overlap_size = chunk_size / 4; let state_cache = RetNetStateCache::new(config.num_hidden_layers * 4);
1196
1197 Ok(Self {
1198 model,
1199 chunk_size,
1200 overlap_size,
1201 state_cache,
1202 device,
1203 })
1204 }
1205
1206 pub fn device(&self) -> Device {
1207 self.device
1208 }
1209
1210 pub fn process_long_sequence(&mut self, input: Vec<u32>) -> Result<Tensor> {
1212 let seq_len = input.len();
1213
1214 if seq_len <= self.chunk_size {
1215 return self.model.forward(input);
1216 }
1217
1218 let mut all_outputs = Vec::new();
1219 let effective_step = self.chunk_size - self.overlap_size;
1220
1221 for start in (0..seq_len).step_by(effective_step) {
1222 let end = std::cmp::min(start + self.chunk_size, seq_len);
1223 let chunk = input[start..end].to_vec();
1224
1225 let chunk_output = self.model.forward(chunk)?;
1226
1227 let output_start = if start == 0 { 0 } else { self.overlap_size };
1229 let chunk_seq_len = chunk_output.shape()[1];
1230
1231 if chunk_seq_len > output_start {
1232 let trimmed_output = chunk_output.slice_ranges(&[
1233 (0, chunk_output.shape()[0]),
1234 (output_start, chunk_seq_len),
1235 (0, chunk_output.shape()[2]),
1236 ])?;
1237 all_outputs.push(trimmed_output);
1238 }
1239 }
1240
1241 self.concatenate_outputs(all_outputs)
1242 }
1243
1244 fn concatenate_outputs(&self, outputs: Vec<Tensor>) -> Result<Tensor> {
1245 if outputs.is_empty() {
1246 return Err(tensor_op_error(
1247 "tensor_operation",
1248 "No outputs to concatenate".to_string(),
1249 ));
1250 }
1251
1252 let batch_size = outputs[0].shape()[0];
1253 let hidden_size = outputs[0].shape()[2];
1254 let total_seq_len: usize = outputs.iter().map(|o| o.shape()[1]).sum();
1255
1256 let mut result = Tensor::zeros(&[batch_size, total_seq_len, hidden_size])?
1257 .to_device_enum(&self.device)?;
1258 let mut offset = 0;
1259
1260 for output in outputs {
1261 let seq_len = output.shape()[1];
1262
1263 for b in 0..batch_size {
1264 for s in 0..seq_len {
1265 for h in 0..hidden_size {
1266 let val = output.get_scalar(&[b, s, h])?;
1267 result = result.set_scalar(&[b, offset + s, h], val)?;
1268 }
1269 }
1270 }
1271
1272 offset += seq_len;
1273 }
1274
1275 Ok(result)
1276 }
1277
1278 pub fn get_memory_stats(&self) -> RetNetMemoryStats {
1280 RetNetMemoryStats {
1281 cache_size: self.state_cache.size(),
1282 max_cache_size: self.state_cache.max_cache_size,
1283 chunk_size: self.chunk_size,
1284 overlap_size: self.overlap_size,
1285 estimated_memory_mb: self.estimate_memory_usage(),
1286 }
1287 }
1288
1289 fn estimate_memory_usage(&self) -> f64 {
1290 let config = self.model.get_config();
1291 let params = self.model.num_parameters() as f64;
1292 let state_memory =
1293 (self.state_cache.size() * config.hidden_size * config.hidden_size * 4) as f64; let chunk_memory = (self.chunk_size * config.hidden_size * 4) as f64;
1295
1296 (params * 4.0 + state_memory + chunk_memory) / (1024.0 * 1024.0) }
1298}
1299
1300#[derive(Debug, Clone)]
1302pub struct RetNetMemoryStats {
1303 pub cache_size: usize,
1304 pub max_cache_size: usize,
1305 pub chunk_size: usize,
1306 pub overlap_size: usize,
1307 pub estimated_memory_mb: f64,
1308}
1309
1310pub struct RetNetForLanguageModeling {
1312 retnet: RetNetModel,
1313 lm_head: Option<Linear>,
1314 device: Device,
1315}
1316
1317impl RetNetForLanguageModeling {
1318 pub fn new(config: RetNetConfig) -> Result<Self> {
1319 Self::new_with_device(config, Device::CPU)
1320 }
1321
1322 pub fn new_with_device(config: RetNetConfig, device: Device) -> Result<Self> {
1323 let retnet = RetNetModel::new_with_device(config.clone(), device)?;
1324
1325 let lm_head = if !config.no_output_layer {
1326 Some(Linear::new_with_device(
1327 config.hidden_size,
1328 config.vocab_size,
1329 false,
1330 device,
1331 ))
1332 } else {
1333 None
1334 };
1335
1336 Ok(Self {
1337 retnet,
1338 lm_head,
1339 device,
1340 })
1341 }
1342
1343 pub fn device(&self) -> Device {
1344 self.device
1345 }
1346}
1347
1348impl Model for RetNetForLanguageModeling {
1349 type Config = RetNetConfig;
1350 type Input = Vec<u32>;
1351 type Output = Tensor;
1352
1353 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
1354 let hidden_states = self.retnet.forward(input)?;
1355
1356 if let Some(ref lm_head) = self.lm_head {
1357 lm_head.forward(hidden_states)
1358 } else {
1359 Ok(hidden_states)
1360 }
1361 }
1362
1363 fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
1364 self.retnet.load_pretrained(reader)
1365 }
1366
1367 fn get_config(&self) -> &Self::Config {
1368 self.retnet.get_config()
1369 }
1370
1371 fn num_parameters(&self) -> usize {
1372 let mut total = self.retnet.num_parameters();
1373 if let Some(ref lm_head) = self.lm_head {
1374 total += lm_head.parameter_count();
1375 }
1376 total
1377 }
1378}
1379
1380pub struct RetNetForSequenceClassification {
1382 retnet: RetNetModel,
1383 classifier: Linear,
1384 #[allow(dead_code)]
1385 num_labels: usize,
1386 device: Device,
1387}
1388
1389impl RetNetForSequenceClassification {
1390 pub fn new(config: RetNetConfig, num_labels: usize) -> Result<Self> {
1391 Self::new_with_device(config, num_labels, Device::CPU)
1392 }
1393
1394 pub fn new_with_device(
1395 config: RetNetConfig,
1396 num_labels: usize,
1397 device: Device,
1398 ) -> Result<Self> {
1399 let retnet = RetNetModel::new_with_device(config.clone(), device)?;
1400 let classifier = Linear::new_with_device(config.hidden_size, num_labels, true, device);
1401
1402 Ok(Self {
1403 retnet,
1404 classifier,
1405 num_labels,
1406 device,
1407 })
1408 }
1409
1410 pub fn device(&self) -> Device {
1411 self.device
1412 }
1413}
1414
1415impl Model for RetNetForSequenceClassification {
1416 type Config = RetNetConfig;
1417 type Input = Vec<u32>;
1418 type Output = Tensor;
1419
1420 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
1421 let sequence_output = self.retnet.forward(input)?;
1422
1423 let last_token = self.get_last_token(&sequence_output)?;
1425 self.classifier.forward(last_token)
1426 }
1427
1428 fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
1429 self.retnet.load_pretrained(reader)
1430 }
1431
1432 fn get_config(&self) -> &Self::Config {
1433 self.retnet.get_config()
1434 }
1435
1436 fn num_parameters(&self) -> usize {
1437 self.retnet.num_parameters() + self.classifier.parameter_count()
1438 }
1439}
1440
1441impl RetNetForSequenceClassification {
1442 fn get_last_token(&self, x: &Tensor) -> Result<Tensor> {
1443 let batch_size = x.shape()[0];
1444 let seq_len = x.shape()[1];
1445 let hidden_size = x.shape()[2];
1446
1447 let mut last_tokens =
1449 Tensor::zeros(&[batch_size, hidden_size])?.to_device_enum(&self.device)?;
1450
1451 for b in 0..batch_size {
1452 for h in 0..hidden_size {
1453 let val = x.get_scalar(&[b, seq_len - 1, h])?;
1454 last_tokens = last_tokens.set_scalar(&[b, h], val)?;
1455 }
1456 }
1457
1458 Ok(last_tokens)
1459 }
1460}