Skip to main content

swarm_engine_core/learn/
training.rs

1//! TrainingData - LoRA 学習用データ形式
2//!
3//! ## 設計思想
4//!
5//! - **DPO/SFT 両対応**: 将来の DPO 対応を見据えた設計
6//! - **メタデータ付き**: モデル名、LoRA、Episode ID など追跡情報を保持
7//! - **Builder パターン**: 柔軟な構築と可読性の両立
8//!
9//! ## 使用例
10//!
11//! ```rust
12//! use swarm_engine_core::learn::{TrainingData, TrainingFormat};
13//!
14//! // SFT 形式(シンプル)
15//! let sft = TrainingData::sft_simple("What action?", "CheckStatus");
16//!
17//! // SFT 形式(システムプロンプト付き)
18//! let sft_with_system = TrainingData::sft(
19//!     "You are an agent.",
20//!     "What should I do?",
21//!     "CheckStatus"
22//! );
23//!
24//! // メタデータ付加
25//! let with_meta = TrainingData::sft_simple("prompt", "response")
26//!     .with_episode_id("ep_001".to_string())
27//!     .with_model("qwen2.5")
28//!     .with_outcome_score(1.0);
29//! ```
30
31use serde::{Deserialize, Serialize};
32
33// ============================================================================
34// TrainingFormat
35// ============================================================================
36
37/// 学習形式
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
39pub enum TrainingFormat {
40    /// Supervised Fine-Tuning(chosen のみ使用)
41    #[default]
42    Sft,
43    /// Direct Preference Optimization(chosen + rejected)
44    Dpo,
45}
46
47// ============================================================================
48// TrainingMetadata
49// ============================================================================
50
51/// 学習データのメタデータ
52#[derive(Debug, Clone, Default, Serialize, Deserialize)]
53pub struct TrainingMetadata {
54    /// 元の Episode ID
55    pub episode_id: Option<String>,
56
57    /// Outcome スコア(0.0-1.0)
58    pub outcome_score: Option<f64>,
59
60    /// 使用モデル名
61    pub model: Option<String>,
62
63    /// 使用 LoRA 名
64    pub lora: Option<String>,
65
66    /// Strategy 名
67    pub strategy_name: Option<String>,
68
69    /// シナリオ名
70    pub scenario_name: Option<String>,
71
72    /// 追加のカスタムメタデータ
73    #[serde(default, skip_serializing_if = "std::collections::HashMap::is_empty")]
74    pub custom: std::collections::HashMap<String, String>,
75}
76
77impl TrainingMetadata {
78    pub fn new() -> Self {
79        Self::default()
80    }
81}
82
83// ============================================================================
84// TrainingData
85// ============================================================================
86
87/// LoRA 学習用データ形式
88///
89/// DPO/SFT 両対応を想定した設計。
90/// 初期実装は SFT のみで開始し、データが蓄積されたら DPO に移行予定。
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct TrainingData {
93    /// システムプロンプト(オプション)
94    #[serde(default, skip_serializing_if = "Option::is_none")]
95    pub system: Option<String>,
96
97    /// 入力プロンプト(ユーザー入力)
98    pub prompt: String,
99
100    /// 選択された応答(成功ケース)
101    pub chosen: String,
102
103    /// 拒否された応答(失敗ケース、DPO 用)
104    /// SFT の場合は None
105    #[serde(default, skip_serializing_if = "Option::is_none")]
106    pub rejected: Option<String>,
107
108    /// 学習形式
109    pub format: TrainingFormat,
110
111    /// メタデータ
112    #[serde(default)]
113    pub metadata: TrainingMetadata,
114}
115
116impl TrainingData {
117    // ========================================================================
118    // Constructors
119    // ========================================================================
120
121    /// SFT 形式のデータを作成(システムプロンプト付き)
122    ///
123    /// # Arguments
124    /// * `system` - システムプロンプト
125    /// * `prompt` - ユーザープロンプト
126    /// * `response` - アシスタントレスポンス(chosen)
127    pub fn sft(system: &str, prompt: &str, response: &str) -> Self {
128        Self {
129            system: Some(system.to_string()),
130            prompt: prompt.to_string(),
131            chosen: response.to_string(),
132            rejected: None,
133            format: TrainingFormat::Sft,
134            metadata: TrainingMetadata::default(),
135        }
136    }
137
138    /// SFT 形式のデータを作成(シンプル)
139    ///
140    /// # Arguments
141    /// * `prompt` - プロンプト
142    /// * `response` - レスポンス(chosen)
143    pub fn sft_simple(prompt: &str, response: &str) -> Self {
144        Self {
145            system: None,
146            prompt: prompt.to_string(),
147            chosen: response.to_string(),
148            rejected: None,
149            format: TrainingFormat::Sft,
150            metadata: TrainingMetadata::default(),
151        }
152    }
153
154    /// DPO 形式のデータを作成
155    ///
156    /// # Arguments
157    /// * `prompt` - プロンプト
158    /// * `chosen` - 選択された応答(成功ケース)
159    /// * `rejected` - 拒否された応答(失敗ケース)
160    pub fn dpo(prompt: &str, chosen: &str, rejected: &str) -> Self {
161        Self {
162            system: None,
163            prompt: prompt.to_string(),
164            chosen: chosen.to_string(),
165            rejected: Some(rejected.to_string()),
166            format: TrainingFormat::Dpo,
167            metadata: TrainingMetadata::default(),
168        }
169    }
170
171    /// DPO 形式のデータを作成(システムプロンプト付き)
172    pub fn dpo_with_system(system: &str, prompt: &str, chosen: &str, rejected: &str) -> Self {
173        Self {
174            system: Some(system.to_string()),
175            prompt: prompt.to_string(),
176            chosen: chosen.to_string(),
177            rejected: Some(rejected.to_string()),
178            format: TrainingFormat::Dpo,
179            metadata: TrainingMetadata::default(),
180        }
181    }
182
183    // ========================================================================
184    // Builder methods
185    // ========================================================================
186
187    /// Episode ID を設定
188    pub fn with_episode_id(mut self, episode_id: String) -> Self {
189        self.metadata.episode_id = Some(episode_id);
190        self
191    }
192
193    /// Outcome スコアを設定
194    pub fn with_outcome_score(mut self, score: f64) -> Self {
195        self.metadata.outcome_score = Some(score);
196        self
197    }
198
199    /// モデル名を設定
200    pub fn with_model(mut self, model: &str) -> Self {
201        self.metadata.model = Some(model.to_string());
202        self
203    }
204
205    /// LoRA 名を設定
206    pub fn with_lora(mut self, lora: Option<String>) -> Self {
207        self.metadata.lora = lora;
208        self
209    }
210
211    /// Strategy 名を設定
212    pub fn with_strategy(mut self, strategy: &str) -> Self {
213        self.metadata.strategy_name = Some(strategy.to_string());
214        self
215    }
216
217    /// シナリオ名を設定
218    pub fn with_scenario(mut self, scenario: &str) -> Self {
219        self.metadata.scenario_name = Some(scenario.to_string());
220        self
221    }
222
223    /// カスタムメタデータを追加
224    pub fn with_custom(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
225        self.metadata.custom.insert(key.into(), value.into());
226        self
227    }
228
229    // ========================================================================
230    // Accessors
231    // ========================================================================
232
233    /// SFT 形式かどうか
234    pub fn is_sft(&self) -> bool {
235        matches!(self.format, TrainingFormat::Sft)
236    }
237
238    /// DPO 形式かどうか
239    pub fn is_dpo(&self) -> bool {
240        matches!(self.format, TrainingFormat::Dpo)
241    }
242
243    /// 有効なデータかどうか(prompt と chosen が非空)
244    pub fn is_valid(&self) -> bool {
245        !self.prompt.is_empty() && !self.chosen.is_empty()
246    }
247
248    /// DPO として有効かどうか(rejected も必要)
249    pub fn is_valid_dpo(&self) -> bool {
250        self.is_valid()
251            && self
252                .rejected
253                .as_ref()
254                .map(|r| !r.is_empty())
255                .unwrap_or(false)
256    }
257}
258
259// ============================================================================
260// Conversation Format (for JSONL output)
261// ============================================================================
262
263/// 会話形式の学習データ(JSONL 出力用)
264///
265/// Hugging Face の conversations 形式に準拠。
266#[derive(Debug, Clone, Serialize, Deserialize)]
267pub struct ConversationData {
268    /// 会話のターン
269    pub conversations: Vec<ConversationTurn>,
270
271    /// メタデータ(オプション)
272    #[serde(default, skip_serializing_if = "Option::is_none")]
273    pub metadata: Option<TrainingMetadata>,
274}
275
276/// 会話のターン
277#[derive(Debug, Clone, Serialize, Deserialize)]
278pub struct ConversationTurn {
279    /// 発話者の役割
280    pub role: ConversationRole,
281
282    /// 発話内容
283    pub content: String,
284}
285
286/// 発話者の役割
287#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
288#[serde(rename_all = "lowercase")]
289pub enum ConversationRole {
290    System,
291    User,
292    Assistant,
293}
294
295impl From<&TrainingData> for ConversationData {
296    fn from(data: &TrainingData) -> Self {
297        let mut conversations = Vec::new();
298
299        // System prompt (optional)
300        if let Some(system) = &data.system {
301            conversations.push(ConversationTurn {
302                role: ConversationRole::System,
303                content: system.clone(),
304            });
305        }
306
307        // User prompt
308        conversations.push(ConversationTurn {
309            role: ConversationRole::User,
310            content: data.prompt.clone(),
311        });
312
313        // Assistant response (chosen)
314        conversations.push(ConversationTurn {
315            role: ConversationRole::Assistant,
316            content: data.chosen.clone(),
317        });
318
319        Self {
320            conversations,
321            metadata: Some(data.metadata.clone()),
322        }
323    }
324}
325
326impl TrainingData {
327    /// Conversation 形式に変換
328    pub fn to_conversation(&self) -> ConversationData {
329        ConversationData::from(self)
330    }
331}
332
333// ============================================================================
334// Tests
335// ============================================================================
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340
341    #[test]
342    fn test_sft_simple() {
343        let data = TrainingData::sft_simple("What action?", "CheckStatus");
344
345        assert_eq!(data.prompt, "What action?");
346        assert_eq!(data.chosen, "CheckStatus");
347        assert!(data.system.is_none());
348        assert!(data.rejected.is_none());
349        assert!(data.is_sft());
350        assert!(data.is_valid());
351    }
352
353    #[test]
354    fn test_sft_with_system() {
355        let data = TrainingData::sft("You are an agent.", "What to do?", "CheckStatus");
356
357        assert_eq!(data.system, Some("You are an agent.".to_string()));
358        assert_eq!(data.prompt, "What to do?");
359        assert_eq!(data.chosen, "CheckStatus");
360        assert!(data.is_sft());
361    }
362
363    #[test]
364    fn test_dpo() {
365        let data = TrainingData::dpo("What action?", "CheckStatus", "InvalidAction");
366
367        assert_eq!(data.chosen, "CheckStatus");
368        assert_eq!(data.rejected, Some("InvalidAction".to_string()));
369        assert!(data.is_dpo());
370        assert!(data.is_valid_dpo());
371    }
372
373    #[test]
374    fn test_builder_methods() {
375        let data = TrainingData::sft_simple("prompt", "response")
376            .with_episode_id("ep_001".to_string())
377            .with_outcome_score(0.85)
378            .with_model("qwen2.5")
379            .with_lora(Some("my_lora".to_string()))
380            .with_strategy("worker_action")
381            .with_scenario("troubleshooting")
382            .with_custom("key", "value");
383
384        assert_eq!(data.metadata.episode_id, Some("ep_001".to_string()));
385        assert_eq!(data.metadata.outcome_score, Some(0.85));
386        assert_eq!(data.metadata.model, Some("qwen2.5".to_string()));
387        assert_eq!(data.metadata.lora, Some("my_lora".to_string()));
388        assert_eq!(
389            data.metadata.strategy_name,
390            Some("worker_action".to_string())
391        );
392        assert_eq!(
393            data.metadata.scenario_name,
394            Some("troubleshooting".to_string())
395        );
396        assert_eq!(data.metadata.custom.get("key"), Some(&"value".to_string()));
397    }
398
399    #[test]
400    fn test_to_conversation() {
401        let data = TrainingData::sft("System prompt", "User prompt", "Assistant response");
402
403        let conv = data.to_conversation();
404
405        assert_eq!(conv.conversations.len(), 3);
406        assert_eq!(conv.conversations[0].role, ConversationRole::System);
407        assert_eq!(conv.conversations[0].content, "System prompt");
408        assert_eq!(conv.conversations[1].role, ConversationRole::User);
409        assert_eq!(conv.conversations[1].content, "User prompt");
410        assert_eq!(conv.conversations[2].role, ConversationRole::Assistant);
411        assert_eq!(conv.conversations[2].content, "Assistant response");
412    }
413
414    #[test]
415    fn test_to_conversation_no_system() {
416        let data = TrainingData::sft_simple("prompt", "response");
417
418        let conv = data.to_conversation();
419
420        assert_eq!(conv.conversations.len(), 2);
421        assert_eq!(conv.conversations[0].role, ConversationRole::User);
422        assert_eq!(conv.conversations[1].role, ConversationRole::Assistant);
423    }
424
425    #[test]
426    fn test_serialization() {
427        let data =
428            TrainingData::sft_simple("prompt", "response").with_episode_id("ep_001".to_string());
429
430        let json = serde_json::to_string(&data).unwrap();
431        let deserialized: TrainingData = serde_json::from_str(&json).unwrap();
432
433        assert_eq!(deserialized.prompt, data.prompt);
434        assert_eq!(deserialized.chosen, data.chosen);
435        assert_eq!(deserialized.metadata.episode_id, data.metadata.episode_id);
436    }
437
438    #[test]
439    fn test_conversation_serialization() {
440        let data = TrainingData::sft("System", "User", "Assistant");
441        let conv = data.to_conversation();
442
443        let json = serde_json::to_string(&conv).unwrap();
444
445        // conversations 形式で出力されることを確認
446        assert!(json.contains("\"conversations\""));
447        assert!(json.contains("\"role\""));
448        assert!(json.contains("\"system\""));
449        assert!(json.contains("\"user\""));
450        assert!(json.contains("\"assistant\""));
451    }
452}