tiny_recursive_rs/
config.rs

1/// Configuration for Tiny Recursive Model
2///
3/// Based on TinyRecursiveReasoningModel_ACTV1Config from the Python implementation.
4#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
5pub struct TRMConfig {
6    /// Embedding/hidden dimension
7    pub hidden_size: usize,
8
9    /// Number of high-level reasoning cycles
10    pub h_cycles: usize,
11
12    /// Number of low-level update cycles per H-cycle
13    pub l_cycles: usize,
14
15    /// Number of layers in L-level (low-level) blocks
16    pub l_layers: usize,
17
18    /// Number of attention heads
19    pub num_heads: usize,
20
21    /// FFN expansion factor (hidden_size * expansion)
22    pub expansion: f32,
23
24    /// Positional encoding type: "rope", "learned", or "none"
25    pub pos_encodings: String,
26
27    /// Use MLP instead of transformer (smaller, faster)
28    pub mlp_t: bool,
29
30    /// Maximum steps for ACT halting
31    pub halt_max_steps: usize,
32
33    /// Dropout probability
34    pub dropout: f32,
35
36    /// Vocabulary size (for embeddings)
37    pub vocab_size: usize,
38
39    /// Number of output classes/tokens
40    pub num_outputs: usize,
41}
42
43impl Default for TRMConfig {
44    fn default() -> Self {
45        Self {
46            hidden_size: 256,
47            h_cycles: 3,
48            l_cycles: 6,
49            l_layers: 2,
50            num_heads: 8,
51            expansion: 4.0,
52            pos_encodings: "rope".to_string(),
53            mlp_t: false,
54            halt_max_steps: 10,
55            dropout: 0.0,
56            vocab_size: 50257, // GPT-2 vocab size as default
57            num_outputs: 50257,
58        }
59    }
60}
61
62impl TRMConfig {
63    /// Validate configuration
64    pub fn validate(&self) -> crate::Result<()> {
65        if self.hidden_size == 0 {
66            return Err(crate::TRMError::Config(
67                "hidden_size must be > 0".to_string(),
68            ));
69        }
70
71        if self.hidden_size % self.num_heads != 0 {
72            return Err(crate::TRMError::Config(
73                "hidden_size must be divisible by num_heads".to_string(),
74            ));
75        }
76
77        if self.h_cycles == 0 || self.l_cycles == 0 {
78            return Err(crate::TRMError::Config(
79                "h_cycles and l_cycles must be > 0".to_string(),
80            ));
81        }
82
83        if !["rope", "learned", "none"].contains(&self.pos_encodings.as_str()) {
84            return Err(crate::TRMError::Config(format!(
85                "Invalid pos_encodings: {}. Must be 'rope', 'learned', or 'none'",
86                self.pos_encodings
87            )));
88        }
89
90        Ok(())
91    }
92
93    /// Get head dimension
94    pub fn head_dim(&self) -> usize {
95        self.hidden_size / self.num_heads
96    }
97
98    /// Get FFN hidden size
99    pub fn ffn_hidden_size(&self) -> usize {
100        (self.hidden_size as f32 * self.expansion) as usize
101    }
102}