Skip to main content

swarm_engine_core/learn/
training.rs

1//! TrainingData - LoRA 学習用データ形式
2//!
3//! ## 設計思想
4//!
5//! - **DPO/SFT 両対応**: 将来の DPO 対応を見据えた設計
6//! - **メタデータ付き**: モデル名、LoRA、Episode ID など追跡情報を保持
7//! - **Builder パターン**: 柔軟な構築と可読性の両立
8//!
9//! ## 使用例
10//!
11//! ```rust
12//! use swarm_engine_core::learn::{TrainingData, TrainingFormat};
13//!
14//! // SFT 形式(シンプル)
15//! let sft = TrainingData::sft_simple("What action?", "CheckStatus");
16//!
17//! // SFT 形式(システムプロンプト付き)
18//! let sft_with_system = TrainingData::sft(
19//!     "You are an agent.",
20//!     "What should I do?",
21//!     "CheckStatus"
22//! );
23//!
24//! // メタデータ付加
25//! let with_meta = TrainingData::sft_simple("prompt", "response")
26//!     .with_episode_id("ep_001".to_string())
27//!     .with_model("qwen2.5")
28//!     .with_outcome_score(1.0);
29//! ```
30
31use serde::{Deserialize, Serialize};
32
33// ============================================================================
34// TrainingFormat
35// ============================================================================
36
37/// 学習形式
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
39#[derive(Default)]
40pub enum TrainingFormat {
41    /// Supervised Fine-Tuning(chosen のみ使用)
42    #[default]
43    Sft,
44    /// Direct Preference Optimization(chosen + rejected)
45    Dpo,
46}
47
48
49// ============================================================================
50// TrainingMetadata
51// ============================================================================
52
53/// 学習データのメタデータ
54#[derive(Debug, Clone, Default, Serialize, Deserialize)]
55pub struct TrainingMetadata {
56    /// 元の Episode ID
57    pub episode_id: Option<String>,
58
59    /// Outcome スコア(0.0-1.0)
60    pub outcome_score: Option<f64>,
61
62    /// 使用モデル名
63    pub model: Option<String>,
64
65    /// 使用 LoRA 名
66    pub lora: Option<String>,
67
68    /// Strategy 名
69    pub strategy_name: Option<String>,
70
71    /// シナリオ名
72    pub scenario_name: Option<String>,
73
74    /// 追加のカスタムメタデータ
75    #[serde(default, skip_serializing_if = "std::collections::HashMap::is_empty")]
76    pub custom: std::collections::HashMap<String, String>,
77}
78
79impl TrainingMetadata {
80    pub fn new() -> Self {
81        Self::default()
82    }
83}
84
85// ============================================================================
86// TrainingData
87// ============================================================================
88
89/// LoRA 学習用データ形式
90///
91/// DPO/SFT 両対応を想定した設計。
92/// 初期実装は SFT のみで開始し、データが蓄積されたら DPO に移行予定。
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct TrainingData {
95    /// システムプロンプト(オプション)
96    #[serde(default, skip_serializing_if = "Option::is_none")]
97    pub system: Option<String>,
98
99    /// 入力プロンプト(ユーザー入力)
100    pub prompt: String,
101
102    /// 選択された応答(成功ケース)
103    pub chosen: String,
104
105    /// 拒否された応答(失敗ケース、DPO 用)
106    /// SFT の場合は None
107    #[serde(default, skip_serializing_if = "Option::is_none")]
108    pub rejected: Option<String>,
109
110    /// 学習形式
111    pub format: TrainingFormat,
112
113    /// メタデータ
114    #[serde(default)]
115    pub metadata: TrainingMetadata,
116}
117
118impl TrainingData {
119    // ========================================================================
120    // Constructors
121    // ========================================================================
122
123    /// SFT 形式のデータを作成(システムプロンプト付き)
124    ///
125    /// # Arguments
126    /// * `system` - システムプロンプト
127    /// * `prompt` - ユーザープロンプト
128    /// * `response` - アシスタントレスポンス(chosen)
129    pub fn sft(system: &str, prompt: &str, response: &str) -> Self {
130        Self {
131            system: Some(system.to_string()),
132            prompt: prompt.to_string(),
133            chosen: response.to_string(),
134            rejected: None,
135            format: TrainingFormat::Sft,
136            metadata: TrainingMetadata::default(),
137        }
138    }
139
140    /// SFT 形式のデータを作成(シンプル)
141    ///
142    /// # Arguments
143    /// * `prompt` - プロンプト
144    /// * `response` - レスポンス(chosen)
145    pub fn sft_simple(prompt: &str, response: &str) -> Self {
146        Self {
147            system: None,
148            prompt: prompt.to_string(),
149            chosen: response.to_string(),
150            rejected: None,
151            format: TrainingFormat::Sft,
152            metadata: TrainingMetadata::default(),
153        }
154    }
155
156    /// DPO 形式のデータを作成
157    ///
158    /// # Arguments
159    /// * `prompt` - プロンプト
160    /// * `chosen` - 選択された応答(成功ケース)
161    /// * `rejected` - 拒否された応答(失敗ケース)
162    pub fn dpo(prompt: &str, chosen: &str, rejected: &str) -> Self {
163        Self {
164            system: None,
165            prompt: prompt.to_string(),
166            chosen: chosen.to_string(),
167            rejected: Some(rejected.to_string()),
168            format: TrainingFormat::Dpo,
169            metadata: TrainingMetadata::default(),
170        }
171    }
172
173    /// DPO 形式のデータを作成(システムプロンプト付き)
174    pub fn dpo_with_system(system: &str, prompt: &str, chosen: &str, rejected: &str) -> Self {
175        Self {
176            system: Some(system.to_string()),
177            prompt: prompt.to_string(),
178            chosen: chosen.to_string(),
179            rejected: Some(rejected.to_string()),
180            format: TrainingFormat::Dpo,
181            metadata: TrainingMetadata::default(),
182        }
183    }
184
185    // ========================================================================
186    // Builder methods
187    // ========================================================================
188
189    /// Episode ID を設定
190    pub fn with_episode_id(mut self, episode_id: String) -> Self {
191        self.metadata.episode_id = Some(episode_id);
192        self
193    }
194
195    /// Outcome スコアを設定
196    pub fn with_outcome_score(mut self, score: f64) -> Self {
197        self.metadata.outcome_score = Some(score);
198        self
199    }
200
201    /// モデル名を設定
202    pub fn with_model(mut self, model: &str) -> Self {
203        self.metadata.model = Some(model.to_string());
204        self
205    }
206
207    /// LoRA 名を設定
208    pub fn with_lora(mut self, lora: Option<String>) -> Self {
209        self.metadata.lora = lora;
210        self
211    }
212
213    /// Strategy 名を設定
214    pub fn with_strategy(mut self, strategy: &str) -> Self {
215        self.metadata.strategy_name = Some(strategy.to_string());
216        self
217    }
218
219    /// シナリオ名を設定
220    pub fn with_scenario(mut self, scenario: &str) -> Self {
221        self.metadata.scenario_name = Some(scenario.to_string());
222        self
223    }
224
225    /// カスタムメタデータを追加
226    pub fn with_custom(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
227        self.metadata.custom.insert(key.into(), value.into());
228        self
229    }
230
231    // ========================================================================
232    // Accessors
233    // ========================================================================
234
235    /// SFT 形式かどうか
236    pub fn is_sft(&self) -> bool {
237        matches!(self.format, TrainingFormat::Sft)
238    }
239
240    /// DPO 形式かどうか
241    pub fn is_dpo(&self) -> bool {
242        matches!(self.format, TrainingFormat::Dpo)
243    }
244
245    /// 有効なデータかどうか(prompt と chosen が非空)
246    pub fn is_valid(&self) -> bool {
247        !self.prompt.is_empty() && !self.chosen.is_empty()
248    }
249
250    /// DPO として有効かどうか(rejected も必要)
251    pub fn is_valid_dpo(&self) -> bool {
252        self.is_valid()
253            && self
254                .rejected
255                .as_ref()
256                .map(|r| !r.is_empty())
257                .unwrap_or(false)
258    }
259}
260
261// ============================================================================
262// Conversation Format (for JSONL output)
263// ============================================================================
264
265/// 会話形式の学習データ(JSONL 出力用)
266///
267/// Hugging Face の conversations 形式に準拠。
268#[derive(Debug, Clone, Serialize, Deserialize)]
269pub struct ConversationData {
270    /// 会話のターン
271    pub conversations: Vec<ConversationTurn>,
272
273    /// メタデータ(オプション)
274    #[serde(default, skip_serializing_if = "Option::is_none")]
275    pub metadata: Option<TrainingMetadata>,
276}
277
278/// 会話のターン
279#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct ConversationTurn {
281    /// 発話者の役割
282    pub role: ConversationRole,
283
284    /// 発話内容
285    pub content: String,
286}
287
288/// 発話者の役割
289#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
290#[serde(rename_all = "lowercase")]
291pub enum ConversationRole {
292    System,
293    User,
294    Assistant,
295}
296
297impl From<&TrainingData> for ConversationData {
298    fn from(data: &TrainingData) -> Self {
299        let mut conversations = Vec::new();
300
301        // System prompt (optional)
302        if let Some(system) = &data.system {
303            conversations.push(ConversationTurn {
304                role: ConversationRole::System,
305                content: system.clone(),
306            });
307        }
308
309        // User prompt
310        conversations.push(ConversationTurn {
311            role: ConversationRole::User,
312            content: data.prompt.clone(),
313        });
314
315        // Assistant response (chosen)
316        conversations.push(ConversationTurn {
317            role: ConversationRole::Assistant,
318            content: data.chosen.clone(),
319        });
320
321        Self {
322            conversations,
323            metadata: Some(data.metadata.clone()),
324        }
325    }
326}
327
328impl TrainingData {
329    /// Conversation 形式に変換
330    pub fn to_conversation(&self) -> ConversationData {
331        ConversationData::from(self)
332    }
333}
334
335// ============================================================================
336// Tests
337// ============================================================================
338
339#[cfg(test)]
340mod tests {
341    use super::*;
342
343    #[test]
344    fn test_sft_simple() {
345        let data = TrainingData::sft_simple("What action?", "CheckStatus");
346
347        assert_eq!(data.prompt, "What action?");
348        assert_eq!(data.chosen, "CheckStatus");
349        assert!(data.system.is_none());
350        assert!(data.rejected.is_none());
351        assert!(data.is_sft());
352        assert!(data.is_valid());
353    }
354
355    #[test]
356    fn test_sft_with_system() {
357        let data = TrainingData::sft("You are an agent.", "What to do?", "CheckStatus");
358
359        assert_eq!(data.system, Some("You are an agent.".to_string()));
360        assert_eq!(data.prompt, "What to do?");
361        assert_eq!(data.chosen, "CheckStatus");
362        assert!(data.is_sft());
363    }
364
365    #[test]
366    fn test_dpo() {
367        let data = TrainingData::dpo("What action?", "CheckStatus", "InvalidAction");
368
369        assert_eq!(data.chosen, "CheckStatus");
370        assert_eq!(data.rejected, Some("InvalidAction".to_string()));
371        assert!(data.is_dpo());
372        assert!(data.is_valid_dpo());
373    }
374
375    #[test]
376    fn test_builder_methods() {
377        let data = TrainingData::sft_simple("prompt", "response")
378            .with_episode_id("ep_001".to_string())
379            .with_outcome_score(0.85)
380            .with_model("qwen2.5")
381            .with_lora(Some("my_lora".to_string()))
382            .with_strategy("worker_action")
383            .with_scenario("troubleshooting")
384            .with_custom("key", "value");
385
386        assert_eq!(data.metadata.episode_id, Some("ep_001".to_string()));
387        assert_eq!(data.metadata.outcome_score, Some(0.85));
388        assert_eq!(data.metadata.model, Some("qwen2.5".to_string()));
389        assert_eq!(data.metadata.lora, Some("my_lora".to_string()));
390        assert_eq!(
391            data.metadata.strategy_name,
392            Some("worker_action".to_string())
393        );
394        assert_eq!(
395            data.metadata.scenario_name,
396            Some("troubleshooting".to_string())
397        );
398        assert_eq!(data.metadata.custom.get("key"), Some(&"value".to_string()));
399    }
400
401    #[test]
402    fn test_to_conversation() {
403        let data = TrainingData::sft("System prompt", "User prompt", "Assistant response");
404
405        let conv = data.to_conversation();
406
407        assert_eq!(conv.conversations.len(), 3);
408        assert_eq!(conv.conversations[0].role, ConversationRole::System);
409        assert_eq!(conv.conversations[0].content, "System prompt");
410        assert_eq!(conv.conversations[1].role, ConversationRole::User);
411        assert_eq!(conv.conversations[1].content, "User prompt");
412        assert_eq!(conv.conversations[2].role, ConversationRole::Assistant);
413        assert_eq!(conv.conversations[2].content, "Assistant response");
414    }
415
416    #[test]
417    fn test_to_conversation_no_system() {
418        let data = TrainingData::sft_simple("prompt", "response");
419
420        let conv = data.to_conversation();
421
422        assert_eq!(conv.conversations.len(), 2);
423        assert_eq!(conv.conversations[0].role, ConversationRole::User);
424        assert_eq!(conv.conversations[1].role, ConversationRole::Assistant);
425    }
426
427    #[test]
428    fn test_serialization() {
429        let data =
430            TrainingData::sft_simple("prompt", "response").with_episode_id("ep_001".to_string());
431
432        let json = serde_json::to_string(&data).unwrap();
433        let deserialized: TrainingData = serde_json::from_str(&json).unwrap();
434
435        assert_eq!(deserialized.prompt, data.prompt);
436        assert_eq!(deserialized.chosen, data.chosen);
437        assert_eq!(deserialized.metadata.episode_id, data.metadata.episode_id);
438    }
439
440    #[test]
441    fn test_conversation_serialization() {
442        let data = TrainingData::sft("System", "User", "Assistant");
443        let conv = data.to_conversation();
444
445        let json = serde_json::to_string(&conv).unwrap();
446
447        // conversations 形式で出力されることを確認
448        assert!(json.contains("\"conversations\""));
449        assert!(json.contains("\"role\""));
450        assert!(json.contains("\"system\""));
451        assert!(json.contains("\"user\""));
452        assert!(json.contains("\"assistant\""));
453    }
454}