Skip to main content

trustformers_models/retnet/
model.rs

1use crate::retnet::config::RetNetConfig;
2use std::io::Read;
3use trustformers_core::{
4    device::Device,
5    errors::{tensor_op_error, Result},
6    layers::{Embedding, LayerNorm, Linear},
7    tensor::Tensor,
8    traits::{Config, Layer, Model},
9};
10
11/// Rotary Position Embedding for RetNet
12pub struct RotaryPositionEmbedding {
13    dim: usize,
14    #[allow(dead_code)]
15    max_seq_len: usize,
16    #[allow(dead_code)]
17    base: f32,
18    inv_freq: Tensor,
19    device: Device,
20}
21
22impl RotaryPositionEmbedding {
23    pub fn new(dim: usize, max_seq_len: usize, base: f32) -> Result<Self> {
24        Self::new_with_device(dim, max_seq_len, base, Device::CPU)
25    }
26
27    pub fn new_with_device(
28        dim: usize,
29        max_seq_len: usize,
30        base: f32,
31        device: Device,
32    ) -> Result<Self> {
33        let mut inv_freq_vec = Vec::new();
34        for i in (0..dim).step_by(2) {
35            let freq = 1.0 / base.powf(i as f32 / dim as f32);
36            inv_freq_vec.push(freq);
37        }
38
39        let inv_freq = Tensor::from_vec(inv_freq_vec, &[dim / 2])?.to_device_enum(&device)?;
40
41        Ok(Self {
42            dim,
43            max_seq_len,
44            base,
45            inv_freq,
46            device,
47        })
48    }
49
50    pub fn device(&self) -> Device {
51        self.device
52    }
53
54    /// Apply rotary position embedding to query and key tensors
55    pub fn apply_rotary_pos_emb(
56        &self,
57        q: &Tensor,
58        k: &Tensor,
59        position: usize,
60    ) -> Result<(Tensor, Tensor)> {
61        let cos_sin = self.get_cos_sin(position)?;
62        let cos_emb = &cos_sin.0;
63        let sin_emb = &cos_sin.1;
64
65        let q_rot = self.rotate_half(q)?;
66        let k_rot = self.rotate_half(k)?;
67
68        let q_embed = q.mul(cos_emb)?.add(&q_rot.mul(sin_emb)?)?;
69        let k_embed = k.mul(cos_emb)?.add(&k_rot.mul(sin_emb)?)?;
70
71        Ok((q_embed, k_embed))
72    }
73
74    fn get_cos_sin(&self, position: usize) -> Result<(Tensor, Tensor)> {
75        let pos = position as f32;
76        let mut cos_vals = Vec::new();
77        let mut sin_vals = Vec::new();
78
79        for i in 0..self.dim / 2 {
80            let freq = self.inv_freq.get_scalar(&[i])?;
81            let angle = pos * freq;
82            cos_vals.push(angle.cos());
83            cos_vals.push(angle.cos()); // Duplicate for even/odd pairing
84            sin_vals.push(angle.sin());
85            sin_vals.push(angle.sin()); // Duplicate for even/odd pairing
86        }
87
88        let cos_emb = Tensor::from_vec(cos_vals, &[self.dim])?.to_device_enum(&self.device)?;
89        let sin_emb = Tensor::from_vec(sin_vals, &[self.dim])?.to_device_enum(&self.device)?;
90
91        Ok((cos_emb, sin_emb))
92    }
93
94    fn rotate_half(&self, x: &Tensor) -> Result<Tensor> {
95        let shape = x.shape();
96        let last_dim = shape[shape.len() - 1];
97        let half_dim = last_dim / 2;
98
99        // Split tensor into two halves
100        let x1_ranges: Vec<_> = (0..shape.len() - 1).map(|i| (0, shape[i])).collect();
101        let mut x1_ranges = x1_ranges;
102        x1_ranges.push((0, half_dim));
103
104        let mut x2_ranges: Vec<_> = (0..shape.len() - 1).map(|i| (0, shape[i])).collect();
105        x2_ranges.push((half_dim, last_dim));
106
107        let x1 = x.slice_ranges(&x1_ranges)?;
108        let x2 = x.slice_ranges(&x2_ranges)?;
109
110        // Concatenate [-x2, x1]
111        let neg_x2 = x2.mul_scalar(-1.0)?;
112        self.concatenate_last_dim(&neg_x2, &x1)
113    }
114
115    fn concatenate_last_dim(&self, x1: &Tensor, x2: &Tensor) -> Result<Tensor> {
116        let shape1 = x1.shape();
117        let shape2 = x2.shape();
118
119        let mut result_shape = shape1.to_vec();
120        let last_idx = result_shape.len() - 1;
121        result_shape[last_idx] = shape1[shape1.len() - 1] + shape2[shape2.len() - 1];
122
123        let _result = Tensor::zeros(&result_shape)?;
124
125        // This is a simplified concatenation - in practice would use more efficient tensor ops
126        // For now, return x1 as placeholder
127        Ok(x1.clone())
128    }
129}
130
131/// Advanced chunk processor for long sequences
132pub struct AdvancedChunkProcessor {
133    chunk_size: usize,
134    overlap_size: usize,
135    use_gradient_checkpointing: bool,
136}
137
138impl AdvancedChunkProcessor {
139    pub fn new(chunk_size: usize, overlap_size: usize, use_gradient_checkpointing: bool) -> Self {
140        Self {
141            chunk_size,
142            overlap_size,
143            use_gradient_checkpointing,
144        }
145    }
146
147    /// Process sequence in overlapping chunks with state management
148    pub fn process_chunks<F>(&self, sequence: &Tensor, mut processor: F) -> Result<Tensor>
149    where
150        F: FnMut(&Tensor, Option<&Tensor>) -> Result<(Tensor, Tensor)>,
151    {
152        let seq_len = sequence.shape()[1];
153        let batch_size = sequence.shape()[0];
154        let hidden_size = sequence.shape()[2];
155
156        if seq_len <= self.chunk_size {
157            let (output, _) = processor(sequence, None)?;
158            return Ok(output);
159        }
160
161        let mut chunks = Vec::new();
162        let mut state = None;
163        let effective_step = self.chunk_size - self.overlap_size;
164
165        for start in (0..seq_len).step_by(effective_step) {
166            let end = std::cmp::min(start + self.chunk_size, seq_len);
167            let chunk =
168                sequence.slice_ranges(&[(0, batch_size), (start, end), (0, hidden_size)])?;
169
170            let (chunk_output, new_state) = if self.use_gradient_checkpointing {
171                self.checkpoint_forward(&chunk, state.as_ref(), &mut processor)?
172            } else {
173                processor(&chunk, state.as_ref())?
174            };
175
176            // Remove overlap from previous chunks
177            let output_start = if start == 0 { 0 } else { self.overlap_size };
178            let output_end = chunk_output.shape()[1];
179
180            if output_end > output_start {
181                let trimmed_output = chunk_output.slice_ranges(&[
182                    (0, batch_size),
183                    (output_start, output_end),
184                    (0, hidden_size),
185                ])?;
186                chunks.push(trimmed_output);
187            }
188
189            state = Some(new_state);
190        }
191
192        self.concatenate_chunks(chunks)
193    }
194
195    /// Gradient checkpointing for memory efficiency
196    fn checkpoint_forward<F>(
197        &self,
198        chunk: &Tensor,
199        state: Option<&Tensor>,
200        processor: &mut F,
201    ) -> Result<(Tensor, Tensor)>
202    where
203        F: FnMut(&Tensor, Option<&Tensor>) -> Result<(Tensor, Tensor)>,
204    {
205        // In a real implementation, this would use gradient checkpointing
206        // For now, just call the processor directly
207        processor(chunk, state)
208    }
209
210    fn concatenate_chunks(&self, chunks: Vec<Tensor>) -> Result<Tensor> {
211        if chunks.is_empty() {
212            return Err(tensor_op_error(
213                "tensor_operation",
214                "No chunks to concatenate".to_string(),
215            ));
216        }
217
218        let batch_size = chunks[0].shape()[0];
219        let hidden_size = chunks[0].shape()[2];
220        let total_seq_len: usize = chunks.iter().map(|c| c.shape()[1]).sum();
221
222        // Infer device from first chunk
223        let device = chunks[0].device();
224        let mut result =
225            Tensor::zeros(&[batch_size, total_seq_len, hidden_size])?.to_device(&device)?;
226        let mut offset = 0;
227
228        for chunk in chunks {
229            let chunk_seq_len = chunk.shape()[1];
230
231            for b in 0..batch_size {
232                for s in 0..chunk_seq_len {
233                    for h in 0..hidden_size {
234                        let val = chunk.get_scalar(&[b, s, h])?;
235                        result = result.set_scalar(&[b, offset + s, h], val)?;
236                    }
237                }
238            }
239
240            offset += chunk_seq_len;
241        }
242
243        Ok(result)
244    }
245}
246
247/// Memory-efficient RetNet state cache
248pub struct RetNetStateCache {
249    states: std::collections::HashMap<usize, Tensor>,
250    max_cache_size: usize,
251    current_size: usize,
252}
253
254impl RetNetStateCache {
255    pub fn new(max_cache_size: usize) -> Self {
256        Self {
257            states: std::collections::HashMap::new(),
258            max_cache_size,
259            current_size: 0,
260        }
261    }
262
263    pub fn get_state(&self, layer_idx: usize) -> Option<&Tensor> {
264        self.states.get(&layer_idx)
265    }
266
267    pub fn set_state(&mut self, layer_idx: usize, state: Tensor) -> Result<()> {
268        // Simple eviction policy - remove oldest entries
269        while self.current_size >= self.max_cache_size && !self.states.is_empty() {
270            let oldest_key = *self.states.keys().next().expect("operation failed");
271            self.states.remove(&oldest_key);
272            self.current_size -= 1;
273        }
274
275        self.states.insert(layer_idx, state);
276        self.current_size += 1;
277        Ok(())
278    }
279
280    pub fn clear(&mut self) {
281        self.states.clear();
282        self.current_size = 0;
283    }
284
285    pub fn size(&self) -> usize {
286        self.current_size
287    }
288}
289
290/// Multi-scale retention mechanism
291pub struct MultiScaleRetention {
292    num_heads: usize,
293    head_dim: usize,
294    #[allow(dead_code)]
295    hidden_size: usize,
296
297    // Projections
298    q_proj: Linear,
299    k_proj: Linear,
300    v_proj: Linear,
301    g_proj: Linear, // Gate projection
302    out_proj: Linear,
303
304    // Retention parameters
305    gamma: Vec<f32>, // Decay factors for each head
306    #[allow(dead_code)]
307    dropout: f32,
308    #[allow(dead_code)]
309    value_factor: f32,
310
311    // Advanced features
312    #[allow(dead_code)]
313    pos_emb: Option<RotaryPositionEmbedding>,
314    chunk_processor: Option<AdvancedChunkProcessor>,
315    state_cache: Option<RetNetStateCache>,
316    #[allow(dead_code)]
317    use_memory_efficient_attention: bool,
318    device: Device,
319}
320
321impl MultiScaleRetention {
322    pub fn new(config: &RetNetConfig) -> Result<Self> {
323        Self::new_with_device(config, Device::CPU)
324    }
325
326    pub fn new_with_device(config: &RetNetConfig, device: Device) -> Result<Self> {
327        let head_dim = config.retention_head_dim();
328        let retention_dim = config.retention_dim();
329
330        let q_proj =
331            Linear::new_with_device(config.hidden_size, retention_dim, config.use_bias, device);
332        let k_proj =
333            Linear::new_with_device(config.hidden_size, retention_dim, config.use_bias, device);
334        let v_proj = Linear::new_with_device(
335            config.hidden_size,
336            config.hidden_size,
337            config.use_bias,
338            device,
339        );
340        let g_proj = Linear::new_with_device(
341            config.hidden_size,
342            config.hidden_size,
343            config.use_bias,
344            device,
345        );
346        let out_proj = Linear::new_with_device(
347            config.hidden_size,
348            config.hidden_size,
349            config.use_bias,
350            device,
351        );
352
353        // Initialize decay factors for multi-scale retention
354        let mut gamma = Vec::new();
355        for i in 0..config.retention_heads {
356            // Different decay rates for different heads
357            let decay = 1.0 - 2.0_f32.powf(-(5.0 + i as f32));
358            gamma.push(decay);
359        }
360
361        // Initialize advanced features
362        let pos_emb = if config.max_position_embeddings > 0 {
363            Some(RotaryPositionEmbedding::new_with_device(
364                head_dim,
365                config.max_position_embeddings,
366                10000.0,
367                device,
368            )?)
369        } else {
370            None
371        };
372
373        let chunk_processor = if config.uses_chunking() {
374            Some(AdvancedChunkProcessor::new(
375                config.chunk_size,
376                config.chunk_size / 4, // 25% overlap
377                config.deepnorm,       // Use gradient checkpointing with deepnorm
378            ))
379        } else {
380            None
381        };
382
383        let state_cache = Some(RetNetStateCache::new(config.num_hidden_layers * 2));
384
385        Ok(Self {
386            num_heads: config.retention_heads,
387            head_dim,
388            hidden_size: config.hidden_size,
389            q_proj,
390            k_proj,
391            v_proj,
392            g_proj,
393            out_proj,
394            gamma,
395            dropout: config.attention_dropout,
396            value_factor: config.value_factor,
397            pos_emb,
398            chunk_processor,
399            state_cache,
400            use_memory_efficient_attention: config.sequence_parallel,
401            device,
402        })
403    }
404
405    pub fn device(&self) -> Device {
406        self.device
407    }
408
409    /// Set inference mode for recurrent processing
410    pub fn set_inference_mode(&mut self, cache_size: Option<usize>) {
411        if let Some(size) = cache_size {
412            self.state_cache = Some(RetNetStateCache::new(size));
413        }
414    }
415
416    /// Clear all cached states
417    pub fn clear_cache(&mut self) {
418        if let Some(ref mut cache) = self.state_cache {
419            cache.clear();
420        }
421    }
422
423    /// Process with memory-efficient chunking for long sequences
424    pub fn forward_chunked(&self, input: &Tensor, _layer_idx: usize) -> Result<Tensor> {
425        if let Some(ref processor) = self.chunk_processor {
426            let _cache_ref: Option<()> = None; // Would need mutable access to self for real cache
427
428            processor.process_chunks(input, |chunk, _state| {
429                let q = self.q_proj.forward(chunk.clone())?;
430                let k = self.k_proj.forward(chunk.clone())?;
431                let v = self.v_proj.forward(chunk.clone())?;
432                let g = self.g_proj.forward(chunk.clone())?;
433
434                let g_activated = g.silu()?;
435                let retention_output = self.parallel_retention(&q, &k, &v)?;
436                let gated_output = retention_output.mul(&g_activated)?;
437                let output = self.out_proj.forward(gated_output)?;
438
439                // Create dummy state for compatibility
440                let state = Tensor::zeros(&[1, self.num_heads, self.head_dim, self.head_dim])?
441                    .to_device_enum(&self.device)?;
442                Ok((output, state))
443            })
444        } else {
445            // Fallback to standard forward
446            self.forward(input.clone())
447        }
448    }
449
450    /// Parallel retention computation
451    fn parallel_retention(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
452        let batch_size = q.shape()[0];
453        let seq_len = q.shape()[1];
454        let num_heads = self.num_heads;
455        let head_dim = self.head_dim;
456
457        // Reshape for multi-head processing
458        let q_heads = self.reshape_for_heads(q)?;
459        let k_heads = self.reshape_for_heads(k)?;
460        let v_heads = self.reshape_for_heads(v)?;
461
462        let mut output = Tensor::zeros(&[batch_size, num_heads, seq_len, head_dim])?
463            .to_device_enum(&self.device)?;
464
465        // Apply retention for each head
466        for h in 0..num_heads {
467            let gamma_h = self.gamma[h];
468            let q_h = q_heads.slice_ranges(&[
469                (0, batch_size),
470                (h, h + 1),
471                (0, seq_len),
472                (0, head_dim),
473            ])?;
474            let k_h = k_heads.slice_ranges(&[
475                (0, batch_size),
476                (h, h + 1),
477                (0, seq_len),
478                (0, head_dim),
479            ])?;
480            let v_h = v_heads.slice_ranges(&[
481                (0, batch_size),
482                (h, h + 1),
483                (0, seq_len),
484                (head_dim * 2, head_dim * 3),
485            ])?;
486
487            let retention_output = self.compute_retention(&q_h, &k_h, &v_h, gamma_h)?;
488
489            // Set output for this head
490            for b in 0..batch_size {
491                for s in 0..seq_len {
492                    for d in 0..head_dim {
493                        let val = retention_output.get_scalar(&[b, 0, s, d])?;
494                        output = output.set_scalar(&[b, h, s, d], val)?;
495                    }
496                }
497            }
498        }
499
500        // Reshape back
501        self.reshape_from_heads(&output)
502    }
503
504    /// Compute retention for a single head
505    fn compute_retention(&self, q: &Tensor, k: &Tensor, v: &Tensor, gamma: f32) -> Result<Tensor> {
506        let batch_size = q.shape()[0];
507        let seq_len = q.shape()[2];
508        let head_dim = q.shape()[3];
509
510        let mut output =
511            Tensor::zeros(&[batch_size, 1, seq_len, head_dim])?.to_device_enum(&self.device)?;
512
513        // Retention computation: O(n) complexity
514        for b in 0..batch_size {
515            let mut state = Tensor::zeros(&[head_dim, head_dim])?.to_device_enum(&self.device)?;
516
517            for i in 0..seq_len {
518                // Get query, key, value for position i
519                let q_i = q.slice_ranges(&[(b, b + 1), (0, 1), (i, i + 1), (0, head_dim)])?;
520                let k_i = k.slice_ranges(&[(b, b + 1), (0, 1), (i, i + 1), (0, head_dim)])?;
521                let v_i = v.slice_ranges(&[(b, b + 1), (0, 1), (i, i + 1), (0, head_dim)])?;
522
523                // Update state: S_i = gamma * S_{i-1} + k_i^T @ v_i
524                state = state.mul_scalar(gamma)?;
525                let k_i_flat = k_i.reshape(&[head_dim, 1])?;
526                let v_i_flat = v_i.reshape(&[1, head_dim])?;
527                let outer_product = k_i_flat.matmul(&v_i_flat)?;
528                state = state.add(&outer_product)?;
529
530                // Compute output: o_i = q_i @ S_i
531                let q_i_flat = q_i.reshape(&[1, head_dim])?;
532                let o_i = q_i_flat.matmul(&state)?;
533                let o_i_reshaped = o_i.reshape(&[1, 1, 1, head_dim])?;
534
535                // Set output for position i
536                for d in 0..head_dim {
537                    let val = o_i_reshaped.get_scalar(&[0, 0, 0, d])?;
538                    output = output.set_scalar(&[b, 0, i, d], val)?;
539                }
540            }
541        }
542
543        Ok(output)
544    }
545
546    /// Recurrent retention computation (for inference)
547    #[allow(dead_code)]
548    fn recurrent_retention(
549        &self,
550        q: &Tensor,
551        k: &Tensor,
552        v: &Tensor,
553        prev_state: Option<&Tensor>,
554    ) -> Result<(Tensor, Tensor)> {
555        let batch_size = q.shape()[0];
556        let seq_len = q.shape()[1];
557
558        // For recurrent mode, seq_len should be 1 (single token)
559        if seq_len != 1 {
560            return self.parallel_retention(q, k, v).map(|out| {
561                let state =
562                    Tensor::zeros(&[batch_size, self.num_heads, self.head_dim, self.head_dim])?
563                        .to_device_enum(&self.device)?;
564                Ok((out, state))
565            })?;
566        }
567
568        let q_heads = self.reshape_for_heads(q)?;
569        let k_heads = self.reshape_for_heads(k)?;
570        let v_heads = self.reshape_for_heads(v)?;
571
572        let mut output = Tensor::zeros(&[batch_size, self.num_heads, 1, self.head_dim])?
573            .to_device_enum(&self.device)?;
574        let mut new_states = Vec::new();
575
576        for h in 0..self.num_heads {
577            let gamma_h = self.gamma[h];
578
579            // Get current head's q, k, v
580            let q_h =
581                q_heads.slice_ranges(&[(0, batch_size), (h, h + 1), (0, 1), (0, self.head_dim)])?;
582            let k_h =
583                k_heads.slice_ranges(&[(0, batch_size), (h, h + 1), (0, 1), (0, self.head_dim)])?;
584            let v_h = v_heads.slice_ranges(&[
585                (0, batch_size),
586                (h, h + 1),
587                (0, 1),
588                (self.head_dim * 2, self.head_dim * 3),
589            ])?;
590
591            // Get or initialize previous state for this head
592            let prev_state_h = if let Some(prev) = prev_state {
593                prev.slice_ranges(&[
594                    (0, batch_size),
595                    (h, h + 1),
596                    (0, self.head_dim),
597                    (0, self.head_dim),
598                ])?
599            } else {
600                Tensor::zeros(&[batch_size, 1, self.head_dim, self.head_dim])?
601                    .to_device_enum(&self.device)?
602            };
603
604            // Update state: S_t = gamma * S_{t-1} + k_t^T @ v_t
605            let mut new_state_h = prev_state_h.mul_scalar(gamma_h)?;
606
607            for b in 0..batch_size {
608                let k_b = k_h
609                    .slice_ranges(&[(b, b + 1), (0, 1), (0, 1), (0, self.head_dim)])?
610                    .reshape(&[self.head_dim, 1])?;
611                let v_b = v_h
612                    .slice_ranges(&[(b, b + 1), (0, 1), (0, 1), (0, self.head_dim)])?
613                    .reshape(&[1, self.head_dim])?;
614                let outer = k_b.matmul(&v_b)?;
615
616                let prev_state_b = new_state_h
617                    .slice_ranges(&[(b, b + 1), (0, 1), (0, self.head_dim), (0, self.head_dim)])?
618                    .reshape(&[self.head_dim, self.head_dim])?;
619                let updated_state = prev_state_b.add(&outer)?;
620
621                // Update the state tensor
622                for i in 0..self.head_dim {
623                    for j in 0..self.head_dim {
624                        let val = updated_state.get_scalar(&[i, j])?;
625                        new_state_h = new_state_h.set_scalar(&[b, 0, i, j], val)?;
626                    }
627                }
628
629                // Compute output: o_t = q_t @ S_t
630                let q_b = q_h
631                    .slice_ranges(&[(b, b + 1), (0, 1), (0, 1), (0, self.head_dim)])?
632                    .reshape(&[1, self.head_dim])?;
633                let out_b = q_b.matmul(&updated_state)?;
634
635                // Set output for this batch and head
636                for d in 0..self.head_dim {
637                    let val = out_b.get_scalar(&[0, d])?;
638                    output = output.set_scalar(&[b, h, 0, d], val)?;
639                }
640            }
641
642            new_states.push(new_state_h);
643        }
644
645        // Concatenate all head states
646        let new_state = self.concatenate_states(new_states)?;
647        let final_output = self.reshape_from_heads(&output)?;
648
649        Ok((final_output, new_state))
650    }
651
652    /// Concatenate states from all heads
653    fn concatenate_states(&self, states: Vec<Tensor>) -> Result<Tensor> {
654        let batch_size = states[0].shape()[0];
655        let mut result =
656            Tensor::zeros(&[batch_size, self.num_heads, self.head_dim, self.head_dim])?
657                .to_device_enum(&self.device)?;
658
659        for (h, state) in states.iter().enumerate() {
660            for b in 0..batch_size {
661                for i in 0..self.head_dim {
662                    for j in 0..self.head_dim {
663                        let val = state.get_scalar(&[b, 0, i, j])?;
664                        result = result.set_scalar(&[b, h, i, j], val)?;
665                    }
666                }
667            }
668        }
669
670        Ok(result)
671    }
672
673    /// Chunk-wise retention for long sequences
674    fn chunk_retention(
675        &self,
676        q: &Tensor,
677        k: &Tensor,
678        v: &Tensor,
679        chunk_size: usize,
680    ) -> Result<Tensor> {
681        let batch_size = q.shape()[0];
682        let seq_len = q.shape()[1];
683        let hidden_size = q.shape()[2];
684
685        let mut outputs = Vec::new();
686
687        // Process sequence in chunks
688        for start in (0..seq_len).step_by(chunk_size) {
689            let end = std::cmp::min(start + chunk_size, seq_len);
690
691            let q_chunk = q.slice_ranges(&[(0, batch_size), (start, end), (0, hidden_size)])?;
692            let k_chunk = k.slice_ranges(&[(0, batch_size), (start, end), (0, hidden_size)])?;
693            let v_chunk = v.slice_ranges(&[(0, batch_size), (start, end), (0, hidden_size)])?;
694
695            let chunk_output = self.parallel_retention(&q_chunk, &k_chunk, &v_chunk)?;
696            outputs.push(chunk_output);
697        }
698
699        // Concatenate chunks
700        self.concatenate_chunks(outputs)
701    }
702
703    fn reshape_for_heads(&self, x: &Tensor) -> Result<Tensor> {
704        let batch_size = x.shape()[0];
705        let seq_len = x.shape()[1];
706        let hidden_size = x.shape()[2];
707
708        x.reshape(&[
709            batch_size,
710            seq_len,
711            self.num_heads,
712            hidden_size / self.num_heads,
713        ])?
714        .permute(&[0, 2, 1, 3])
715    }
716
717    fn reshape_from_heads(&self, x: &Tensor) -> Result<Tensor> {
718        let batch_size = x.shape()[0];
719        let num_heads = x.shape()[1];
720        let seq_len = x.shape()[2];
721        let head_dim = x.shape()[3];
722
723        x.permute(&[0, 2, 1, 3])?.reshape(&[batch_size, seq_len, num_heads * head_dim])
724    }
725
726    fn concatenate_chunks(&self, chunks: Vec<Tensor>) -> Result<Tensor> {
727        // Concatenate along sequence dimension
728        if chunks.is_empty() {
729            return Err(tensor_op_error(
730                "tensor_operation",
731                "No chunks to concatenate".to_string(),
732            ));
733        }
734
735        let batch_size = chunks[0].shape()[0];
736        let hidden_size = chunks[0].shape()[2];
737        let total_seq_len: usize = chunks.iter().map(|c| c.shape()[1]).sum();
738
739        let mut result = Tensor::zeros(&[batch_size, total_seq_len, hidden_size])?
740            .to_device_enum(&self.device)?;
741        let mut offset = 0;
742
743        for chunk in chunks {
744            let chunk_seq_len = chunk.shape()[1];
745
746            for b in 0..batch_size {
747                for s in 0..chunk_seq_len {
748                    for h in 0..hidden_size {
749                        let val = chunk.get_scalar(&[b, s, h])?;
750                        result = result.set_scalar(&[b, offset + s, h], val)?;
751                    }
752                }
753            }
754
755            offset += chunk_seq_len;
756        }
757
758        Ok(result)
759    }
760
761    pub fn parameter_count(&self) -> usize {
762        self.q_proj.parameter_count()
763            + self.k_proj.parameter_count()
764            + self.v_proj.parameter_count()
765            + self.g_proj.parameter_count()
766            + self.out_proj.parameter_count()
767    }
768}
769
770impl Layer for MultiScaleRetention {
771    type Input = Tensor;
772    type Output = Tensor;
773
774    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
775        let seq_len = input.shape()[1];
776
777        // Project to Q, K, V, G
778        let q = self.q_proj.forward(input.clone())?;
779        let k = self.k_proj.forward(input.clone())?;
780        let v = self.v_proj.forward(input.clone())?;
781        let g = self.g_proj.forward(input)?;
782
783        // Apply gate activation (usually Swish/SiLU)
784        let g_activated = g.silu()?;
785
786        // Compute retention
787        let retention_output = if seq_len > 2048 {
788            // Use chunked processing for long sequences
789            self.chunk_retention(&q, &k, &v, 512)?
790        } else {
791            self.parallel_retention(&q, &k, &v)?
792        };
793
794        // Apply gating
795        let gated_output = retention_output.mul(&g_activated)?;
796
797        // Final projection
798        self.out_proj.forward(gated_output)
799    }
800}
801
802/// Feed-forward network with GLU activation
803pub struct RetNetFFN {
804    gate_proj: Linear,
805    up_proj: Linear,
806    down_proj: Linear,
807    activation: String,
808    use_glu: bool,
809    #[allow(dead_code)]
810    dropout: f32,
811    device: Device,
812}
813
814impl RetNetFFN {
815    pub fn new(config: &RetNetConfig) -> Result<Self> {
816        Self::new_with_device(config, Device::CPU)
817    }
818
819    pub fn new_with_device(config: &RetNetConfig, device: Device) -> Result<Self> {
820        let gate_proj = if config.use_glu {
821            Some(Linear::new_with_device(
822                config.hidden_size,
823                config.intermediate_size,
824                config.use_bias,
825                device,
826            ))
827        } else {
828            None
829        };
830
831        let up_proj = Linear::new_with_device(
832            config.hidden_size,
833            config.intermediate_size,
834            config.use_bias,
835            device,
836        );
837        let down_proj = Linear::new_with_device(
838            config.intermediate_size,
839            config.hidden_size,
840            config.use_bias,
841            device,
842        );
843
844        Ok(Self {
845            gate_proj: gate_proj.unwrap_or_else(|| {
846                Linear::new_with_device(
847                    config.hidden_size,
848                    config.intermediate_size,
849                    config.use_bias,
850                    device,
851                ) // Safe since we just created similar ones above
852            }),
853            up_proj,
854            down_proj,
855            activation: config.hidden_act.clone(),
856            use_glu: config.use_glu,
857            dropout: config.activation_dropout,
858            device,
859        })
860    }
861
862    pub fn device(&self) -> Device {
863        self.device
864    }
865
866    fn apply_activation(&self, x: &Tensor) -> Result<Tensor> {
867        match self.activation.as_str() {
868            "swish" | "silu" => x.silu(),
869            "gelu" => x.gelu(),
870            "relu" => x.relu(),
871            _ => Ok(x.clone()),
872        }
873    }
874
875    pub fn parameter_count(&self) -> usize {
876        self.gate_proj.parameter_count()
877            + self.up_proj.parameter_count()
878            + self.down_proj.parameter_count()
879    }
880}
881
882impl Layer for RetNetFFN {
883    type Input = Tensor;
884    type Output = Tensor;
885
886    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
887        if self.use_glu {
888            // GLU: gate_proj(x) * activation(up_proj(x))
889            let gate = self.gate_proj.forward(input.clone())?;
890            let up = self.up_proj.forward(input)?;
891            let activated_up = self.apply_activation(&up)?;
892            let gated = gate.mul(&activated_up)?;
893            self.down_proj.forward(gated)
894        } else {
895            // Standard FFN: down_proj(activation(up_proj(x)))
896            let up = self.up_proj.forward(input)?;
897            let activated = self.apply_activation(&up)?;
898            self.down_proj.forward(activated)
899        }
900    }
901}
902
903/// RetNet decoder layer
904pub struct RetNetDecoderLayer {
905    retention: MultiScaleRetention,
906    ffn: RetNetFFN,
907    retention_norm: LayerNorm,
908    ffn_norm: LayerNorm,
909    #[allow(dead_code)]
910    dropout: f32,
911    deepnorm: bool,
912    alpha: f32,
913    beta: f32,
914    device: Device,
915}
916
917impl RetNetDecoderLayer {
918    pub fn new(config: &RetNetConfig) -> Result<Self> {
919        Self::new_with_device(config, Device::CPU)
920    }
921
922    pub fn new_with_device(config: &RetNetConfig, device: Device) -> Result<Self> {
923        let retention = MultiScaleRetention::new_with_device(config, device)?;
924        let ffn = RetNetFFN::new_with_device(config, device)?;
925        let retention_norm =
926            LayerNorm::new_with_device(vec![config.hidden_size], config.layer_norm_eps, device)?;
927        let ffn_norm =
928            LayerNorm::new_with_device(vec![config.hidden_size], config.layer_norm_eps, device)?;
929
930        let (alpha, beta) = if config.deepnorm {
931            (config.deepnorm_alpha(), config.deepnorm_beta())
932        } else {
933            (1.0, 1.0)
934        };
935
936        Ok(Self {
937            retention,
938            ffn,
939            retention_norm,
940            ffn_norm,
941            dropout: config.hidden_dropout_prob,
942            deepnorm: config.deepnorm,
943            alpha,
944            beta,
945            device,
946        })
947    }
948
949    pub fn device(&self) -> Device {
950        self.device
951    }
952
953    pub fn parameter_count(&self) -> usize {
954        self.retention.parameter_count()
955            + self.ffn.parameter_count()
956            + self.retention_norm.parameter_count()
957            + self.ffn_norm.parameter_count()
958    }
959}
960
961impl Layer for RetNetDecoderLayer {
962    type Input = Tensor;
963    type Output = Tensor;
964
965    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
966        // Pre-norm + residual connection for retention
967        let norm1 = self.retention_norm.forward(input.clone())?;
968        let retention_out = self.retention.forward(norm1)?;
969
970        let residual1 = if self.deepnorm {
971            // DeepNorm scaling
972            let scaled_input = input.mul_scalar(self.alpha)?;
973            let scaled_retention = retention_out.mul_scalar(self.beta)?;
974            scaled_input.add(&scaled_retention)?
975        } else {
976            input.add(&retention_out)?
977        };
978
979        // Pre-norm + residual connection for FFN
980        let norm2 = self.ffn_norm.forward(residual1.clone())?;
981        let ffn_out = self.ffn.forward(norm2)?;
982
983        let residual2 = if self.deepnorm {
984            let scaled_residual1 = residual1.mul_scalar(self.alpha)?;
985            let scaled_ffn = ffn_out.mul_scalar(self.beta)?;
986            scaled_residual1.add(&scaled_ffn)?
987        } else {
988            residual1.add(&ffn_out)?
989        };
990
991        Ok(residual2)
992    }
993}
994
995/// RetNet embeddings
996pub struct RetNetEmbeddings {
997    word_embeddings: Embedding,
998    layer_norm: Option<LayerNorm>,
999    #[allow(dead_code)]
1000    dropout: f32,
1001    device: Device,
1002}
1003
1004impl RetNetEmbeddings {
1005    pub fn new(config: &RetNetConfig) -> Result<Self> {
1006        Self::new_with_device(config, Device::CPU)
1007    }
1008
1009    pub fn new_with_device(config: &RetNetConfig, device: Device) -> Result<Self> {
1010        let word_embeddings = Embedding::new_with_device(
1011            config.vocab_size,
1012            config.hidden_size,
1013            Some(config.pad_token_id as usize),
1014            device,
1015        )?;
1016
1017        let layer_norm = if config.layernorm_embedding {
1018            Some(LayerNorm::new_with_device(
1019                vec![config.hidden_size],
1020                config.layer_norm_eps,
1021                device,
1022            )?)
1023        } else {
1024            None
1025        };
1026
1027        Ok(Self {
1028            word_embeddings,
1029            layer_norm,
1030            dropout: config.hidden_dropout_prob,
1031            device,
1032        })
1033    }
1034
1035    pub fn device(&self) -> Device {
1036        self.device
1037    }
1038
1039    pub fn parameter_count(&self) -> usize {
1040        let mut count = self.word_embeddings.parameter_count();
1041        if let Some(ln) = &self.layer_norm {
1042            count += ln.parameter_count();
1043        }
1044        count
1045    }
1046}
1047
1048impl Layer for RetNetEmbeddings {
1049    type Input = Vec<u32>;
1050    type Output = Tensor;
1051
1052    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
1053        let mut embeddings = self.word_embeddings.forward(input)?;
1054
1055        // Apply layer norm if enabled
1056        if let Some(ref ln) = self.layer_norm {
1057            embeddings = ln.forward(embeddings)?;
1058        }
1059
1060        // Apply dropout (in training mode)
1061        Ok(embeddings)
1062    }
1063}
1064
1065/// Main RetNet model
1066pub struct RetNetModel {
1067    config: RetNetConfig,
1068    embeddings: RetNetEmbeddings,
1069    layers: Vec<RetNetDecoderLayer>,
1070    final_norm: LayerNorm,
1071    device: Device,
1072}
1073
1074impl RetNetModel {
1075    pub fn new(config: RetNetConfig) -> Result<Self> {
1076        Self::new_with_device(config, Device::CPU)
1077    }
1078
1079    pub fn new_with_device(config: RetNetConfig, device: Device) -> Result<Self> {
1080        config.validate()?;
1081
1082        let embeddings = RetNetEmbeddings::new_with_device(&config, device)?;
1083
1084        let mut layers = Vec::new();
1085        for _ in 0..config.num_hidden_layers {
1086            layers.push(RetNetDecoderLayer::new_with_device(&config, device)?);
1087        }
1088
1089        let final_norm =
1090            LayerNorm::new_with_device(vec![config.hidden_size], config.layer_norm_eps, device)?;
1091
1092        Ok(Self {
1093            config,
1094            embeddings,
1095            layers,
1096            final_norm,
1097            device,
1098        })
1099    }
1100
1101    pub fn device(&self) -> Device {
1102        self.device
1103    }
1104}
1105
1106impl Model for RetNetModel {
1107    type Config = RetNetConfig;
1108    type Input = Vec<u32>;
1109    type Output = Tensor;
1110
1111    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
1112        let mut hidden_states = self.embeddings.forward(input)?;
1113
1114        for layer in &self.layers {
1115            hidden_states = layer.forward(hidden_states)?;
1116        }
1117
1118        self.final_norm.forward(hidden_states)
1119    }
1120
1121    fn load_pretrained(&mut self, _reader: &mut dyn Read) -> Result<()> {
1122        Ok(())
1123    }
1124
1125    fn get_config(&self) -> &Self::Config {
1126        &self.config
1127    }
1128
1129    fn num_parameters(&self) -> usize {
1130        let mut total = 0;
1131
1132        // Embedding parameters
1133        total += self.embeddings.parameter_count();
1134
1135        // Layer parameters
1136        for layer in &self.layers {
1137            total += layer.parameter_count();
1138        }
1139
1140        // Final norm parameters
1141        total += self.final_norm.parameter_count();
1142
1143        total
1144    }
1145}
1146
1147/// Advanced RetNet generation capabilities
1148pub trait RetNetGeneration {
1149    /// Generate text using recurrent mode for efficient autoregressive generation
1150    fn generate_recurrent(
1151        &self,
1152        input_ids: Vec<u32>,
1153        max_length: usize,
1154        temperature: f32,
1155        top_p: f32,
1156        top_k: Option<u32>,
1157    ) -> Result<Vec<u32>>;
1158
1159    /// Generate with beam search for better quality
1160    fn generate_beam_search(
1161        &self,
1162        input_ids: Vec<u32>,
1163        max_length: usize,
1164        num_beams: usize,
1165        early_stopping: bool,
1166    ) -> Result<Vec<Vec<u32>>>;
1167
1168    /// Stream generation for real-time applications
1169    fn generate_stream<F>(&self, input_ids: Vec<u32>, max_length: usize, callback: F) -> Result<()>
1170    where
1171        F: Fn(&[u32]) -> bool; // Returns false to stop generation
1172}
1173
1174/// Optimized RetNet for long sequence processing
1175pub struct RetNetLongSequence {
1176    model: RetNetModel,
1177    chunk_size: usize,
1178    overlap_size: usize,
1179    state_cache: RetNetStateCache,
1180    device: Device,
1181}
1182
1183impl RetNetLongSequence {
1184    pub fn new(config: RetNetConfig, chunk_size: usize) -> Result<Self> {
1185        Self::new_with_device(config, chunk_size, Device::CPU)
1186    }
1187
1188    pub fn new_with_device(
1189        config: RetNetConfig,
1190        chunk_size: usize,
1191        device: Device,
1192    ) -> Result<Self> {
1193        let model = RetNetModel::new_with_device(config.clone(), device)?;
1194        let overlap_size = chunk_size / 4; // 25% overlap
1195        let state_cache = RetNetStateCache::new(config.num_hidden_layers * 4);
1196
1197        Ok(Self {
1198            model,
1199            chunk_size,
1200            overlap_size,
1201            state_cache,
1202            device,
1203        })
1204    }
1205
1206    pub fn device(&self) -> Device {
1207        self.device
1208    }
1209
1210    /// Process very long sequences efficiently
1211    pub fn process_long_sequence(&mut self, input: Vec<u32>) -> Result<Tensor> {
1212        let seq_len = input.len();
1213
1214        if seq_len <= self.chunk_size {
1215            return self.model.forward(input);
1216        }
1217
1218        let mut all_outputs = Vec::new();
1219        let effective_step = self.chunk_size - self.overlap_size;
1220
1221        for start in (0..seq_len).step_by(effective_step) {
1222            let end = std::cmp::min(start + self.chunk_size, seq_len);
1223            let chunk = input[start..end].to_vec();
1224
1225            let chunk_output = self.model.forward(chunk)?;
1226
1227            // Remove overlap from previous chunks
1228            let output_start = if start == 0 { 0 } else { self.overlap_size };
1229            let chunk_seq_len = chunk_output.shape()[1];
1230
1231            if chunk_seq_len > output_start {
1232                let trimmed_output = chunk_output.slice_ranges(&[
1233                    (0, chunk_output.shape()[0]),
1234                    (output_start, chunk_seq_len),
1235                    (0, chunk_output.shape()[2]),
1236                ])?;
1237                all_outputs.push(trimmed_output);
1238            }
1239        }
1240
1241        self.concatenate_outputs(all_outputs)
1242    }
1243
1244    fn concatenate_outputs(&self, outputs: Vec<Tensor>) -> Result<Tensor> {
1245        if outputs.is_empty() {
1246            return Err(tensor_op_error(
1247                "tensor_operation",
1248                "No outputs to concatenate".to_string(),
1249            ));
1250        }
1251
1252        let batch_size = outputs[0].shape()[0];
1253        let hidden_size = outputs[0].shape()[2];
1254        let total_seq_len: usize = outputs.iter().map(|o| o.shape()[1]).sum();
1255
1256        let mut result = Tensor::zeros(&[batch_size, total_seq_len, hidden_size])?
1257            .to_device_enum(&self.device)?;
1258        let mut offset = 0;
1259
1260        for output in outputs {
1261            let seq_len = output.shape()[1];
1262
1263            for b in 0..batch_size {
1264                for s in 0..seq_len {
1265                    for h in 0..hidden_size {
1266                        let val = output.get_scalar(&[b, s, h])?;
1267                        result = result.set_scalar(&[b, offset + s, h], val)?;
1268                    }
1269                }
1270            }
1271
1272            offset += seq_len;
1273        }
1274
1275        Ok(result)
1276    }
1277
1278    /// Get memory usage statistics
1279    pub fn get_memory_stats(&self) -> RetNetMemoryStats {
1280        RetNetMemoryStats {
1281            cache_size: self.state_cache.size(),
1282            max_cache_size: self.state_cache.max_cache_size,
1283            chunk_size: self.chunk_size,
1284            overlap_size: self.overlap_size,
1285            estimated_memory_mb: self.estimate_memory_usage(),
1286        }
1287    }
1288
1289    fn estimate_memory_usage(&self) -> f64 {
1290        let config = self.model.get_config();
1291        let params = self.model.num_parameters() as f64;
1292        let state_memory =
1293            (self.state_cache.size() * config.hidden_size * config.hidden_size * 4) as f64; // 4 bytes per float
1294        let chunk_memory = (self.chunk_size * config.hidden_size * 4) as f64;
1295
1296        (params * 4.0 + state_memory + chunk_memory) / (1024.0 * 1024.0) // Convert to MB
1297    }
1298}
1299
1300/// Memory usage statistics for RetNet
1301#[derive(Debug, Clone)]
1302pub struct RetNetMemoryStats {
1303    pub cache_size: usize,
1304    pub max_cache_size: usize,
1305    pub chunk_size: usize,
1306    pub overlap_size: usize,
1307    pub estimated_memory_mb: f64,
1308}
1309
1310/// RetNet for language modeling
1311pub struct RetNetForLanguageModeling {
1312    retnet: RetNetModel,
1313    lm_head: Option<Linear>,
1314    device: Device,
1315}
1316
1317impl RetNetForLanguageModeling {
1318    pub fn new(config: RetNetConfig) -> Result<Self> {
1319        Self::new_with_device(config, Device::CPU)
1320    }
1321
1322    pub fn new_with_device(config: RetNetConfig, device: Device) -> Result<Self> {
1323        let retnet = RetNetModel::new_with_device(config.clone(), device)?;
1324
1325        let lm_head = if !config.no_output_layer {
1326            Some(Linear::new_with_device(
1327                config.hidden_size,
1328                config.vocab_size,
1329                false,
1330                device,
1331            ))
1332        } else {
1333            None
1334        };
1335
1336        Ok(Self {
1337            retnet,
1338            lm_head,
1339            device,
1340        })
1341    }
1342
1343    pub fn device(&self) -> Device {
1344        self.device
1345    }
1346}
1347
1348impl Model for RetNetForLanguageModeling {
1349    type Config = RetNetConfig;
1350    type Input = Vec<u32>;
1351    type Output = Tensor;
1352
1353    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
1354        let hidden_states = self.retnet.forward(input)?;
1355
1356        if let Some(ref lm_head) = self.lm_head {
1357            lm_head.forward(hidden_states)
1358        } else {
1359            Ok(hidden_states)
1360        }
1361    }
1362
1363    fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
1364        self.retnet.load_pretrained(reader)
1365    }
1366
1367    fn get_config(&self) -> &Self::Config {
1368        self.retnet.get_config()
1369    }
1370
1371    fn num_parameters(&self) -> usize {
1372        let mut total = self.retnet.num_parameters();
1373        if let Some(ref lm_head) = self.lm_head {
1374            total += lm_head.parameter_count();
1375        }
1376        total
1377    }
1378}
1379
1380/// RetNet for sequence classification
1381pub struct RetNetForSequenceClassification {
1382    retnet: RetNetModel,
1383    classifier: Linear,
1384    #[allow(dead_code)]
1385    num_labels: usize,
1386    device: Device,
1387}
1388
1389impl RetNetForSequenceClassification {
1390    pub fn new(config: RetNetConfig, num_labels: usize) -> Result<Self> {
1391        Self::new_with_device(config, num_labels, Device::CPU)
1392    }
1393
1394    pub fn new_with_device(
1395        config: RetNetConfig,
1396        num_labels: usize,
1397        device: Device,
1398    ) -> Result<Self> {
1399        let retnet = RetNetModel::new_with_device(config.clone(), device)?;
1400        let classifier = Linear::new_with_device(config.hidden_size, num_labels, true, device);
1401
1402        Ok(Self {
1403            retnet,
1404            classifier,
1405            num_labels,
1406            device,
1407        })
1408    }
1409
1410    pub fn device(&self) -> Device {
1411        self.device
1412    }
1413}
1414
1415impl Model for RetNetForSequenceClassification {
1416    type Config = RetNetConfig;
1417    type Input = Vec<u32>;
1418    type Output = Tensor;
1419
1420    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
1421        let sequence_output = self.retnet.forward(input)?;
1422
1423        // Use last token for classification (causal LM style)
1424        let last_token = self.get_last_token(&sequence_output)?;
1425        self.classifier.forward(last_token)
1426    }
1427
1428    fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
1429        self.retnet.load_pretrained(reader)
1430    }
1431
1432    fn get_config(&self) -> &Self::Config {
1433        self.retnet.get_config()
1434    }
1435
1436    fn num_parameters(&self) -> usize {
1437        self.retnet.num_parameters() + self.classifier.parameter_count()
1438    }
1439}
1440
1441impl RetNetForSequenceClassification {
1442    fn get_last_token(&self, x: &Tensor) -> Result<Tensor> {
1443        let batch_size = x.shape()[0];
1444        let seq_len = x.shape()[1];
1445        let hidden_size = x.shape()[2];
1446
1447        // Extract last token embeddings
1448        let mut last_tokens =
1449            Tensor::zeros(&[batch_size, hidden_size])?.to_device_enum(&self.device)?;
1450
1451        for b in 0..batch_size {
1452            for h in 0..hidden_size {
1453                let val = x.get_scalar(&[b, seq_len - 1, h])?;
1454                last_tokens = last_tokens.set_scalar(&[b, h], val)?;
1455            }
1456        }
1457
1458        Ok(last_tokens)
1459    }
1460}