1use std::io::Write as IoWrite;
27use std::path::{Path, PathBuf};
28use std::process::Stdio;
29use std::sync::Arc;
30
31use tokio::process::Command;
32
33use crate::learn::episode::{EpisodeId, Outcome};
34use crate::learn::learn_model::LearnModel;
35use crate::learn::store::{EpisodeDto, EpisodeFilter, EpisodeStore, StoreError};
36use crate::learn::training::TrainingData;
37use crate::util::{epoch_millis, epoch_millis_for_ordering};
38
39#[derive(Debug, Clone)]
45pub struct LoraTrainerConfig {
46 pub base_model: String,
48 pub lora_rank: u32,
50 pub lora_alpha: f32,
52 pub lora_dropout: f32,
54 pub epochs: u32,
56 pub batch_size: u32,
58 pub gradient_accumulation: u32,
60 pub learning_rate: f32,
62 pub max_seq_length: u32,
64 pub train_script: PathBuf,
66 pub output_dir: PathBuf,
68 pub data_dir: PathBuf,
70 pub python_path: PathBuf,
72}
73
74impl Default for LoraTrainerConfig {
75 fn default() -> Self {
76 Self {
77 base_model: "LiquidAI/LFM2.5-1.2B-Instruct".to_string(),
78 lora_rank: 16,
79 lora_alpha: 32.0,
80 lora_dropout: 0.05,
81 epochs: 3,
82 batch_size: 4,
83 gradient_accumulation: 4,
84 learning_rate: 2e-4,
85 max_seq_length: 2048,
86 train_script: PathBuf::from("lora/train.py"),
87 output_dir: PathBuf::from("lora/adapters"),
88 data_dir: PathBuf::from("lora/data"),
89 python_path: PathBuf::from("python3"),
90 }
91 }
92}
93
94impl LoraTrainerConfig {
95 pub fn base_model(mut self, model: impl Into<String>) -> Self {
97 self.base_model = model.into();
98 self
99 }
100
101 pub fn lora_rank(mut self, rank: u32) -> Self {
103 self.lora_rank = rank;
104 self
105 }
106
107 pub fn lora_alpha(mut self, alpha: f32) -> Self {
109 self.lora_alpha = alpha;
110 self
111 }
112
113 pub fn epochs(mut self, epochs: u32) -> Self {
115 self.epochs = epochs;
116 self
117 }
118
119 pub fn batch_size(mut self, size: u32) -> Self {
121 self.batch_size = size;
122 self
123 }
124
125 pub fn learning_rate(mut self, lr: f32) -> Self {
127 self.learning_rate = lr;
128 self
129 }
130
131 pub fn train_script(mut self, path: impl Into<PathBuf>) -> Self {
133 self.train_script = path.into();
134 self
135 }
136
137 pub fn output_dir(mut self, path: impl Into<PathBuf>) -> Self {
139 self.output_dir = path.into();
140 self
141 }
142
143 pub fn python_path(mut self, path: impl Into<PathBuf>) -> Self {
145 self.python_path = path.into();
146 self
147 }
148}
149
150#[derive(Debug, Clone)]
156pub struct TrainedModel {
157 pub id: LoraModelId,
159 pub base_model: String,
161 pub adapter_path: PathBuf,
163 pub learn_model_name: String,
165 pub episode_ids: Vec<EpisodeId>,
167 pub sample_count: usize,
169 pub created_at: u64,
171 pub metrics: Option<TrainingMetrics>,
173}
174
175#[derive(Debug, Clone, PartialEq, Eq, Hash)]
177pub struct LoraModelId(String);
178
179impl LoraModelId {
180 pub fn new() -> Self {
182 use std::sync::atomic::{AtomicU32, Ordering};
183 static COUNTER: AtomicU32 = AtomicU32::new(0);
184
185 let ts = epoch_millis_for_ordering();
186 let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
187 Self(format!("lora-{}-{:08x}", ts, counter))
188 }
189
190 pub fn parse(s: &str) -> Self {
192 Self(s.to_string())
193 }
194
195 pub fn as_str(&self) -> &str {
197 &self.0
198 }
199}
200
201impl Default for LoraModelId {
202 fn default() -> Self {
203 Self::new()
204 }
205}
206
207impl std::fmt::Display for LoraModelId {
208 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209 write!(f, "{}", self.0)
210 }
211}
212
213#[derive(Debug, Clone, Default)]
215pub struct TrainingMetrics {
216 pub final_loss: Option<f64>,
218 pub training_time_secs: Option<u64>,
220 pub gpu_memory_mb: Option<u64>,
222}
223
224#[derive(Debug)]
230pub enum LoraTrainerError {
231 Store(StoreError),
233 EmptyData(String),
235 Io(std::io::Error),
237 ScriptNotFound(PathBuf),
239 ProcessFailed { exit_code: i32, stderr: String },
241 Other(String),
243}
244
245impl std::fmt::Display for LoraTrainerError {
246 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247 match self {
248 Self::Store(e) => write!(f, "Store error: {}", e),
249 Self::EmptyData(msg) => write!(f, "Empty data: {}", msg),
250 Self::Io(e) => write!(f, "IO error: {}", e),
251 Self::ScriptNotFound(p) => write!(f, "Script not found: {}", p.display()),
252 Self::ProcessFailed { exit_code, stderr } => {
253 write!(f, "Training failed (exit {}): {}", exit_code, stderr)
254 }
255 Self::Other(msg) => write!(f, "{}", msg),
256 }
257 }
258}
259
260impl std::error::Error for LoraTrainerError {}
261
262impl From<StoreError> for LoraTrainerError {
263 fn from(e: StoreError) -> Self {
264 Self::Store(e)
265 }
266}
267
268impl From<std::io::Error> for LoraTrainerError {
269 fn from(e: std::io::Error) -> Self {
270 Self::Io(e)
271 }
272}
273
274pub struct LoraTrainer {
280 config: LoraTrainerConfig,
282 episode_store: Arc<dyn EpisodeStore>,
284}
285
286impl LoraTrainer {
287 pub fn new(config: LoraTrainerConfig, episode_store: Arc<dyn EpisodeStore>) -> Self {
289 Self {
290 config,
291 episode_store,
292 }
293 }
294
295 pub fn config(&self) -> &LoraTrainerConfig {
297 &self.config
298 }
299
300 pub fn episode_store(&self) -> &Arc<dyn EpisodeStore> {
302 &self.episode_store
303 }
304
305 pub async fn train(
316 &self,
317 learn_model: &dyn LearnModel,
318 filter: Option<EpisodeFilter>,
319 ) -> Result<TrainedModel, LoraTrainerError> {
320 let started_at = std::time::Instant::now();
321
322 tracing::info!(
324 learn_model = learn_model.name(),
325 "Fetching episodes for training"
326 );
327 let filter = filter.unwrap_or_default();
328 let episodes = self.episode_store.query(&filter)?;
329
330 if episodes.is_empty() {
331 return Err(LoraTrainerError::EmptyData(
332 "No episodes found for training".into(),
333 ));
334 }
335
336 let episode_ids: Vec<_> = episodes.iter().map(|e| e.id.clone()).collect();
337 tracing::info!(episode_count = episodes.len(), "Episodes fetched");
338
339 tracing::info!("Converting episodes to training data");
341 let training_data: Vec<TrainingData> = episodes
342 .iter()
343 .filter_map(|ep| episode_dto_to_training_data(ep, learn_model.name()).ok())
344 .collect();
345
346 if training_data.is_empty() {
347 return Err(LoraTrainerError::EmptyData(
348 "No training data generated from episodes".into(),
349 ));
350 }
351
352 let sample_count = training_data.len();
353 tracing::info!(sample_count, "Training data prepared");
354
355 let data_path = self.write_training_data(&training_data, learn_model.name())?;
357 tracing::info!(path = %data_path.display(), "Training data written");
358
359 let timestamp = epoch_millis() / 1000; let adapter_name = format!("{}-{}", learn_model.name(), timestamp);
362 let adapter_path = self.run_lora_training(&data_path, &adapter_name).await?;
363
364 let elapsed = started_at.elapsed();
365 tracing::info!(
366 elapsed_secs = elapsed.as_secs(),
367 adapter = %adapter_path.display(),
368 "Training completed"
369 );
370
371 let model = TrainedModel {
373 id: LoraModelId::new(),
374 base_model: self.config.base_model.clone(),
375 adapter_path,
376 learn_model_name: learn_model.name().to_string(),
377 episode_ids,
378 sample_count,
379 created_at: epoch_millis(),
380 metrics: Some(TrainingMetrics {
381 final_loss: None, training_time_secs: Some(elapsed.as_secs()),
383 gpu_memory_mb: None,
384 }),
385 };
386
387 Ok(model)
388 }
389
390 fn write_training_data(
392 &self,
393 data: &[TrainingData],
394 learn_model_name: &str,
395 ) -> Result<PathBuf, LoraTrainerError> {
396 std::fs::create_dir_all(&self.config.data_dir)?;
398
399 let filename = format!("{}.jsonl", learn_model_name);
400 let path = self.config.data_dir.join(filename);
401
402 let mut file = std::fs::File::create(&path)?;
403
404 for td in data {
405 let json_str = training_data_to_json(td)?;
407 writeln!(file, "{}", json_str)?;
408 }
409
410 Ok(path)
411 }
412
413 async fn run_lora_training(
415 &self,
416 data_path: &Path,
417 adapter_name: &str,
418 ) -> Result<PathBuf, LoraTrainerError> {
419 if !self.config.train_script.exists() {
421 return Err(LoraTrainerError::ScriptNotFound(
422 self.config.train_script.clone(),
423 ));
424 }
425
426 let output_path = self.config.output_dir.join(adapter_name);
427
428 let mut cmd = Command::new(&self.config.python_path);
430 cmd.arg(&self.config.train_script)
431 .arg("--data")
432 .arg(data_path)
433 .arg("--output")
434 .arg(&output_path)
435 .arg("--model")
436 .arg(&self.config.base_model)
437 .arg("--rank")
438 .arg(self.config.lora_rank.to_string())
439 .arg("--alpha")
440 .arg(self.config.lora_alpha.to_string())
441 .arg("--dropout")
442 .arg(self.config.lora_dropout.to_string())
443 .arg("--epochs")
444 .arg(self.config.epochs.to_string())
445 .arg("--batch-size")
446 .arg(self.config.batch_size.to_string())
447 .arg("--grad-accum")
448 .arg(self.config.gradient_accumulation.to_string())
449 .arg("--lr")
450 .arg(self.config.learning_rate.to_string())
451 .arg("--max-seq-length")
452 .arg(self.config.max_seq_length.to_string())
453 .stdout(Stdio::piped())
454 .stderr(Stdio::piped());
455
456 tracing::info!(
457 script = %self.config.train_script.display(),
458 data = %data_path.display(),
459 output = %output_path.display(),
460 "Starting LoRA training"
461 );
462
463 let output = cmd.output().await?;
465
466 if !output.status.success() {
467 let stderr = String::from_utf8_lossy(&output.stderr);
468 return Err(LoraTrainerError::ProcessFailed {
469 exit_code: output.status.code().unwrap_or(-1),
470 stderr: stderr.to_string(),
471 });
472 }
473
474 let stdout = String::from_utf8_lossy(&output.stdout);
476 for line in stdout.lines() {
477 tracing::debug!(line, "train.py output");
478 }
479
480 Ok(output_path)
481 }
482}
483
484fn episode_dto_to_training_data(
490 dto: &EpisodeDto,
491 learn_model_name: &str,
492) -> Result<TrainingData, LoraTrainerError> {
493 let system_prompt = format!(
495 "You are an intelligent agent using the {} strategy. Your task is to make optimal decisions.",
496 learn_model_name
497 );
498
499 let user_prompt = format!(
501 "Episode ID: {}\nLearn Model: {}\nMetadata: {:?}",
502 dto.id, dto.learn_model, dto.metadata
503 );
504
505 let response = match &dto.outcome {
507 Outcome::Success { score } => {
508 format!("Decision successful with score {:.2}", score)
509 }
510 Outcome::Failure { reason } => {
511 format!("Decision failed: {}", reason)
512 }
513 Outcome::Timeout { partial_score } => match partial_score {
514 Some(score) => format!("Timeout with partial score {:.2}", score),
515 None => "Timeout without progress".to_string(),
516 },
517 Outcome::Unknown => "Outcome unknown".to_string(),
518 };
519
520 let training_data = TrainingData::sft(&system_prompt, &user_prompt, &response)
522 .with_episode_id(dto.id.to_string())
523 .with_model(learn_model_name);
524
525 let training_data = if let Outcome::Success { score } = &dto.outcome {
527 training_data.with_outcome_score(*score)
528 } else {
529 training_data
530 };
531
532 Ok(training_data)
533}
534
535fn training_data_to_json(td: &TrainingData) -> Result<String, LoraTrainerError> {
537 let conversation = td.to_conversation();
539
540 let turns: Vec<serde_json::Value> = conversation
541 .conversations
542 .iter()
543 .map(|turn| {
544 serde_json::json!({
545 "role": match turn.role {
546 crate::learn::training::ConversationRole::System => "system",
547 crate::learn::training::ConversationRole::User => "user",
548 crate::learn::training::ConversationRole::Assistant => "assistant",
549 },
550 "content": turn.content,
551 })
552 })
553 .collect();
554
555 let json_value = serde_json::json!({
556 "conversations": turns
557 });
558
559 serde_json::to_string(&json_value)
560 .map_err(|e| LoraTrainerError::Other(format!("JSON serialization error: {}", e)))
561}
562
563#[cfg(test)]
568mod tests {
569 use super::*;
570 use crate::learn::store::InMemoryEpisodeStore;
571
572 #[test]
573 fn test_trainer_config_builder() {
574 let config = LoraTrainerConfig::default()
575 .base_model("test-model")
576 .lora_rank(32)
577 .lora_alpha(64.0)
578 .epochs(5)
579 .batch_size(8)
580 .learning_rate(1e-4);
581
582 assert_eq!(config.base_model, "test-model");
583 assert_eq!(config.lora_rank, 32);
584 assert_eq!(config.lora_alpha, 64.0);
585 assert_eq!(config.epochs, 5);
586 assert_eq!(config.batch_size, 8);
587 assert!((config.learning_rate - 1e-4).abs() < 1e-10);
588 }
589
590 #[test]
591 fn test_model_id() {
592 let id1 = LoraModelId::new();
593 let id2 = LoraModelId::new();
594
595 assert!(!id1.as_str().is_empty());
598 assert!(!id2.as_str().is_empty());
599 }
600
601 #[test]
602 fn test_trainer_creation() {
603 let config = LoraTrainerConfig::default();
604 let store = Arc::new(InMemoryEpisodeStore::new());
605 let trainer = LoraTrainer::new(config, store);
606
607 assert_eq!(trainer.config().base_model, "LiquidAI/LFM2.5-1.2B-Instruct");
608 assert_eq!(trainer.config().lora_rank, 16);
609 }
610
611 #[tokio::test]
612 async fn test_train_empty_store() {
613 use crate::learn::learn_model::WorkerTaskLearn;
614
615 let config = LoraTrainerConfig::default();
616 let store = Arc::new(InMemoryEpisodeStore::new());
617 let trainer = LoraTrainer::new(config, store);
618
619 let learn_model = WorkerTaskLearn::new();
620 let result = trainer.train(&learn_model, None).await;
621
622 assert!(result.is_err());
623 match result {
624 Err(LoraTrainerError::EmptyData(_)) => {}
625 _ => panic!("Expected EmptyData error"),
626 }
627 }
628
629 #[test]
630 fn test_episode_dto_to_training_data() {
631 use crate::learn::episode::EpisodeMetadata;
632
633 let dto = EpisodeDto {
634 id: EpisodeId::new(),
635 learn_model: "test".to_string(),
636 outcome: Outcome::success(0.95),
637 metadata: EpisodeMetadata::new(),
638 record_ids: vec![],
639 };
640
641 let td = episode_dto_to_training_data(&dto, "test-model").unwrap();
642 assert!(td.is_sft());
643 }
644
645 #[test]
646 fn test_training_data_to_json() {
647 let td = TrainingData::sft(
648 "You are a helpful assistant.",
649 "What is 2+2?",
650 "2+2 equals 4.",
651 );
652
653 let json = training_data_to_json(&td).unwrap();
654 assert!(json.contains("conversations"));
655 assert!(json.contains("system"));
656 assert!(json.contains("user"));
657 assert!(json.contains("assistant"));
658 }
659}