Skip to main content

swarm_engine_core/learn/lora/
trainer.rs

1//! LoRA Trainer - LoRA 学習の実行
2//!
3//! ## 概要
4//!
5//! Episode → TrainingData → LoRA 学習 → TrainedModel
6//!
7//! ## 設計
8//!
9//! 現在は `lora/train.py` をサブプロセスとして呼び出す。
10//! 将来的には Rust native 実装への移行も可能。
11//!
12//! ## 使用例
13//!
14//! ```ignore
15//! use swarm_engine_core::learn::lora::{LoraTrainer, LoraTrainerConfig};
16//! use swarm_engine_core::learn::EpisodeStore;
17//!
18//! let config = LoraTrainerConfig::default()
19//!     .base_model("LiquidAI/LFM2.5-1.2B-Instruct")
20//!     .lora_rank(16);
21//!
22//! let trainer = LoraTrainer::new(config, episode_store);
23//! let model = trainer.train(&learn_model, None).await?;
24//! ```
25
26use std::io::Write as IoWrite;
27use std::path::{Path, PathBuf};
28use std::process::Stdio;
29use std::sync::Arc;
30
31use tokio::process::Command;
32
33use crate::learn::episode::{EpisodeId, Outcome};
34use crate::learn::learn_model::LearnModel;
35use crate::learn::store::{EpisodeDto, EpisodeFilter, EpisodeStore, StoreError};
36use crate::learn::training::TrainingData;
37use crate::util::{epoch_millis, epoch_millis_for_ordering};
38
39// ============================================================================
40// LoraTrainerConfig
41// ============================================================================
42
43/// LoRA Trainer の設定
44#[derive(Debug, Clone)]
45pub struct LoraTrainerConfig {
46    /// ベースモデル (HuggingFace ID or path)
47    pub base_model: String,
48    /// LoRA rank
49    pub lora_rank: u32,
50    /// LoRA alpha
51    pub lora_alpha: f32,
52    /// Dropout rate
53    pub lora_dropout: f32,
54    /// 学習エポック数
55    pub epochs: u32,
56    /// バッチサイズ
57    pub batch_size: u32,
58    /// 勾配蓄積ステップ
59    pub gradient_accumulation: u32,
60    /// 学習率
61    pub learning_rate: f32,
62    /// 最大シーケンス長
63    pub max_seq_length: u32,
64    /// train.py のパス
65    pub train_script: PathBuf,
66    /// 出力ディレクトリ(アダプタ保存先)
67    pub output_dir: PathBuf,
68    /// 学習データ一時ファイル
69    pub data_dir: PathBuf,
70    /// Python 実行パス
71    pub python_path: PathBuf,
72}
73
74impl Default for LoraTrainerConfig {
75    fn default() -> Self {
76        Self {
77            base_model: "LiquidAI/LFM2.5-1.2B-Instruct".to_string(),
78            lora_rank: 16,
79            lora_alpha: 32.0,
80            lora_dropout: 0.05,
81            epochs: 3,
82            batch_size: 4,
83            gradient_accumulation: 4,
84            learning_rate: 2e-4,
85            max_seq_length: 2048,
86            train_script: PathBuf::from("lora/train.py"),
87            output_dir: PathBuf::from("lora/adapters"),
88            data_dir: PathBuf::from("lora/data"),
89            python_path: PathBuf::from("python3"),
90        }
91    }
92}
93
94impl LoraTrainerConfig {
95    /// ベースモデルを設定
96    pub fn base_model(mut self, model: impl Into<String>) -> Self {
97        self.base_model = model.into();
98        self
99    }
100
101    /// LoRA rank を設定
102    pub fn lora_rank(mut self, rank: u32) -> Self {
103        self.lora_rank = rank;
104        self
105    }
106
107    /// LoRA alpha を設定
108    pub fn lora_alpha(mut self, alpha: f32) -> Self {
109        self.lora_alpha = alpha;
110        self
111    }
112
113    /// エポック数を設定
114    pub fn epochs(mut self, epochs: u32) -> Self {
115        self.epochs = epochs;
116        self
117    }
118
119    /// バッチサイズを設定
120    pub fn batch_size(mut self, size: u32) -> Self {
121        self.batch_size = size;
122        self
123    }
124
125    /// 学習率を設定
126    pub fn learning_rate(mut self, lr: f32) -> Self {
127        self.learning_rate = lr;
128        self
129    }
130
131    /// train.py のパスを設定
132    pub fn train_script(mut self, path: impl Into<PathBuf>) -> Self {
133        self.train_script = path.into();
134        self
135    }
136
137    /// 出力ディレクトリを設定
138    pub fn output_dir(mut self, path: impl Into<PathBuf>) -> Self {
139        self.output_dir = path.into();
140        self
141    }
142
143    /// Python パスを設定
144    pub fn python_path(mut self, path: impl Into<PathBuf>) -> Self {
145        self.python_path = path.into();
146        self
147    }
148}
149
150// ============================================================================
151// TrainedModel
152// ============================================================================
153
154/// 学習済みモデル
155#[derive(Debug, Clone)]
156pub struct TrainedModel {
157    /// モデル ID
158    pub id: LoraModelId,
159    /// ベースモデル
160    pub base_model: String,
161    /// アダプタのパス
162    pub adapter_path: PathBuf,
163    /// 使用した LearnModel の名前
164    pub learn_model_name: String,
165    /// 学習に使用した Episode の ID リスト
166    pub episode_ids: Vec<EpisodeId>,
167    /// 学習データサンプル数
168    pub sample_count: usize,
169    /// 作成日時(Unix timestamp ms)
170    pub created_at: u64,
171    /// 学習メトリクス
172    pub metrics: Option<TrainingMetrics>,
173}
174
175/// LoRA モデル ID
176#[derive(Debug, Clone, PartialEq, Eq, Hash)]
177pub struct LoraModelId(String);
178
179impl LoraModelId {
180    /// 新しい ID を生成
181    pub fn new() -> Self {
182        use std::sync::atomic::{AtomicU32, Ordering};
183        static COUNTER: AtomicU32 = AtomicU32::new(0);
184
185        let ts = epoch_millis_for_ordering();
186        let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
187        Self(format!("lora-{}-{:08x}", ts, counter))
188    }
189
190    /// 文字列から生成
191    pub fn parse(s: &str) -> Self {
192        Self(s.to_string())
193    }
194
195    /// 文字列として取得
196    pub fn as_str(&self) -> &str {
197        &self.0
198    }
199}
200
201impl Default for LoraModelId {
202    fn default() -> Self {
203        Self::new()
204    }
205}
206
207impl std::fmt::Display for LoraModelId {
208    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209        write!(f, "{}", self.0)
210    }
211}
212
213/// 学習メトリクス
214#[derive(Debug, Clone, Default)]
215pub struct TrainingMetrics {
216    /// 最終 loss
217    pub final_loss: Option<f64>,
218    /// 学習時間(秒)
219    pub training_time_secs: Option<u64>,
220    /// 使用した GPU メモリ(MB)
221    pub gpu_memory_mb: Option<u64>,
222}
223
224// ============================================================================
225// LoraTrainerError
226// ============================================================================
227
228/// LoRA Trainer エラー
229#[derive(Debug)]
230pub enum LoraTrainerError {
231    /// Store エラー
232    Store(StoreError),
233    /// データが空
234    EmptyData(String),
235    /// IO エラー
236    Io(std::io::Error),
237    /// スクリプトが見つからない
238    ScriptNotFound(PathBuf),
239    /// 学習プロセスエラー
240    ProcessFailed { exit_code: i32, stderr: String },
241    /// その他
242    Other(String),
243}
244
245impl std::fmt::Display for LoraTrainerError {
246    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247        match self {
248            Self::Store(e) => write!(f, "Store error: {}", e),
249            Self::EmptyData(msg) => write!(f, "Empty data: {}", msg),
250            Self::Io(e) => write!(f, "IO error: {}", e),
251            Self::ScriptNotFound(p) => write!(f, "Script not found: {}", p.display()),
252            Self::ProcessFailed { exit_code, stderr } => {
253                write!(f, "Training failed (exit {}): {}", exit_code, stderr)
254            }
255            Self::Other(msg) => write!(f, "{}", msg),
256        }
257    }
258}
259
260impl std::error::Error for LoraTrainerError {}
261
262impl From<StoreError> for LoraTrainerError {
263    fn from(e: StoreError) -> Self {
264        Self::Store(e)
265    }
266}
267
268impl From<std::io::Error> for LoraTrainerError {
269    fn from(e: std::io::Error) -> Self {
270        Self::Io(e)
271    }
272}
273
274// ============================================================================
275// LoraTrainer
276// ============================================================================
277
278/// LoRA 学習を実行する
279pub struct LoraTrainer {
280    /// 設定
281    config: LoraTrainerConfig,
282    /// Episode ストア
283    episode_store: Arc<dyn EpisodeStore>,
284}
285
286impl LoraTrainer {
287    /// 新しい LoraTrainer を作成
288    pub fn new(config: LoraTrainerConfig, episode_store: Arc<dyn EpisodeStore>) -> Self {
289        Self {
290            config,
291            episode_store,
292        }
293    }
294
295    /// 設定を取得
296    pub fn config(&self) -> &LoraTrainerConfig {
297        &self.config
298    }
299
300    /// Episode ストアを取得
301    pub fn episode_store(&self) -> &Arc<dyn EpisodeStore> {
302        &self.episode_store
303    }
304
305    /// LoRA 学習を実行(非同期)
306    ///
307    /// サブプロセスとして `train.py` を実行するため、長時間ブロックを回避。
308    ///
309    /// # Arguments
310    /// * `learn_model` - Episode → TrainingData 変換を行う LearnModel
311    /// * `filter` - Episode フィルタ(None で全件)
312    ///
313    /// # Returns
314    /// 学習済みモデル
315    pub async fn train(
316        &self,
317        learn_model: &dyn LearnModel,
318        filter: Option<EpisodeFilter>,
319    ) -> Result<TrainedModel, LoraTrainerError> {
320        let started_at = std::time::Instant::now();
321
322        // 1. Episode 取得
323        tracing::info!(
324            learn_model = learn_model.name(),
325            "Fetching episodes for training"
326        );
327        let filter = filter.unwrap_or_default();
328        let episodes = self.episode_store.query(&filter)?;
329
330        if episodes.is_empty() {
331            return Err(LoraTrainerError::EmptyData(
332                "No episodes found for training".into(),
333            ));
334        }
335
336        let episode_ids: Vec<_> = episodes.iter().map(|e| e.id.clone()).collect();
337        tracing::info!(episode_count = episodes.len(), "Episodes fetched");
338
339        // 2. TrainingData に変換
340        tracing::info!("Converting episodes to training data");
341        let training_data: Vec<TrainingData> = episodes
342            .iter()
343            .filter_map(|ep| episode_dto_to_training_data(ep, learn_model.name()).ok())
344            .collect();
345
346        if training_data.is_empty() {
347            return Err(LoraTrainerError::EmptyData(
348                "No training data generated from episodes".into(),
349            ));
350        }
351
352        let sample_count = training_data.len();
353        tracing::info!(sample_count, "Training data prepared");
354
355        // 3. データをファイルに書き出し
356        let data_path = self.write_training_data(&training_data, learn_model.name())?;
357        tracing::info!(path = %data_path.display(), "Training data written");
358
359        // 4. LoRA 学習実行(非同期)
360        let timestamp = epoch_millis() / 1000; // 秒単位
361        let adapter_name = format!("{}-{}", learn_model.name(), timestamp);
362        let adapter_path = self.run_lora_training(&data_path, &adapter_name).await?;
363
364        let elapsed = started_at.elapsed();
365        tracing::info!(
366            elapsed_secs = elapsed.as_secs(),
367            adapter = %adapter_path.display(),
368            "Training completed"
369        );
370
371        // 5. TrainedModel を作成
372        let model = TrainedModel {
373            id: LoraModelId::new(),
374            base_model: self.config.base_model.clone(),
375            adapter_path,
376            learn_model_name: learn_model.name().to_string(),
377            episode_ids,
378            sample_count,
379            created_at: epoch_millis(),
380            metrics: Some(TrainingMetrics {
381                final_loss: None, // TODO: train.py の出力からパース
382                training_time_secs: Some(elapsed.as_secs()),
383                gpu_memory_mb: None,
384            }),
385        };
386
387        Ok(model)
388    }
389
390    /// 学習データをファイルに書き出し
391    fn write_training_data(
392        &self,
393        data: &[TrainingData],
394        learn_model_name: &str,
395    ) -> Result<PathBuf, LoraTrainerError> {
396        // 出力ディレクトリ作成
397        std::fs::create_dir_all(&self.config.data_dir)?;
398
399        let filename = format!("{}.jsonl", learn_model_name);
400        let path = self.config.data_dir.join(filename);
401
402        let mut file = std::fs::File::create(&path)?;
403
404        for td in data {
405            // TrainingData を学習用フォーマットに変換
406            let json_str = training_data_to_json(td)?;
407            writeln!(file, "{}", json_str)?;
408        }
409
410        Ok(path)
411    }
412
413    /// train.py を実行(非同期)
414    async fn run_lora_training(
415        &self,
416        data_path: &Path,
417        adapter_name: &str,
418    ) -> Result<PathBuf, LoraTrainerError> {
419        // スクリプト存在確認
420        if !self.config.train_script.exists() {
421            return Err(LoraTrainerError::ScriptNotFound(
422                self.config.train_script.clone(),
423            ));
424        }
425
426        let output_path = self.config.output_dir.join(adapter_name);
427
428        // コマンド構築
429        let mut cmd = Command::new(&self.config.python_path);
430        cmd.arg(&self.config.train_script)
431            .arg("--data")
432            .arg(data_path)
433            .arg("--output")
434            .arg(&output_path)
435            .arg("--model")
436            .arg(&self.config.base_model)
437            .arg("--rank")
438            .arg(self.config.lora_rank.to_string())
439            .arg("--alpha")
440            .arg(self.config.lora_alpha.to_string())
441            .arg("--dropout")
442            .arg(self.config.lora_dropout.to_string())
443            .arg("--epochs")
444            .arg(self.config.epochs.to_string())
445            .arg("--batch-size")
446            .arg(self.config.batch_size.to_string())
447            .arg("--grad-accum")
448            .arg(self.config.gradient_accumulation.to_string())
449            .arg("--lr")
450            .arg(self.config.learning_rate.to_string())
451            .arg("--max-seq-length")
452            .arg(self.config.max_seq_length.to_string())
453            .stdout(Stdio::piped())
454            .stderr(Stdio::piped());
455
456        tracing::info!(
457            script = %self.config.train_script.display(),
458            data = %data_path.display(),
459            output = %output_path.display(),
460            "Starting LoRA training"
461        );
462
463        // 非同期実行
464        let output = cmd.output().await?;
465
466        if !output.status.success() {
467            let stderr = String::from_utf8_lossy(&output.stderr);
468            return Err(LoraTrainerError::ProcessFailed {
469                exit_code: output.status.code().unwrap_or(-1),
470                stderr: stderr.to_string(),
471            });
472        }
473
474        // stdout をログ出力
475        let stdout = String::from_utf8_lossy(&output.stdout);
476        for line in stdout.lines() {
477            tracing::debug!(line, "train.py output");
478        }
479
480        Ok(output_path)
481    }
482}
483
484// ============================================================================
485// Helpers
486// ============================================================================
487
488/// EpisodeDto → TrainingData 変換
489fn episode_dto_to_training_data(
490    dto: &EpisodeDto,
491    learn_model_name: &str,
492) -> Result<TrainingData, LoraTrainerError> {
493    // システムプロンプト
494    let system_prompt = format!(
495        "You are an intelligent agent using the {} strategy. Your task is to make optimal decisions.",
496        learn_model_name
497    );
498
499    // ユーザープロンプト(Episode のメタデータを含む)
500    let user_prompt = format!(
501        "Episode ID: {}\nLearn Model: {}\nMetadata: {:?}",
502        dto.id, dto.learn_model, dto.metadata
503    );
504
505    // アシスタント応答(Outcome に基づく)
506    let response = match &dto.outcome {
507        Outcome::Success { score } => {
508            format!("Decision successful with score {:.2}", score)
509        }
510        Outcome::Failure { reason } => {
511            format!("Decision failed: {}", reason)
512        }
513        Outcome::Timeout { partial_score } => match partial_score {
514            Some(score) => format!("Timeout with partial score {:.2}", score),
515            None => "Timeout without progress".to_string(),
516        },
517        Outcome::Unknown => "Outcome unknown".to_string(),
518    };
519
520    // SFT 形式の TrainingData を作成
521    let training_data = TrainingData::sft(&system_prompt, &user_prompt, &response)
522        .with_episode_id(dto.id.to_string())
523        .with_model(learn_model_name);
524
525    // Outcome が Success なら score を追加
526    let training_data = if let Outcome::Success { score } = &dto.outcome {
527        training_data.with_outcome_score(*score)
528    } else {
529        training_data
530    };
531
532    Ok(training_data)
533}
534
535/// TrainingData → JSON 文字列(学習用フォーマット)
536fn training_data_to_json(td: &TrainingData) -> Result<String, LoraTrainerError> {
537    // train.py が期待するフォーマット: {"conversations": [...]}
538    let conversation = td.to_conversation();
539
540    let turns: Vec<serde_json::Value> = conversation
541        .conversations
542        .iter()
543        .map(|turn| {
544            serde_json::json!({
545                "role": match turn.role {
546                    crate::learn::training::ConversationRole::System => "system",
547                    crate::learn::training::ConversationRole::User => "user",
548                    crate::learn::training::ConversationRole::Assistant => "assistant",
549                },
550                "content": turn.content,
551            })
552        })
553        .collect();
554
555    let json_value = serde_json::json!({
556        "conversations": turns
557    });
558
559    serde_json::to_string(&json_value)
560        .map_err(|e| LoraTrainerError::Other(format!("JSON serialization error: {}", e)))
561}
562
563// ============================================================================
564// Tests
565// ============================================================================
566
567#[cfg(test)]
568mod tests {
569    use super::*;
570    use crate::learn::store::InMemoryEpisodeStore;
571
572    #[test]
573    fn test_trainer_config_builder() {
574        let config = LoraTrainerConfig::default()
575            .base_model("test-model")
576            .lora_rank(32)
577            .lora_alpha(64.0)
578            .epochs(5)
579            .batch_size(8)
580            .learning_rate(1e-4);
581
582        assert_eq!(config.base_model, "test-model");
583        assert_eq!(config.lora_rank, 32);
584        assert_eq!(config.lora_alpha, 64.0);
585        assert_eq!(config.epochs, 5);
586        assert_eq!(config.batch_size, 8);
587        assert!((config.learning_rate - 1e-4).abs() < 1e-10);
588    }
589
590    #[test]
591    fn test_model_id() {
592        let id1 = LoraModelId::new();
593        let id2 = LoraModelId::new();
594
595        // IDs should be unique (different timestamps or rand)
596        // Note: In fast tests, they might have same timestamp but different rand
597        assert!(!id1.as_str().is_empty());
598        assert!(!id2.as_str().is_empty());
599    }
600
601    #[test]
602    fn test_trainer_creation() {
603        let config = LoraTrainerConfig::default();
604        let store = Arc::new(InMemoryEpisodeStore::new());
605        let trainer = LoraTrainer::new(config, store);
606
607        assert_eq!(trainer.config().base_model, "LiquidAI/LFM2.5-1.2B-Instruct");
608        assert_eq!(trainer.config().lora_rank, 16);
609    }
610
611    #[tokio::test]
612    async fn test_train_empty_store() {
613        use crate::learn::learn_model::WorkerTaskLearn;
614
615        let config = LoraTrainerConfig::default();
616        let store = Arc::new(InMemoryEpisodeStore::new());
617        let trainer = LoraTrainer::new(config, store);
618
619        let learn_model = WorkerTaskLearn::new();
620        let result = trainer.train(&learn_model, None).await;
621
622        assert!(result.is_err());
623        match result {
624            Err(LoraTrainerError::EmptyData(_)) => {}
625            _ => panic!("Expected EmptyData error"),
626        }
627    }
628
629    #[test]
630    fn test_episode_dto_to_training_data() {
631        use crate::learn::episode::EpisodeMetadata;
632
633        let dto = EpisodeDto {
634            id: EpisodeId::new(),
635            learn_model: "test".to_string(),
636            outcome: Outcome::success(0.95),
637            metadata: EpisodeMetadata::new(),
638            record_ids: vec![],
639        };
640
641        let td = episode_dto_to_training_data(&dto, "test-model").unwrap();
642        assert!(td.is_sft());
643    }
644
645    #[test]
646    fn test_training_data_to_json() {
647        let td = TrainingData::sft(
648            "You are a helpful assistant.",
649            "What is 2+2?",
650            "2+2 equals 4.",
651        );
652
653        let json = training_data_to_json(&td).unwrap();
654        assert!(json.contains("conversations"));
655        assert!(json.contains("system"));
656        assert!(json.contains("user"));
657        assert!(json.contains("assistant"));
658    }
659}