trident/neural/
checkpoint.rs1use std::path::PathBuf;
7
8use burn::module::Module;
9use burn::prelude::*;
10use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder};
11
12const CHECKPOINT_DIR: &str = "model/general/v2";
14
15#[derive(Debug, Clone, Copy)]
17pub enum CheckpointTag {
18 Stage1Best,
19 Stage2Latest,
20 Production,
21}
22
23impl CheckpointTag {
24 fn stem(&self) -> &'static str {
25 match self {
26 Self::Stage1Best => "stage1_best",
27 Self::Stage2Latest => "stage2_latest",
28 Self::Production => "production",
29 }
30 }
31}
32
33fn checkpoint_dir() -> PathBuf {
35 let mut dir = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
36 loop {
38 if dir.join("Cargo.toml").exists() && dir.join("vm").is_dir() {
39 break;
40 }
41 if !dir.pop() {
42 dir = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
43 break;
44 }
45 }
46 dir.join(CHECKPOINT_DIR)
47}
48
49pub fn save_checkpoint<B: Backend, M: Module<B> + Clone>(
54 model: &M,
55 tag: CheckpointTag,
56 _device: &B::Device,
57) -> Result<PathBuf, String> {
58 let dir = checkpoint_dir();
59 std::fs::create_dir_all(&dir).map_err(|e| format!("mkdir {}: {}", dir.display(), e))?;
60
61 let path = dir.join(tag.stem());
62 let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
63 model
64 .clone()
65 .save_file(path.clone(), &recorder)
66 .map_err(|e| format!("save {}: {}", path.display(), e))?;
67
68 let full_path = path.with_extension("mpk");
70 Ok(full_path)
71}
72
73pub fn load_checkpoint<B: Backend, M: Module<B>>(
77 model: M,
78 tag: CheckpointTag,
79 device: &B::Device,
80) -> Result<Option<M>, String> {
81 let dir = checkpoint_dir();
82 let path = dir.join(tag.stem());
83
84 let full_path = path.with_extension("mpk");
86 if !full_path.exists() {
87 return Ok(None);
88 }
89
90 let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
91 let loaded = model
92 .load_file(path, &recorder, device)
93 .map_err(|e| format!("load {}: {}", full_path.display(), e))?;
94
95 Ok(Some(loaded))
96}
97
98pub fn available_checkpoints() -> Vec<(CheckpointTag, PathBuf)> {
100 let dir = checkpoint_dir();
101 let mut found = Vec::new();
102 for tag in [
103 CheckpointTag::Production,
104 CheckpointTag::Stage1Best,
105 CheckpointTag::Stage2Latest,
106 ] {
107 let path = dir.join(tag.stem()).with_extension("mpk");
108 if path.exists() {
109 found.push((tag, path));
110 }
111 }
112 found
113}
114
115pub fn detect_stage(replay_count: usize, replay_threshold: usize) -> TrainingStage {
121 let dir = checkpoint_dir();
122
123 let has_stage2 = dir.join("stage2_latest.mpk").exists();
124 let has_stage1 = dir.join("stage1_best.mpk").exists();
125 let has_production = dir.join("production.mpk").exists();
126
127 if has_stage2 && replay_count >= replay_threshold {
128 TrainingStage::Stage3Online
129 } else if has_stage1 || has_production {
130 TrainingStage::Stage2GFlowNet
131 } else {
132 TrainingStage::Stage1Supervised
133 }
134}
135
136#[derive(Debug, Clone, Copy, PartialEq, Eq)]
138pub enum TrainingStage {
139 Stage1Supervised,
140 Stage2GFlowNet,
141 Stage3Online,
142}
143
144impl std::fmt::Display for TrainingStage {
145 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146 match self {
147 Self::Stage1Supervised => write!(f, "Stage 1: supervised CE"),
148 Self::Stage2GFlowNet => write!(f, "Stage 2: GFlowNet TB"),
149 Self::Stage3Online => write!(f, "Stage 3: online learning"),
150 }
151 }
152}
153
154pub fn promote_to_production(source: CheckpointTag) -> Result<(), String> {
156 let dir = checkpoint_dir();
157 let src = dir.join(source.stem()).with_extension("mpk");
158 let dst = dir.join("production.mpk");
159
160 if !src.exists() {
161 return Err(format!("{} does not exist", src.display()));
162 }
163
164 std::fs::copy(&src, &dst)
165 .map_err(|e| format!("copy {} → {}: {}", src.display(), dst.display(), e))?;
166 Ok(())
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172
173 #[test]
174 fn detect_stage_returns_valid_stage() {
175 let stage = detect_stage(0, 100);
177 match stage {
178 TrainingStage::Stage1Supervised
179 | TrainingStage::Stage2GFlowNet
180 | TrainingStage::Stage3Online => {} }
182 }
183
184 #[test]
185 fn checkpoint_tag_stems() {
186 assert_eq!(CheckpointTag::Stage1Best.stem(), "stage1_best");
187 assert_eq!(CheckpointTag::Stage2Latest.stem(), "stage2_latest");
188 assert_eq!(CheckpointTag::Production.stem(), "production");
189 }
190}