tiny_recursive_rs/
config.rs1#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
5pub struct TRMConfig {
6 pub hidden_size: usize,
8
9 pub h_cycles: usize,
11
12 pub l_cycles: usize,
14
15 pub l_layers: usize,
17
18 pub num_heads: usize,
20
21 pub expansion: f32,
23
24 pub pos_encodings: String,
26
27 pub mlp_t: bool,
29
30 pub halt_max_steps: usize,
32
33 pub dropout: f32,
35
36 pub vocab_size: usize,
38
39 pub num_outputs: usize,
41}
42
43impl Default for TRMConfig {
44 fn default() -> Self {
45 Self {
46 hidden_size: 256,
47 h_cycles: 3,
48 l_cycles: 6,
49 l_layers: 2,
50 num_heads: 8,
51 expansion: 4.0,
52 pos_encodings: "rope".to_string(),
53 mlp_t: false,
54 halt_max_steps: 10,
55 dropout: 0.0,
56 vocab_size: 50257, num_outputs: 50257,
58 }
59 }
60}
61
62impl TRMConfig {
63 pub fn validate(&self) -> crate::Result<()> {
65 if self.hidden_size == 0 {
66 return Err(crate::TRMError::Config(
67 "hidden_size must be > 0".to_string(),
68 ));
69 }
70
71 if self.hidden_size % self.num_heads != 0 {
72 return Err(crate::TRMError::Config(
73 "hidden_size must be divisible by num_heads".to_string(),
74 ));
75 }
76
77 if self.h_cycles == 0 || self.l_cycles == 0 {
78 return Err(crate::TRMError::Config(
79 "h_cycles and l_cycles must be > 0".to_string(),
80 ));
81 }
82
83 if !["rope", "learned", "none"].contains(&self.pos_encodings.as_str()) {
84 return Err(crate::TRMError::Config(format!(
85 "Invalid pos_encodings: {}. Must be 'rope', 'learned', or 'none'",
86 self.pos_encodings
87 )));
88 }
89
90 Ok(())
91 }
92
93 pub fn head_dim(&self) -> usize {
95 self.hidden_size / self.num_heads
96 }
97
98 pub fn ffn_hidden_size(&self) -> usize {
100 (self.hidden_size as f32 * self.expansion) as usize
101 }
102}