trident/neural/model/
grammar.rs1use super::grammar_tables::{build_min_stack_depths, build_stack_effects, StackEffect};
9use super::vocab::VOCAB_SIZE;
10
11const MAX_TRACKED_DEPTH: usize = 64;
14
15pub const TYPE_WINDOW: usize = 8;
18
19#[derive(Clone, Copy, Debug, PartialEq, Eq)]
21#[repr(u8)]
22pub enum ElemType {
23 BFE = 0,
24 XFE = 1,
25 Unknown = 2,
26}
27
28pub struct StackStateMachine {
34 depth: i32,
35 types: Vec<ElemType>,
37 effects: Vec<StackEffect>,
39 min_depths: Vec<i32>,
41}
42
43impl StackStateMachine {
44 pub fn new(initial_depth: i32) -> Self {
46 let types = vec![ElemType::Unknown; initial_depth.max(0) as usize];
47 Self {
48 depth: initial_depth,
49 types,
50 effects: build_stack_effects(),
51 min_depths: build_min_stack_depths(),
52 }
53 }
54
55 pub fn stack_depth(&self) -> i32 {
57 self.depth
58 }
59
60 pub fn step(&mut self, token: u32) {
62 if token == 0 || token as usize >= VOCAB_SIZE {
63 return; }
65 let idx = token as usize;
66 let effect = self.effects[idx];
67
68 for _ in 0..effect.pops {
70 if !self.types.is_empty() {
71 self.types.pop();
72 }
73 }
74 for _ in 0..effect.pushes {
75 self.types.push(ElemType::Unknown);
76 }
77
78 self.update_types_for_op(token);
80
81 self.depth += effect.net();
83 if self.depth < 0 {
84 self.depth = 0;
85 }
86
87 if self.types.len() > MAX_TRACKED_DEPTH {
89 self.types.truncate(MAX_TRACKED_DEPTH);
90 }
91 }
92
93 fn update_types_for_op(&mut self, token: u32) {
95 let idx = token as usize;
96 match idx {
97 1..=14 => {
99 if let Some(last) = self.types.last_mut() {
100 *last = ElemType::BFE;
101 }
102 }
103 20..=35 => {
107 let n = (idx - 20) as usize;
108 let len = self.types.len();
109 if len >= 2 + n {
110 let src_type = self.types[len - 2 - n];
111 self.types[len - 1] = src_type;
112 }
113 }
114 136 => {
116 let len = self.types.len();
118 if len >= 3 {
119 for i in 0..3 {
120 self.types[len - 1 - i] = ElemType::XFE;
121 }
122 }
123 }
124 137 => {
125 let len = self.types.len();
127 if len >= 3 {
128 for i in 0..3 {
129 self.types[len - 1 - i] = ElemType::XFE;
130 }
131 }
132 }
133 83..=95 => {
135 if let Some(last) = self.types.last_mut() {
137 *last = ElemType::BFE;
138 }
139 }
140 _ => {}
141 }
142 }
143
144 pub fn valid_mask(&self) -> Vec<f32> {
147 let mut mask = vec![0.0f32; VOCAB_SIZE];
148
149 for token_id in 1..VOCAB_SIZE {
150 let min_depth = self.min_depths[token_id];
151 if self.depth < min_depth {
152 mask[token_id] = -1e9;
153 }
154 }
155
156 mask[0] = 0.0;
158
159 mask
160 }
161
162 pub fn type_encoding(&self) -> Vec<f32> {
165 let mut encoding = vec![0.0f32; 3 * TYPE_WINDOW];
166 for i in 0..TYPE_WINDOW {
167 let elem_type = if i < self.types.len() {
168 self.types[self.types.len() - 1 - i]
169 } else {
170 ElemType::Unknown };
172 let base = i * 3;
173 encoding[base + elem_type as usize] = 1.0;
174 }
175 encoding
176 }
177
178 pub fn depth_for_embedding(&self, max_depth: usize) -> u32 {
180 (self.depth.max(0) as usize).min(max_depth - 1) as u32
181 }
182}
183
184pub fn precompute_sequence_state(target_tokens: &[u32], initial_depth: i32) -> SequenceState {
192 let seq_len = target_tokens.len();
193 let mut masks = Vec::with_capacity(seq_len);
194 let mut depths = Vec::with_capacity(seq_len);
195 let mut type_states = Vec::with_capacity(seq_len);
196
197 let mut sm = StackStateMachine::new(initial_depth);
198
199 for &token in target_tokens {
200 masks.push(sm.valid_mask());
202 depths.push(sm.depth_for_embedding(65));
203 type_states.push(sm.type_encoding());
204
205 sm.step(token);
207 }
208
209 SequenceState {
210 masks,
211 depths,
212 type_states,
213 }
214}
215
216pub struct SequenceState {
218 pub masks: Vec<Vec<f32>>,
220 pub depths: Vec<u32>,
222 pub type_states: Vec<Vec<f32>>,
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229
230 #[test]
231 fn initial_state_empty_stack() {
232 let sm = StackStateMachine::new(0);
233 assert_eq!(sm.stack_depth(), 0);
234 }
235
236 #[test]
237 fn push_increases_depth() {
238 let mut sm = StackStateMachine::new(0);
239 sm.step(3); assert_eq!(sm.stack_depth(), 1);
241 sm.step(4); assert_eq!(sm.stack_depth(), 2);
243 }
244
245 #[test]
246 fn add_decreases_depth() {
247 let mut sm = StackStateMachine::new(0);
248 sm.step(3); sm.step(4); sm.step(83); assert_eq!(sm.stack_depth(), 1);
252 }
253
254 #[test]
255 fn mask_prevents_underflow() {
256 let sm = StackStateMachine::new(0);
257 let mask = sm.valid_mask();
258 assert_eq!(mask[83], -1e9);
260 assert_eq!(mask[3], 0.0);
262 assert_eq!(mask[0], 0.0);
264 }
265
266 #[test]
267 fn mask_allows_valid_ops() {
268 let mut sm = StackStateMachine::new(0);
269 sm.step(3); sm.step(4); let mask = sm.valid_mask();
272 assert_eq!(mask[83], 0.0);
274 assert_eq!(mask[20], 0.0);
276 assert_eq!(mask[35], -1e9);
278 }
279
280 #[test]
281 fn dup_preserves_type() {
282 let mut sm = StackStateMachine::new(0);
283 sm.step(3); sm.step(20); assert_eq!(sm.stack_depth(), 2);
286 let enc = sm.type_encoding();
288 assert_eq!(enc[0], 1.0); assert_eq!(enc[1], 0.0); assert_eq!(enc[2], 0.0); }
293
294 #[test]
295 fn type_encoding_shape() {
296 let sm = StackStateMachine::new(5);
297 let enc = sm.type_encoding();
298 assert_eq!(enc.len(), 3 * TYPE_WINDOW);
299 }
300
301 #[test]
302 fn precompute_sequence_lengths() {
303 let tokens = vec![3, 4, 83]; let state = precompute_sequence_state(&tokens, 0);
305 assert_eq!(state.masks.len(), 3);
306 assert_eq!(state.depths.len(), 3);
307 assert_eq!(state.type_states.len(), 3);
308 assert_eq!(state.masks[0].len(), VOCAB_SIZE);
309 assert_eq!(state.type_states[0].len(), 3 * TYPE_WINDOW);
310 }
311
312 #[test]
313 fn precompute_masks_reflect_state() {
314 let tokens = vec![3, 83]; let state = precompute_sequence_state(&tokens, 0);
316 assert_eq!(state.masks[0][83], -1e9);
318 assert_eq!(state.masks[1][83], -1e9);
320 }
321
322 #[test]
323 fn precompute_depths_advance() {
324 let tokens = vec![3, 4, 83]; let state = precompute_sequence_state(&tokens, 0);
326 assert_eq!(state.depths[0], 0); assert_eq!(state.depths[1], 1); assert_eq!(state.depths[2], 2); }
330
331 #[test]
332 fn pop_reduces_depth() {
333 let mut sm = StackStateMachine::new(0);
334 sm.step(3); sm.step(4); sm.step(3); sm.step(16); assert_eq!(sm.stack_depth(), 1);
339 }
340
341 #[test]
342 fn depth_clamps_at_zero() {
343 let mut sm = StackStateMachine::new(1);
344 sm.step(15); assert_eq!(sm.stack_depth(), 0);
346 }
348
349 #[test]
350 fn write_mem_5_needs_six_on_stack() {
351 let sm = StackStateMachine::new(5);
352 let mask = sm.valid_mask();
353 assert_eq!(mask[128], -1e9);
355
356 let sm2 = StackStateMachine::new(6);
357 let mask2 = sm2.valid_mask();
358 assert_eq!(mask2[128], 0.0);
359 }
360
361 #[test]
362 fn hash_needs_ten() {
363 let sm = StackStateMachine::new(9);
364 let mask = sm.valid_mask();
365 assert_eq!(mask[129], -1e9); let sm2 = StackStateMachine::new(10);
368 let mask2 = sm2.valid_mask();
369 assert_eq!(mask2[129], 0.0);
370 }
371}