Skip to main content

trident/neural/model/
grammar.rs

1//! CPU Grammar Mask — stack state machine for TASM validity.
2//!
3//! Tracks abstract stack state (depth + element types) and produces
4//! a validity mask over the vocabulary at each decoding step.
5//! Used during training (teacher forcing) to precompute masks for the
6//! entire target sequence, and during inference as a CPU fallback.
7
8use super::grammar_tables::{build_min_stack_depths, build_stack_effects, StackEffect};
9use super::vocab::VOCAB_SIZE;
10
11/// Maximum stack depth we track. Beyond this we stop tracking types
12/// but still track depth as an integer.
13const MAX_TRACKED_DEPTH: usize = 64;
14
15/// Stack type window size — how many top-of-stack slots we encode
16/// type information for.
17pub const TYPE_WINDOW: usize = 8;
18
19/// Element type for abstract type tracking.
20#[derive(Clone, Copy, Debug, PartialEq, Eq)]
21#[repr(u8)]
22pub enum ElemType {
23    BFE = 0,
24    XFE = 1,
25    Unknown = 2,
26}
27
28/// Stack state machine for grammar masking.
29///
30/// Tracks stack depth and the types of the top `TYPE_WINDOW` elements.
31/// At each step, can produce a validity mask indicating which VOCAB
32/// tokens are legal given the current stack state.
33pub struct StackStateMachine {
34    depth: i32,
35    /// Types of the top elements (index 0 = TOS).
36    types: Vec<ElemType>,
37    /// Precomputed stack effects table.
38    effects: Vec<StackEffect>,
39    /// Precomputed minimum depth requirements.
40    min_depths: Vec<i32>,
41}
42
43impl StackStateMachine {
44    /// Create a new state machine with the given initial stack depth.
45    pub fn new(initial_depth: i32) -> Self {
46        let types = vec![ElemType::Unknown; initial_depth.max(0) as usize];
47        Self {
48            depth: initial_depth,
49            types,
50            effects: build_stack_effects(),
51            min_depths: build_min_stack_depths(),
52        }
53    }
54
55    /// Current stack depth.
56    pub fn stack_depth(&self) -> i32 {
57        self.depth
58    }
59
60    /// Advance the state machine by executing a token.
61    pub fn step(&mut self, token: u32) {
62        if token == 0 || token as usize >= VOCAB_SIZE {
63            return; // EOS or invalid — no state change
64        }
65        let idx = token as usize;
66        let effect = self.effects[idx];
67
68        // Update types: pop, then push Unknown
69        for _ in 0..effect.pops {
70            if !self.types.is_empty() {
71                self.types.pop();
72            }
73        }
74        for _ in 0..effect.pushes {
75            self.types.push(ElemType::Unknown);
76        }
77
78        // Handle special cases for type tracking
79        self.update_types_for_op(token);
80
81        // Update depth
82        self.depth += effect.net();
83        if self.depth < 0 {
84            self.depth = 0;
85        }
86
87        // Cap tracked types
88        if self.types.len() > MAX_TRACKED_DEPTH {
89            self.types.truncate(MAX_TRACKED_DEPTH);
90        }
91    }
92
93    /// Update type annotations for specific operations.
94    fn update_types_for_op(&mut self, token: u32) {
95        let idx = token as usize;
96        match idx {
97            // push constants: always BFE
98            1..=14 => {
99                if let Some(last) = self.types.last_mut() {
100                    *last = ElemType::BFE;
101                }
102            }
103            // dup 0-15: the pushed element has same type as source.
104            // After the generic push (Unknown already appended to types),
105            // the source is at len - 2 - n (one deeper because of the push).
106            20..=35 => {
107                let n = (idx - 20) as usize;
108                let len = self.types.len();
109                if len >= 2 + n {
110                    let src_type = self.types[len - 2 - n];
111                    self.types[len - 1] = src_type;
112                }
113            }
114            // Extension field ops push XFE
115            136 => {
116                // x_invert: pops 3 XFE, pushes 3 XFE
117                let len = self.types.len();
118                if len >= 3 {
119                    for i in 0..3 {
120                        self.types[len - 1 - i] = ElemType::XFE;
121                    }
122                }
123            }
124            137 => {
125                // xb_mul: pushes 3 XFE
126                let len = self.types.len();
127                if len >= 3 {
128                    for i in 0..3 {
129                        self.types[len - 1 - i] = ElemType::XFE;
130                    }
131                }
132            }
133            // Most ops produce BFE results
134            83..=95 => {
135                // arithmetic, comparison, bitwise: result is BFE
136                if let Some(last) = self.types.last_mut() {
137                    *last = ElemType::BFE;
138                }
139            }
140            _ => {}
141        }
142    }
143
144    /// Produce a validity mask over the vocabulary.
145    /// Returns VOCAB_SIZE floats: 0.0 = valid, -1e9 = masked (invalid).
146    pub fn valid_mask(&self) -> Vec<f32> {
147        let mut mask = vec![0.0f32; VOCAB_SIZE];
148
149        for token_id in 1..VOCAB_SIZE {
150            let min_depth = self.min_depths[token_id];
151            if self.depth < min_depth {
152                mask[token_id] = -1e9;
153            }
154        }
155
156        // EOS is always valid (token 0)
157        mask[0] = 0.0;
158
159        mask
160    }
161
162    /// Encode the type state of the top TYPE_WINDOW stack slots
163    /// as a flat vector of 3*TYPE_WINDOW floats (one-hot per slot).
164    pub fn type_encoding(&self) -> Vec<f32> {
165        let mut encoding = vec![0.0f32; 3 * TYPE_WINDOW];
166        for i in 0..TYPE_WINDOW {
167            let elem_type = if i < self.types.len() {
168                self.types[self.types.len() - 1 - i]
169            } else {
170                ElemType::Unknown // below tracked depth
171            };
172            let base = i * 3;
173            encoding[base + elem_type as usize] = 1.0;
174        }
175        encoding
176    }
177
178    /// Clamped stack depth for embedding lookup (0..max_stack_depth-1).
179    pub fn depth_for_embedding(&self, max_depth: usize) -> u32 {
180        (self.depth.max(0) as usize).min(max_depth - 1) as u32
181    }
182}
183
184/// Precompute masks for an entire target sequence (teacher forcing).
185///
186/// Given a sequence of ground-truth tokens, simulates the state machine
187/// and returns the validity mask at each step. Used during training to
188/// apply grammar constraints without GPU-side state tracking.
189///
190/// Also returns stack depths and type encodings for decoder input.
191pub fn precompute_sequence_state(target_tokens: &[u32], initial_depth: i32) -> SequenceState {
192    let seq_len = target_tokens.len();
193    let mut masks = Vec::with_capacity(seq_len);
194    let mut depths = Vec::with_capacity(seq_len);
195    let mut type_states = Vec::with_capacity(seq_len);
196
197    let mut sm = StackStateMachine::new(initial_depth);
198
199    for &token in target_tokens {
200        // Record state BEFORE executing this token
201        masks.push(sm.valid_mask());
202        depths.push(sm.depth_for_embedding(65));
203        type_states.push(sm.type_encoding());
204
205        // Execute the token to advance state
206        sm.step(token);
207    }
208
209    SequenceState {
210        masks,
211        depths,
212        type_states,
213    }
214}
215
216/// Precomputed sequence state for training.
217pub struct SequenceState {
218    /// Validity masks: [seq_len][VOCAB_SIZE], 0.0 or -1e9.
219    pub masks: Vec<Vec<f32>>,
220    /// Stack depths: [seq_len], clamped for embedding.
221    pub depths: Vec<u32>,
222    /// Type encodings: [seq_len][3*TYPE_WINDOW].
223    pub type_states: Vec<Vec<f32>>,
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    #[test]
231    fn initial_state_empty_stack() {
232        let sm = StackStateMachine::new(0);
233        assert_eq!(sm.stack_depth(), 0);
234    }
235
236    #[test]
237    fn push_increases_depth() {
238        let mut sm = StackStateMachine::new(0);
239        sm.step(3); // push 1
240        assert_eq!(sm.stack_depth(), 1);
241        sm.step(4); // push 2
242        assert_eq!(sm.stack_depth(), 2);
243    }
244
245    #[test]
246    fn add_decreases_depth() {
247        let mut sm = StackStateMachine::new(0);
248        sm.step(3); // push 1
249        sm.step(4); // push 2
250        sm.step(83); // add
251        assert_eq!(sm.stack_depth(), 1);
252    }
253
254    #[test]
255    fn mask_prevents_underflow() {
256        let sm = StackStateMachine::new(0);
257        let mask = sm.valid_mask();
258        // add (token 83) needs depth 2, should be masked
259        assert_eq!(mask[83], -1e9);
260        // push 1 (token 3) needs depth 0, should be valid
261        assert_eq!(mask[3], 0.0);
262        // EOS always valid
263        assert_eq!(mask[0], 0.0);
264    }
265
266    #[test]
267    fn mask_allows_valid_ops() {
268        let mut sm = StackStateMachine::new(0);
269        sm.step(3); // push 1
270        sm.step(4); // push 2
271        let mask = sm.valid_mask();
272        // depth=2, add needs 2 → valid
273        assert_eq!(mask[83], 0.0);
274        // dup 0 needs 1 → valid
275        assert_eq!(mask[20], 0.0);
276        // dup 15 needs 16 → masked
277        assert_eq!(mask[35], -1e9);
278    }
279
280    #[test]
281    fn dup_preserves_type() {
282        let mut sm = StackStateMachine::new(0);
283        sm.step(3); // push 1 → BFE
284        sm.step(20); // dup 0
285        assert_eq!(sm.stack_depth(), 2);
286        // Both top elements should be BFE
287        let enc = sm.type_encoding();
288        // TOS (i=0) should be BFE (index 0)
289        assert_eq!(enc[0], 1.0); // BFE
290        assert_eq!(enc[1], 0.0); // XFE
291        assert_eq!(enc[2], 0.0); // Unknown
292    }
293
294    #[test]
295    fn type_encoding_shape() {
296        let sm = StackStateMachine::new(5);
297        let enc = sm.type_encoding();
298        assert_eq!(enc.len(), 3 * TYPE_WINDOW);
299    }
300
301    #[test]
302    fn precompute_sequence_lengths() {
303        let tokens = vec![3, 4, 83]; // push 1, push 2, add
304        let state = precompute_sequence_state(&tokens, 0);
305        assert_eq!(state.masks.len(), 3);
306        assert_eq!(state.depths.len(), 3);
307        assert_eq!(state.type_states.len(), 3);
308        assert_eq!(state.masks[0].len(), VOCAB_SIZE);
309        assert_eq!(state.type_states[0].len(), 3 * TYPE_WINDOW);
310    }
311
312    #[test]
313    fn precompute_masks_reflect_state() {
314        let tokens = vec![3, 83]; // push 1, then add
315        let state = precompute_sequence_state(&tokens, 0);
316        // Before push: depth=0, add should be masked
317        assert_eq!(state.masks[0][83], -1e9);
318        // After push: depth=1, add still masked (needs 2)
319        assert_eq!(state.masks[1][83], -1e9);
320    }
321
322    #[test]
323    fn precompute_depths_advance() {
324        let tokens = vec![3, 4, 83]; // push, push, add
325        let state = precompute_sequence_state(&tokens, 0);
326        assert_eq!(state.depths[0], 0); // before push 1
327        assert_eq!(state.depths[1], 1); // after push 1, before push 2
328        assert_eq!(state.depths[2], 2); // after push 2, before add
329    }
330
331    #[test]
332    fn pop_reduces_depth() {
333        let mut sm = StackStateMachine::new(0);
334        sm.step(3); // push 1
335        sm.step(4); // push 2
336        sm.step(3); // push 1
337        sm.step(16); // pop 2
338        assert_eq!(sm.stack_depth(), 1);
339    }
340
341    #[test]
342    fn depth_clamps_at_zero() {
343        let mut sm = StackStateMachine::new(1);
344        sm.step(15); // pop 1
345        assert_eq!(sm.stack_depth(), 0);
346        // Shouldn't go negative even if we somehow pop more
347    }
348
349    #[test]
350    fn write_mem_5_needs_six_on_stack() {
351        let sm = StackStateMachine::new(5);
352        let mask = sm.valid_mask();
353        // write_mem 5 (token 128) needs 6 elements
354        assert_eq!(mask[128], -1e9);
355
356        let sm2 = StackStateMachine::new(6);
357        let mask2 = sm2.valid_mask();
358        assert_eq!(mask2[128], 0.0);
359    }
360
361    #[test]
362    fn hash_needs_ten() {
363        let sm = StackStateMachine::new(9);
364        let mask = sm.valid_mask();
365        assert_eq!(mask[129], -1e9); // hash needs 10
366
367        let sm2 = StackStateMachine::new(10);
368        let mask2 = sm2.valid_mask();
369        assert_eq!(mask2[129], 0.0);
370    }
371}