tiny_recursive_rs/models/
mod.rs

1/// Tiny Recursive Model implementation
2use candle_core::{Result, Tensor, Device, DType};
3use candle_nn::{VarBuilder, Module};
4use crate::config::TRMConfig;
5use crate::layers::{Attention, SwiGLU, CastedEmbedding, RMSNorm, RotaryEmbedding};
6use crate::layers::normalization::rms_norm;
7use crate::layers::activations::CastedLinear;
8
9pub mod loader;
10
11/// State carry for recursive computation
12///
13/// Holds the high-level and low-level states that are refined
14/// across recursive cycles.
15#[derive(Debug, Clone)]
16pub struct InnerCarry {
17    /// High-level state: [batch, seq_len, hidden_size]
18    pub z_h: Tensor,
19    /// Low-level state: [batch, seq_len, hidden_size]
20    pub z_l: Tensor,
21}
22
23impl InnerCarry {
24    /// Create new carry with given states
25    pub fn new(z_h: Tensor, z_l: Tensor) -> Self {
26        Self { z_h, z_l }
27    }
28
29    /// Create empty carry (uninitialized tensors)
30    pub fn empty(batch_size: usize, seq_len: usize, hidden_size: usize, dtype: DType, device: &Device) -> Result<Self> {
31        let z_h = Tensor::zeros((batch_size, seq_len, hidden_size), dtype, device)?;
32        let z_l = Tensor::zeros((batch_size, seq_len, hidden_size), dtype, device)?;
33        Ok(Self { z_h, z_l })
34    }
35}
36
37/// Transformer block for TRM
38///
39/// Each block consists of:
40/// - Self-attention (optional, not used in MLP mode)
41/// - Feed-forward network (SwiGLU)
42/// - RMS normalization with residual connections (post-norm)
43pub struct TransformerBlock {
44    config: TRMConfig,
45    self_attn: Option<Attention>,
46    mlp: SwiGLU,
47    norm_eps: f64,
48}
49
50impl TransformerBlock {
51    /// Create new transformer block
52    pub fn new(config: TRMConfig, vb: VarBuilder) -> Result<Self> {
53        // Self-attention (only if not using MLP-T mode)
54        let self_attn = if !config.mlp_t {
55            Some(Attention::new(
56                config.hidden_size,
57                config.head_dim(),
58                config.num_heads,
59                config.num_heads, // num_key_value_heads = num_heads (no GQA by default)
60                false, // not causal
61                vb.pp("self_attn"),
62            )?)
63        } else {
64            None
65        };
66
67        // Feed-forward network
68        let mlp = SwiGLU::new(
69            config.hidden_size,
70            config.expansion,
71            vb.pp("mlp"),
72        )?;
73
74        Ok(Self {
75            config: config.clone(),
76            self_attn,
77            mlp,
78            norm_eps: 1e-5,
79        })
80    }
81
82    /// Forward pass
83    ///
84    /// # Arguments
85    /// * `hidden_states` - Input tensor [batch, seq_len, hidden_size]
86    /// * `cos_sin` - Optional RoPE embeddings
87    ///
88    /// # Returns
89    /// Output tensor [batch, seq_len, hidden_size]
90    pub fn forward(
91        &self,
92        hidden_states: &Tensor,
93        cos_sin: Option<(&Tensor, &Tensor)>,
94    ) -> Result<Tensor> {
95        let mut hidden_states = hidden_states.clone();
96
97        // Self-attention sublayer (if not MLP mode)
98        if let Some(ref attn) = self.self_attn {
99            let attn_out = attn.forward(&hidden_states, cos_sin)?;
100            hidden_states = rms_norm(&(hidden_states + attn_out)?, self.norm_eps)?;
101        }
102
103        // Feed-forward sublayer
104        let mlp_out = self.mlp.forward(&hidden_states)?;
105        hidden_states = rms_norm(&(hidden_states + mlp_out)?, self.norm_eps)?;
106
107        Ok(hidden_states)
108    }
109}
110
111/// Reasoning module: stack of transformer blocks with input injection
112///
113/// This is the L-level or H-level reasoning component.
114pub struct ReasoningModule {
115    layers: Vec<TransformerBlock>,
116}
117
118impl ReasoningModule {
119    /// Create new reasoning module
120    ///
121    /// # Arguments
122    /// * `num_layers` - Number of transformer blocks
123    /// * `config` - Model configuration
124    /// * `vb` - VarBuilder for parameter initialization
125    pub fn new(num_layers: usize, config: TRMConfig, vb: VarBuilder) -> Result<Self> {
126        let mut layers = Vec::new();
127        for i in 0..num_layers {
128            layers.push(TransformerBlock::new(
129                config.clone(),
130                vb.pp(&format!("layer_{}", i)),
131            )?);
132        }
133
134        Ok(Self { layers })
135    }
136
137    /// Forward pass with input injection
138    ///
139    /// # Arguments
140    /// * `hidden_states` - Current state
141    /// * `input_injection` - Input to inject (added to hidden_states)
142    /// * `cos_sin` - Optional RoPE embeddings
143    ///
144    /// # Returns
145    /// Updated state after processing through all layers
146    pub fn forward(
147        &self,
148        hidden_states: &Tensor,
149        input_injection: &Tensor,
150        cos_sin: Option<(&Tensor, &Tensor)>,
151    ) -> Result<Tensor> {
152        // Add input injection
153        let mut hidden_states = (hidden_states + input_injection)?;
154
155        // Process through all layers
156        for layer in &self.layers {
157            hidden_states = layer.forward(&hidden_states, cos_sin)?;
158        }
159
160        Ok(hidden_states)
161    }
162}
163
164/// Main Tiny Recursive Model
165///
166/// Implements the recursive reasoning architecture with H-cycles and L-cycles.
167pub struct TinyRecursiveModel {
168    config: TRMConfig,
169
170    // I/O components
171    embed_tokens: CastedEmbedding,
172    lm_head: CastedLinear,
173    embed_scale: f64,
174
175    // Positional encodings
176    rotary_emb: Option<RotaryEmbedding>,
177
178    // Reasoning components
179    l_level: ReasoningModule,
180
181    // Initial states
182    h_init: Tensor,
183    l_init: Tensor,
184
185    // Device
186    device: Device,
187}
188
189impl TinyRecursiveModel {
190    /// Create new TinyRecursiveModel
191    pub fn new(config: TRMConfig, vb: VarBuilder) -> crate::Result<Self> {
192        config.validate()?;
193
194        let device = vb.device().clone();
195        let dtype = vb.dtype();
196
197        // Embedding scale: sqrt(hidden_size)
198        let embed_scale = (config.hidden_size as f64).sqrt();
199
200        // Token embeddings
201        let embed_tokens = CastedEmbedding::new(
202            config.vocab_size,
203            config.hidden_size,
204            vb.pp("embed_tokens"),
205            dtype,
206        )?;
207
208        // Output head
209        let lm_head = CastedLinear::new(
210            config.hidden_size,
211            config.num_outputs,
212            false,
213            vb.pp("lm_head"),
214        )?;
215
216        // Positional encodings
217        let rotary_emb = if config.pos_encodings == "rope" {
218            Some(RotaryEmbedding::new(
219                config.head_dim(),
220                2048, // max sequence length
221                10000.0,
222                &device,
223            )?)
224        } else {
225            None
226        };
227
228        // L-level reasoning module
229        let l_level = ReasoningModule::new(
230            config.l_layers,
231            config.clone(),
232            vb.pp("l_level"),
233        )?;
234
235        // Initial states (learnable parameters)
236        let h_init = vb.get(config.hidden_size, "h_init")?;
237        let l_init = vb.get(config.hidden_size, "l_init")?;
238
239        Ok(Self {
240            config,
241            embed_tokens,
242            lm_head,
243            embed_scale,
244            rotary_emb,
245            l_level,
246            h_init,
247            l_init,
248            device,
249        })
250    }
251
252    /// Create empty carry for a batch
253    pub fn empty_carry(&self, batch_size: usize) -> Result<InnerCarry> {
254        InnerCarry::empty(
255            batch_size,
256            self.config.vocab_size, // Using vocab_size as placeholder for seq_len
257            self.config.hidden_size,
258            DType::F32,
259            &self.device,
260        )
261    }
262
263    /// Reset carry to initial states where reset_flag is true
264    ///
265    /// # Arguments
266    /// * `reset_flag` - Boolean tensor [batch_size] indicating which sequences to reset
267    /// * `carry` - Current carry state
268    pub fn reset_carry(&self, reset_flag: &Tensor, carry: &InnerCarry) -> Result<InnerCarry> {
269        // Reshape reset_flag to [batch, 1, 1] for broadcasting
270        let reset_flag = reset_flag.unsqueeze(1)?.unsqueeze(1)?;
271
272        // Broadcast h_init and l_init to batch dimensions
273        let batch_size = carry.z_h.dim(0)?;
274        let seq_len = carry.z_h.dim(1)?;
275
276        let h_init = self.h_init
277            .unsqueeze(0)?
278            .unsqueeze(0)?
279            .broadcast_as((batch_size, seq_len, self.config.hidden_size))?;
280
281        let l_init = self.l_init
282            .unsqueeze(0)?
283            .unsqueeze(0)?
284            .broadcast_as((batch_size, seq_len, self.config.hidden_size))?;
285
286        // Where reset_flag is true, use init states; otherwise use carry states
287        let z_h = reset_flag.where_cond(&h_init, &carry.z_h)?;
288        let z_l = reset_flag.where_cond(&l_init, &carry.z_l)?;
289
290        Ok(InnerCarry::new(z_h, z_l))
291    }
292
293    /// Encode input tokens to embeddings
294    fn input_embeddings(&self, input: &Tensor) -> Result<Tensor> {
295        // Token embedding
296        let embedding = self.embed_tokens.forward(input)?;
297
298        // Scale by sqrt(hidden_size)
299        embedding.affine(self.embed_scale, 0.0)
300    }
301
302    /// Forward pass with recursive reasoning
303    ///
304    /// # Arguments
305    /// * `carry` - Current state (z_H, z_L)
306    /// * `input` - Input token IDs [batch, seq_len]
307    ///
308    /// # Returns
309    /// Tuple of (new_carry, logits)
310    /// - new_carry: Updated state for next iteration
311    /// - logits: Output logits [batch, seq_len, vocab_size]
312    pub fn forward(&self, carry: &InnerCarry, input: &Tensor) -> Result<(InnerCarry, Tensor)> {
313        let seq_len = input.dim(1)?;
314
315        // Get RoPE embeddings if needed
316        let cos_sin = if let Some(ref rope) = self.rotary_emb {
317            let (cos, sin) = rope.forward_with_len(seq_len)?;
318            Some((cos, sin))
319        } else {
320            None
321        };
322
323        // Input encoding
324        let input_embeddings = self.input_embeddings(input)?;
325
326        // Extract current states
327        let mut z_h = carry.z_h.clone();
328        let mut z_l = carry.z_l.clone();
329
330        // Recursive forward iterations
331        // Pattern from Python:
332        // - (H_cycles - 1) iterations without gradients
333        // - 1 final iteration with gradients
334        //
335        // Each H-cycle:
336        //   - L_cycles iterations: z_L = L_level(z_L, z_H + input)
337        //   - 1 iteration: z_H = L_level(z_H, z_L)
338
339        // For inference in Rust, we don't need the gradient control,
340        // so we just do all H_cycles normally
341        for _h_step in 0..self.config.h_cycles {
342            // L-cycles: refine z_L with z_H + input as injection
343            for _l_step in 0..self.config.l_cycles {
344                let injection = (&z_h + &input_embeddings)?;
345                z_l = self.l_level.forward(
346                    &z_l,
347                    &injection,
348                    cos_sin.as_ref().map(|(c, s)| (c.as_ref(), s.as_ref())),
349                )?;
350            }
351
352            // Update z_H with z_L as injection
353            z_h = self.l_level.forward(
354                &z_h,
355                &z_l,
356                cos_sin.as_ref().map(|(c, s)| (c.as_ref(), s.as_ref())),
357            )?;
358        }
359
360        // Output logits
361        let logits = self.lm_head.forward(&z_h)?;
362
363        // New carry (detached for stateful inference)
364        let new_carry = InnerCarry::new(z_h.clone(), z_l.clone());
365
366        Ok((new_carry, logits))
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373    use candle_nn::VarMap;
374
375    #[test]
376    fn test_inner_carry_creation() -> Result<()> {
377        let device = Device::Cpu;
378
379        let carry = InnerCarry::empty(2, 16, 256, DType::F32, &device)?;
380
381        assert_eq!(carry.z_h.dims(), &[2, 16, 256]);
382        assert_eq!(carry.z_l.dims(), &[2, 16, 256]);
383
384        Ok(())
385    }
386
387    #[test]
388    fn test_transformer_block() -> Result<()> {
389        let device = Device::Cpu;
390        let varmap = VarMap::new();
391        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
392
393        let mut config = TRMConfig::default();
394        config.hidden_size = 256;
395        config.num_heads = 8;
396
397        let block = TransformerBlock::new(config, vb)?;
398
399        let x = Tensor::randn(0f32, 1.0, (2, 16, 256), &device)?;
400        let out = block.forward(&x, None)?;
401
402        assert_eq!(out.dims(), &[2, 16, 256]);
403
404        Ok(())
405    }
406
407    #[test]
408    fn test_reasoning_module() -> Result<()> {
409        let device = Device::Cpu;
410        let varmap = VarMap::new();
411        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
412
413        let mut config = TRMConfig::default();
414        config.hidden_size = 256;
415        config.num_heads = 8;
416        config.l_layers = 2;
417
418        let module = ReasoningModule::new(2, config, vb)?;
419
420        let hidden = Tensor::randn(0f32, 1.0, (2, 16, 256), &device)?;
421        let injection = Tensor::randn(0f32, 1.0, (2, 16, 256), &device)?;
422
423        let out = module.forward(&hidden, &injection, None)?;
424
425        assert_eq!(out.dims(), &[2, 16, 256]);
426
427        Ok(())
428    }
429}