Skip to main content

trident/neural/
checkpoint.rs

1//! Checkpoint management for neural compiler v2.
2//!
3//! Uses burn's native record format (NamedMpk) for model weights.
4//! Supports stage-tagged checkpoints: stage1_best, stage2_latest, production.
5
6use std::path::PathBuf;
7
8use burn::module::Module;
9use burn::prelude::*;
10use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder};
11
12/// Checkpoint directory relative to repo root.
13const CHECKPOINT_DIR: &str = "model/general/v2";
14
15/// Checkpoint tag for naming saved files.
16#[derive(Debug, Clone, Copy)]
17pub enum CheckpointTag {
18    Stage1Best,
19    Stage2Latest,
20    Production,
21}
22
23impl CheckpointTag {
24    fn stem(&self) -> &'static str {
25        match self {
26            Self::Stage1Best => "stage1_best",
27            Self::Stage2Latest => "stage2_latest",
28            Self::Production => "production",
29        }
30    }
31}
32
33/// Resolve the checkpoint directory, creating it if needed.
34fn checkpoint_dir() -> PathBuf {
35    let mut dir = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
36    // Walk up to repo root (has Cargo.toml + vm/)
37    loop {
38        if dir.join("Cargo.toml").exists() && dir.join("vm").is_dir() {
39            break;
40        }
41        if !dir.pop() {
42            dir = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
43            break;
44        }
45    }
46    dir.join(CHECKPOINT_DIR)
47}
48
49/// Save a model checkpoint to disk.
50///
51/// Uses NamedMpk format with full precision (lossless).
52/// File will be at `model/general/v2/{tag}.mpk`.
53pub fn save_checkpoint<B: Backend, M: Module<B> + Clone>(
54    model: &M,
55    tag: CheckpointTag,
56    _device: &B::Device,
57) -> Result<PathBuf, String> {
58    let dir = checkpoint_dir();
59    std::fs::create_dir_all(&dir).map_err(|e| format!("mkdir {}: {}", dir.display(), e))?;
60
61    let path = dir.join(tag.stem());
62    let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
63    model
64        .clone()
65        .save_file(path.clone(), &recorder)
66        .map_err(|e| format!("save {}: {}", path.display(), e))?;
67
68    // burn appends .mpk extension
69    let full_path = path.with_extension("mpk");
70    Ok(full_path)
71}
72
73/// Load a model checkpoint from disk.
74///
75/// Returns the model with loaded weights, or None if checkpoint doesn't exist.
76pub fn load_checkpoint<B: Backend, M: Module<B>>(
77    model: M,
78    tag: CheckpointTag,
79    device: &B::Device,
80) -> Result<Option<M>, String> {
81    let dir = checkpoint_dir();
82    let path = dir.join(tag.stem());
83
84    // burn's NamedMpkFileRecorder appends .mpk
85    let full_path = path.with_extension("mpk");
86    if !full_path.exists() {
87        return Ok(None);
88    }
89
90    let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
91    let loaded = model
92        .load_file(path, &recorder, device)
93        .map_err(|e| format!("load {}: {}", full_path.display(), e))?;
94
95    Ok(Some(loaded))
96}
97
98/// Check which checkpoints exist on disk.
99pub fn available_checkpoints() -> Vec<(CheckpointTag, PathBuf)> {
100    let dir = checkpoint_dir();
101    let mut found = Vec::new();
102    for tag in [
103        CheckpointTag::Production,
104        CheckpointTag::Stage1Best,
105        CheckpointTag::Stage2Latest,
106    ] {
107        let path = dir.join(tag.stem()).with_extension("mpk");
108        if path.exists() {
109            found.push((tag, path));
110        }
111    }
112    found
113}
114
115/// Detect which training stage to run based on existing checkpoints.
116///
117/// - No checkpoints → Stage 1 (supervised)
118/// - Stage1Best exists → Stage 2 (GFlowNet)
119/// - Stage2Latest exists + replay ≥ threshold → Stage 3 (online)
120pub fn detect_stage(replay_count: usize, replay_threshold: usize) -> TrainingStage {
121    let dir = checkpoint_dir();
122
123    let has_stage2 = dir.join("stage2_latest.mpk").exists();
124    let has_stage1 = dir.join("stage1_best.mpk").exists();
125    let has_production = dir.join("production.mpk").exists();
126
127    if has_stage2 && replay_count >= replay_threshold {
128        TrainingStage::Stage3Online
129    } else if has_stage1 || has_production {
130        TrainingStage::Stage2GFlowNet
131    } else {
132        TrainingStage::Stage1Supervised
133    }
134}
135
136/// Which training stage the system should execute.
137#[derive(Debug, Clone, Copy, PartialEq, Eq)]
138pub enum TrainingStage {
139    Stage1Supervised,
140    Stage2GFlowNet,
141    Stage3Online,
142}
143
144impl std::fmt::Display for TrainingStage {
145    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146        match self {
147            Self::Stage1Supervised => write!(f, "Stage 1: supervised CE"),
148            Self::Stage2GFlowNet => write!(f, "Stage 2: GFlowNet TB"),
149            Self::Stage3Online => write!(f, "Stage 3: online learning"),
150        }
151    }
152}
153
154/// Promote a checkpoint to production (copy file).
155pub fn promote_to_production(source: CheckpointTag) -> Result<(), String> {
156    let dir = checkpoint_dir();
157    let src = dir.join(source.stem()).with_extension("mpk");
158    let dst = dir.join("production.mpk");
159
160    if !src.exists() {
161        return Err(format!("{} does not exist", src.display()));
162    }
163
164    std::fs::copy(&src, &dst)
165        .map_err(|e| format!("copy {} → {}: {}", src.display(), dst.display(), e))?;
166    Ok(())
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172
173    #[test]
174    fn detect_stage_returns_valid_stage() {
175        // detect_stage examines real filesystem; just verify it returns a valid stage
176        let stage = detect_stage(0, 100);
177        match stage {
178            TrainingStage::Stage1Supervised
179            | TrainingStage::Stage2GFlowNet
180            | TrainingStage::Stage3Online => {} // all valid
181        }
182    }
183
184    #[test]
185    fn checkpoint_tag_stems() {
186        assert_eq!(CheckpointTag::Stage1Best.stem(), "stage1_best");
187        assert_eq!(CheckpointTag::Stage2Latest.stem(), "stage2_latest");
188        assert_eq!(CheckpointTag::Production.stem(), "production");
189    }
190}