Skip to main content

swarm_engine_core/learn/learn_model/
dpo.rs

1//! DPO (Direct Preference Optimization) LearnModel
2//!
3//! group_id でグループ化された Episode を比較し、DPO 学習用データを生成する。
4
5use std::collections::HashMap;
6
7use super::super::episode::{Episode, EpisodeContext, Outcome};
8use super::super::record::Record;
9use super::super::training::TrainingData;
10use super::{LearnError, LearnModel};
11use crate::types::GroupId;
12
13/// DPO 学習用の比較ペア
14///
15/// 同じ group_id 内の成功/失敗 Episode から生成される。
16#[derive(Debug, Clone)]
17pub struct DpoPair {
18    /// 成功した Episode
19    pub chosen: Episode,
20    /// 失敗した Episode
21    pub rejected: Episode,
22    /// 共通の group_id
23    pub group_id: GroupId,
24    /// 品質差(chosen.score - rejected.score)
25    pub quality_gap: f64,
26}
27
28impl DpoPair {
29    /// 新しい DpoPair を作成
30    pub fn new(chosen: Episode, rejected: Episode, group_id: GroupId) -> Self {
31        let chosen_score = chosen.outcome.score();
32        let rejected_score = rejected.outcome.score();
33        let quality_gap = chosen_score - rejected_score;
34
35        Self {
36            chosen,
37            rejected,
38            group_id,
39            quality_gap,
40        }
41    }
42}
43
44/// DPO LearnModel の設定
45#[derive(Debug, Clone)]
46pub struct DpoConfig {
47    /// 最小品質差(この差未満のペアは除外)
48    pub min_quality_gap: f64,
49    /// 最大ペア数(None なら無制限)
50    pub max_pairs: Option<usize>,
51    /// 同じエピソードの重複使用を許可
52    pub allow_reuse: bool,
53}
54
55impl Default for DpoConfig {
56    fn default() -> Self {
57        Self {
58            min_quality_gap: 0.1, // 10% 以上の差
59            max_pairs: None,
60            allow_reuse: true,
61        }
62    }
63}
64
65/// 汎用 DPO LearnModel
66///
67/// group_id でグループ化された Episode を比較し、DPO 学習用データを生成する。
68///
69/// ## 設計思想
70///
71/// DPO 学習では「同じ条件で複数回実行した結果を比較」する。
72/// - group_id: 同じ条件での実行グループ(Eval -n 5 で 5 回実行など)
73/// - 成功 Episode と失敗 Episode をペアにして比較
74///
75/// ## 使用方法
76///
77/// ```ignore
78/// // Eval で group_id 付きの Episode を収集
79/// let episodes: Vec<Episode> = ...;
80///
81/// // DPO ペアを生成
82/// let dpo_learn = DpoLearnModel::new();
83/// let pairs = dpo_learn.build_pairs(&episodes);
84///
85/// // TrainingData に変換
86/// let training_data: Vec<TrainingData> = pairs
87///     .iter()
88///     .filter_map(|pair| dpo_learn.convert_pair(pair).ok())
89///     .collect();
90/// ```
91pub struct DpoLearnModel<F>
92where
93    F: Fn(&Episode) -> Option<(String, String)> + Send + Sync,
94{
95    /// システムプロンプト
96    system_prompt: String,
97    /// 設定
98    config: DpoConfig,
99    /// Episode から (prompt, response) を抽出する関数
100    extractor: F,
101}
102
103impl<F> DpoLearnModel<F>
104where
105    F: Fn(&Episode) -> Option<(String, String)> + Send + Sync,
106{
107    /// 新しい DpoLearnModel を作成
108    pub fn new(extractor: F) -> Self {
109        Self {
110            system_prompt: String::new(),
111            config: DpoConfig::default(),
112            extractor,
113        }
114    }
115
116    /// システムプロンプトを設定
117    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
118        self.system_prompt = prompt.into();
119        self
120    }
121
122    /// 設定を適用
123    pub fn with_config(mut self, config: DpoConfig) -> Self {
124        self.config = config;
125        self
126    }
127
128    /// 最小品質差を設定
129    pub fn with_min_quality_gap(mut self, gap: f64) -> Self {
130        self.config.min_quality_gap = gap;
131        self
132    }
133
134    /// 最大ペア数を設定
135    pub fn with_max_pairs(mut self, max: usize) -> Self {
136        self.config.max_pairs = Some(max);
137        self
138    }
139
140    /// group_id でグループ化された Episode から DPO ペアを生成
141    pub fn build_pairs(&self, episodes: &[Episode]) -> Vec<DpoPair> {
142        // group_id でグループ化
143        let mut by_group: HashMap<GroupId, Vec<&Episode>> = HashMap::new();
144        for ep in episodes {
145            if let Some(gid) = ep.group_id {
146                by_group.entry(gid).or_default().push(ep);
147            }
148        }
149
150        let mut pairs = Vec::new();
151
152        for (group_id, group_episodes) in by_group {
153            // 成功/失敗で分類
154            let (successes, failures): (Vec<_>, Vec<_>) = group_episodes
155                .into_iter()
156                .partition(|ep| ep.outcome.is_success());
157
158            if successes.is_empty() || failures.is_empty() {
159                continue;
160            }
161
162            // スコアでソート(高い順)
163            let mut sorted_successes: Vec<_> = successes;
164            sorted_successes.sort_by(|a, b| {
165                let a_score = a.outcome.score();
166                let b_score = b.outcome.score();
167                b_score
168                    .partial_cmp(&a_score)
169                    .unwrap_or(std::cmp::Ordering::Equal)
170            });
171
172            // スコアでソート(低い順)
173            let mut sorted_failures: Vec<_> = failures;
174            sorted_failures.sort_by(|a, b| {
175                let a_score = a.outcome.score();
176                let b_score = b.outcome.score();
177                a_score
178                    .partial_cmp(&b_score)
179                    .unwrap_or(std::cmp::Ordering::Equal)
180            });
181
182            // ペア作成
183            for success_ep in &sorted_successes {
184                for failure_ep in &sorted_failures {
185                    let chosen_score = success_ep.outcome.score();
186                    let rejected_score = failure_ep.outcome.score();
187                    let gap = chosen_score - rejected_score;
188
189                    if gap < self.config.min_quality_gap {
190                        continue;
191                    }
192
193                    let pair = DpoPair::new((*success_ep).clone(), (*failure_ep).clone(), group_id);
194                    pairs.push(pair);
195
196                    if !self.config.allow_reuse {
197                        break;
198                    }
199                }
200
201                if !self.config.allow_reuse {
202                    break;
203                }
204            }
205        }
206
207        // 品質差でソート(大きい順)
208        pairs.sort_by(|a, b| {
209            b.quality_gap
210                .partial_cmp(&a.quality_gap)
211                .unwrap_or(std::cmp::Ordering::Equal)
212        });
213
214        // 最大数で制限
215        if let Some(max) = self.config.max_pairs {
216            pairs.truncate(max);
217        }
218
219        pairs
220    }
221
222    /// DPO ペアを TrainingData に変換
223    pub fn convert_pair(&self, pair: &DpoPair) -> Result<TrainingData, LearnError> {
224        let (chosen_prompt, chosen_response) = (self.extractor)(&pair.chosen)
225            .ok_or_else(|| LearnError::MissingData("chosen prompt/response".into()))?;
226
227        let (rejected_prompt, rejected_response) = (self.extractor)(&pair.rejected)
228            .ok_or_else(|| LearnError::MissingData("rejected prompt/response".into()))?;
229
230        // prompt が一致することを確認(正規化後)
231        if chosen_prompt != rejected_prompt {
232            return Err(LearnError::InvalidEpisode(format!(
233                "Prompt mismatch: '{}' vs '{}'",
234                chosen_prompt, rejected_prompt
235            )));
236        }
237
238        let training = if self.system_prompt.is_empty() {
239            TrainingData::dpo(&chosen_prompt, &chosen_response, &rejected_response)
240        } else {
241            TrainingData::dpo_with_system(
242                &self.system_prompt,
243                &chosen_prompt,
244                &chosen_response,
245                &rejected_response,
246            )
247        };
248
249        Ok(training
250            .with_episode_id(pair.chosen.id.to_string())
251            .with_custom("rejected_episode_id", pair.rejected.id.to_string())
252            .with_custom("quality_gap", pair.quality_gap.to_string())
253            .with_custom("group_id", pair.group_id.0.to_string()))
254    }
255
256    /// 複数のペアを一括変換
257    pub fn convert_pairs(&self, pairs: &[DpoPair]) -> Vec<TrainingData> {
258        pairs
259            .iter()
260            .filter_map(|pair| self.convert_pair(pair).ok())
261            .collect()
262    }
263}
264
265/// LearnModel trait の実装(Record ベースの Episode 構築用)
266///
267/// DPO は通常、既存の Episode を比較するため、build_episodes は空を返す。
268/// 実際の DPO ペア生成は build_pairs メソッドを使用。
269impl<F> LearnModel for DpoLearnModel<F>
270where
271    F: Fn(&Episode) -> Option<(String, String)> + Send + Sync,
272{
273    fn name(&self) -> &str {
274        "dpo"
275    }
276
277    fn objective(&self) -> &str {
278        "Learn preferences from success/failure Episode pairs within the same group"
279    }
280
281    fn build_episodes(&self, _records: &[Record]) -> Vec<Episode> {
282        // DPO は既存の Episode を比較するため、Record から Episode は生成しない
283        vec![]
284    }
285
286    fn evaluate(&self, _context: &EpisodeContext) -> Outcome {
287        // DpoLearnModel は複数 Episode を group_id でグルーピングし、
288        // 成功/失敗のペアを比較して学習する。
289        // 個々の Episode を evaluate() するのは設計として不適切。
290        //
291        // DPO のフロー:
292        //   1. Eval 実行時に Episode が生成される(Outcome は Eval 側で設定)
293        //   2. build_pairs() で group_id ごとにグルーピング
294        //   3. 成功/失敗 Episode のペアから TrainingData を生成
295        //
296        // この evaluate() が呼ばれるのは実装ミス。
297        panic!(
298            "DpoLearnModel::evaluate() should not be called.\n\
299             DPO learning compares multiple Episodes by group_id, not individual Episode evaluation.\n\
300             Use build_pairs() to generate training pairs from Episodes."
301        );
302    }
303
304    fn convert(&self, _episode: &Episode) -> Result<TrainingData, LearnError> {
305        // 単一の Episode からは DPO TrainingData は生成できない
306        // convert_pair を使用すること
307        Err(LearnError::InvalidEpisode(
308            "DPO requires pairs, use convert_pair instead".into(),
309        ))
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316    use crate::learn::episode::EpisodeBuilder;
317    use crate::learn::record::ActionRecord;
318    use crate::types::TaskId;
319
320    fn create_test_episode(
321        task_id: TaskId,
322        group_id: GroupId,
323        success: bool,
324        score: f64,
325    ) -> Episode {
326        let outcome = if success {
327            Outcome::success(score)
328        } else {
329            Outcome::failure("test failure")
330        };
331
332        EpisodeBuilder::default()
333            .learn_model("test")
334            .task_id(task_id)
335            .group_id(group_id)
336            .record(ActionRecord::new(1, 0, "TestAction").success(success))
337            .outcome(outcome)
338            .build()
339    }
340
341    fn test_extractor(ep: &Episode) -> Option<(String, String)> {
342        // テスト用: 固定の prompt/response を返す
343        Some((
344            "test prompt".to_string(),
345            format!("response for {:?}", ep.id),
346        ))
347    }
348
349    #[test]
350    fn test_build_pairs_basic() {
351        let group_id = GroupId::new();
352        let task1 = TaskId::new();
353        let task2 = TaskId::new();
354
355        let episodes = vec![
356            create_test_episode(task1, group_id, true, 0.9),
357            create_test_episode(task2, group_id, false, 0.0),
358        ];
359
360        let dpo = DpoLearnModel::new(test_extractor);
361        let pairs = dpo.build_pairs(&episodes);
362
363        assert_eq!(pairs.len(), 1);
364        assert!(pairs[0].quality_gap > 0.0);
365    }
366
367    #[test]
368    fn test_build_pairs_different_groups() {
369        let group1 = GroupId::new();
370        let group2 = GroupId::new();
371
372        let episodes = vec![
373            create_test_episode(TaskId::new(), group1, true, 0.9),
374            create_test_episode(TaskId::new(), group2, false, 0.0),
375        ];
376
377        let dpo = DpoLearnModel::new(test_extractor);
378        let pairs = dpo.build_pairs(&episodes);
379
380        // 異なる group_id なのでペアにならない
381        assert!(pairs.is_empty());
382    }
383
384    #[test]
385    fn test_min_quality_gap() {
386        let group_id = GroupId::new();
387
388        let episodes = vec![
389            create_test_episode(TaskId::new(), group_id, true, 0.6),
390            create_test_episode(TaskId::new(), group_id, false, 0.0),
391        ];
392
393        // 0.5 以上の差を要求
394        let dpo = DpoLearnModel::new(test_extractor).with_min_quality_gap(0.5);
395        let pairs = dpo.build_pairs(&episodes);
396
397        // 0.6 - 0.0 = 0.6 なのでペアになる
398        assert_eq!(pairs.len(), 1);
399
400        // 0.7 以上の差を要求
401        let dpo = DpoLearnModel::new(test_extractor).with_min_quality_gap(0.7);
402        let pairs = dpo.build_pairs(&episodes);
403
404        // 差が足りないのでペアにならない
405        assert!(pairs.is_empty());
406    }
407}