1use candle_core::{Result, Tensor, Device, DType};
3use candle_nn::{VarBuilder, Module};
4use crate::config::TRMConfig;
5use crate::layers::{Attention, SwiGLU, CastedEmbedding, RMSNorm, RotaryEmbedding};
6use crate::layers::normalization::rms_norm;
7use crate::layers::activations::CastedLinear;
8
9pub mod loader;
10
11#[derive(Debug, Clone)]
16pub struct InnerCarry {
17 pub z_h: Tensor,
19 pub z_l: Tensor,
21}
22
23impl InnerCarry {
24 pub fn new(z_h: Tensor, z_l: Tensor) -> Self {
26 Self { z_h, z_l }
27 }
28
29 pub fn empty(batch_size: usize, seq_len: usize, hidden_size: usize, dtype: DType, device: &Device) -> Result<Self> {
31 let z_h = Tensor::zeros((batch_size, seq_len, hidden_size), dtype, device)?;
32 let z_l = Tensor::zeros((batch_size, seq_len, hidden_size), dtype, device)?;
33 Ok(Self { z_h, z_l })
34 }
35}
36
37pub struct TransformerBlock {
44 config: TRMConfig,
45 self_attn: Option<Attention>,
46 mlp: SwiGLU,
47 norm_eps: f64,
48}
49
50impl TransformerBlock {
51 pub fn new(config: TRMConfig, vb: VarBuilder) -> Result<Self> {
53 let self_attn = if !config.mlp_t {
55 Some(Attention::new(
56 config.hidden_size,
57 config.head_dim(),
58 config.num_heads,
59 config.num_heads, false, vb.pp("self_attn"),
62 )?)
63 } else {
64 None
65 };
66
67 let mlp = SwiGLU::new(
69 config.hidden_size,
70 config.expansion,
71 vb.pp("mlp"),
72 )?;
73
74 Ok(Self {
75 config: config.clone(),
76 self_attn,
77 mlp,
78 norm_eps: 1e-5,
79 })
80 }
81
82 pub fn forward(
91 &self,
92 hidden_states: &Tensor,
93 cos_sin: Option<(&Tensor, &Tensor)>,
94 ) -> Result<Tensor> {
95 let mut hidden_states = hidden_states.clone();
96
97 if let Some(ref attn) = self.self_attn {
99 let attn_out = attn.forward(&hidden_states, cos_sin)?;
100 hidden_states = rms_norm(&(hidden_states + attn_out)?, self.norm_eps)?;
101 }
102
103 let mlp_out = self.mlp.forward(&hidden_states)?;
105 hidden_states = rms_norm(&(hidden_states + mlp_out)?, self.norm_eps)?;
106
107 Ok(hidden_states)
108 }
109}
110
111pub struct ReasoningModule {
115 layers: Vec<TransformerBlock>,
116}
117
118impl ReasoningModule {
119 pub fn new(num_layers: usize, config: TRMConfig, vb: VarBuilder) -> Result<Self> {
126 let mut layers = Vec::new();
127 for i in 0..num_layers {
128 layers.push(TransformerBlock::new(
129 config.clone(),
130 vb.pp(&format!("layer_{}", i)),
131 )?);
132 }
133
134 Ok(Self { layers })
135 }
136
137 pub fn forward(
147 &self,
148 hidden_states: &Tensor,
149 input_injection: &Tensor,
150 cos_sin: Option<(&Tensor, &Tensor)>,
151 ) -> Result<Tensor> {
152 let mut hidden_states = (hidden_states + input_injection)?;
154
155 for layer in &self.layers {
157 hidden_states = layer.forward(&hidden_states, cos_sin)?;
158 }
159
160 Ok(hidden_states)
161 }
162}
163
164pub struct TinyRecursiveModel {
168 config: TRMConfig,
169
170 embed_tokens: CastedEmbedding,
172 lm_head: CastedLinear,
173 embed_scale: f64,
174
175 rotary_emb: Option<RotaryEmbedding>,
177
178 l_level: ReasoningModule,
180
181 h_init: Tensor,
183 l_init: Tensor,
184
185 device: Device,
187}
188
189impl TinyRecursiveModel {
190 pub fn new(config: TRMConfig, vb: VarBuilder) -> crate::Result<Self> {
192 config.validate()?;
193
194 let device = vb.device().clone();
195 let dtype = vb.dtype();
196
197 let embed_scale = (config.hidden_size as f64).sqrt();
199
200 let embed_tokens = CastedEmbedding::new(
202 config.vocab_size,
203 config.hidden_size,
204 vb.pp("embed_tokens"),
205 dtype,
206 )?;
207
208 let lm_head = CastedLinear::new(
210 config.hidden_size,
211 config.num_outputs,
212 false,
213 vb.pp("lm_head"),
214 )?;
215
216 let rotary_emb = if config.pos_encodings == "rope" {
218 Some(RotaryEmbedding::new(
219 config.head_dim(),
220 2048, 10000.0,
222 &device,
223 )?)
224 } else {
225 None
226 };
227
228 let l_level = ReasoningModule::new(
230 config.l_layers,
231 config.clone(),
232 vb.pp("l_level"),
233 )?;
234
235 let h_init = vb.get(config.hidden_size, "h_init")?;
237 let l_init = vb.get(config.hidden_size, "l_init")?;
238
239 Ok(Self {
240 config,
241 embed_tokens,
242 lm_head,
243 embed_scale,
244 rotary_emb,
245 l_level,
246 h_init,
247 l_init,
248 device,
249 })
250 }
251
252 pub fn empty_carry(&self, batch_size: usize) -> Result<InnerCarry> {
254 InnerCarry::empty(
255 batch_size,
256 self.config.vocab_size, self.config.hidden_size,
258 DType::F32,
259 &self.device,
260 )
261 }
262
263 pub fn reset_carry(&self, reset_flag: &Tensor, carry: &InnerCarry) -> Result<InnerCarry> {
269 let reset_flag = reset_flag.unsqueeze(1)?.unsqueeze(1)?;
271
272 let batch_size = carry.z_h.dim(0)?;
274 let seq_len = carry.z_h.dim(1)?;
275
276 let h_init = self.h_init
277 .unsqueeze(0)?
278 .unsqueeze(0)?
279 .broadcast_as((batch_size, seq_len, self.config.hidden_size))?;
280
281 let l_init = self.l_init
282 .unsqueeze(0)?
283 .unsqueeze(0)?
284 .broadcast_as((batch_size, seq_len, self.config.hidden_size))?;
285
286 let z_h = reset_flag.where_cond(&h_init, &carry.z_h)?;
288 let z_l = reset_flag.where_cond(&l_init, &carry.z_l)?;
289
290 Ok(InnerCarry::new(z_h, z_l))
291 }
292
293 fn input_embeddings(&self, input: &Tensor) -> Result<Tensor> {
295 let embedding = self.embed_tokens.forward(input)?;
297
298 embedding.affine(self.embed_scale, 0.0)
300 }
301
302 pub fn forward(&self, carry: &InnerCarry, input: &Tensor) -> Result<(InnerCarry, Tensor)> {
313 let seq_len = input.dim(1)?;
314
315 let cos_sin = if let Some(ref rope) = self.rotary_emb {
317 let (cos, sin) = rope.forward_with_len(seq_len)?;
318 Some((cos, sin))
319 } else {
320 None
321 };
322
323 let input_embeddings = self.input_embeddings(input)?;
325
326 let mut z_h = carry.z_h.clone();
328 let mut z_l = carry.z_l.clone();
329
330 for _h_step in 0..self.config.h_cycles {
342 for _l_step in 0..self.config.l_cycles {
344 let injection = (&z_h + &input_embeddings)?;
345 z_l = self.l_level.forward(
346 &z_l,
347 &injection,
348 cos_sin.as_ref().map(|(c, s)| (c.as_ref(), s.as_ref())),
349 )?;
350 }
351
352 z_h = self.l_level.forward(
354 &z_h,
355 &z_l,
356 cos_sin.as_ref().map(|(c, s)| (c.as_ref(), s.as_ref())),
357 )?;
358 }
359
360 let logits = self.lm_head.forward(&z_h)?;
362
363 let new_carry = InnerCarry::new(z_h.clone(), z_l.clone());
365
366 Ok((new_carry, logits))
367 }
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373 use candle_nn::VarMap;
374
375 #[test]
376 fn test_inner_carry_creation() -> Result<()> {
377 let device = Device::Cpu;
378
379 let carry = InnerCarry::empty(2, 16, 256, DType::F32, &device)?;
380
381 assert_eq!(carry.z_h.dims(), &[2, 16, 256]);
382 assert_eq!(carry.z_l.dims(), &[2, 16, 256]);
383
384 Ok(())
385 }
386
387 #[test]
388 fn test_transformer_block() -> Result<()> {
389 let device = Device::Cpu;
390 let varmap = VarMap::new();
391 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
392
393 let mut config = TRMConfig::default();
394 config.hidden_size = 256;
395 config.num_heads = 8;
396
397 let block = TransformerBlock::new(config, vb)?;
398
399 let x = Tensor::randn(0f32, 1.0, (2, 16, 256), &device)?;
400 let out = block.forward(&x, None)?;
401
402 assert_eq!(out.dims(), &[2, 16, 256]);
403
404 Ok(())
405 }
406
407 #[test]
408 fn test_reasoning_module() -> Result<()> {
409 let device = Device::Cpu;
410 let varmap = VarMap::new();
411 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
412
413 let mut config = TRMConfig::default();
414 config.hidden_size = 256;
415 config.num_heads = 8;
416 config.l_layers = 2;
417
418 let module = ReasoningModule::new(2, config, vb)?;
419
420 let hidden = Tensor::randn(0f32, 1.0, (2, 16, 256), &device)?;
421 let injection = Tensor::randn(0f32, 1.0, (2, 16, 256), &device)?;
422
423 let out = module.forward(&hidden, &injection, None)?;
424
425 assert_eq!(out.dims(), &[2, 16, 256]);
426
427 Ok(())
428 }
429}