swarm_engine_core/learn/learn_model/
dpo.rs1use std::collections::HashMap;
6
7use super::super::episode::{Episode, EpisodeContext, Outcome};
8use super::super::record::Record;
9use super::super::training::TrainingData;
10use super::{LearnError, LearnModel};
11use crate::types::GroupId;
12
13#[derive(Debug, Clone)]
17pub struct DpoPair {
18 pub chosen: Episode,
20 pub rejected: Episode,
22 pub group_id: GroupId,
24 pub quality_gap: f64,
26}
27
28impl DpoPair {
29 pub fn new(chosen: Episode, rejected: Episode, group_id: GroupId) -> Self {
31 let chosen_score = chosen.outcome.score();
32 let rejected_score = rejected.outcome.score();
33 let quality_gap = chosen_score - rejected_score;
34
35 Self {
36 chosen,
37 rejected,
38 group_id,
39 quality_gap,
40 }
41 }
42}
43
44#[derive(Debug, Clone)]
46pub struct DpoConfig {
47 pub min_quality_gap: f64,
49 pub max_pairs: Option<usize>,
51 pub allow_reuse: bool,
53}
54
55impl Default for DpoConfig {
56 fn default() -> Self {
57 Self {
58 min_quality_gap: 0.1, max_pairs: None,
60 allow_reuse: true,
61 }
62 }
63}
64
65pub struct DpoLearnModel<F>
92where
93 F: Fn(&Episode) -> Option<(String, String)> + Send + Sync,
94{
95 system_prompt: String,
97 config: DpoConfig,
99 extractor: F,
101}
102
103impl<F> DpoLearnModel<F>
104where
105 F: Fn(&Episode) -> Option<(String, String)> + Send + Sync,
106{
107 pub fn new(extractor: F) -> Self {
109 Self {
110 system_prompt: String::new(),
111 config: DpoConfig::default(),
112 extractor,
113 }
114 }
115
116 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
118 self.system_prompt = prompt.into();
119 self
120 }
121
122 pub fn with_config(mut self, config: DpoConfig) -> Self {
124 self.config = config;
125 self
126 }
127
128 pub fn with_min_quality_gap(mut self, gap: f64) -> Self {
130 self.config.min_quality_gap = gap;
131 self
132 }
133
134 pub fn with_max_pairs(mut self, max: usize) -> Self {
136 self.config.max_pairs = Some(max);
137 self
138 }
139
140 pub fn build_pairs(&self, episodes: &[Episode]) -> Vec<DpoPair> {
142 let mut by_group: HashMap<GroupId, Vec<&Episode>> = HashMap::new();
144 for ep in episodes {
145 if let Some(gid) = ep.group_id {
146 by_group.entry(gid).or_default().push(ep);
147 }
148 }
149
150 let mut pairs = Vec::new();
151
152 for (group_id, group_episodes) in by_group {
153 let (successes, failures): (Vec<_>, Vec<_>) = group_episodes
155 .into_iter()
156 .partition(|ep| ep.outcome.is_success());
157
158 if successes.is_empty() || failures.is_empty() {
159 continue;
160 }
161
162 let mut sorted_successes: Vec<_> = successes;
164 sorted_successes.sort_by(|a, b| {
165 let a_score = a.outcome.score();
166 let b_score = b.outcome.score();
167 b_score
168 .partial_cmp(&a_score)
169 .unwrap_or(std::cmp::Ordering::Equal)
170 });
171
172 let mut sorted_failures: Vec<_> = failures;
174 sorted_failures.sort_by(|a, b| {
175 let a_score = a.outcome.score();
176 let b_score = b.outcome.score();
177 a_score
178 .partial_cmp(&b_score)
179 .unwrap_or(std::cmp::Ordering::Equal)
180 });
181
182 for success_ep in &sorted_successes {
184 for failure_ep in &sorted_failures {
185 let chosen_score = success_ep.outcome.score();
186 let rejected_score = failure_ep.outcome.score();
187 let gap = chosen_score - rejected_score;
188
189 if gap < self.config.min_quality_gap {
190 continue;
191 }
192
193 let pair = DpoPair::new((*success_ep).clone(), (*failure_ep).clone(), group_id);
194 pairs.push(pair);
195
196 if !self.config.allow_reuse {
197 break;
198 }
199 }
200
201 if !self.config.allow_reuse {
202 break;
203 }
204 }
205 }
206
207 pairs.sort_by(|a, b| {
209 b.quality_gap
210 .partial_cmp(&a.quality_gap)
211 .unwrap_or(std::cmp::Ordering::Equal)
212 });
213
214 if let Some(max) = self.config.max_pairs {
216 pairs.truncate(max);
217 }
218
219 pairs
220 }
221
222 pub fn convert_pair(&self, pair: &DpoPair) -> Result<TrainingData, LearnError> {
224 let (chosen_prompt, chosen_response) = (self.extractor)(&pair.chosen)
225 .ok_or_else(|| LearnError::MissingData("chosen prompt/response".into()))?;
226
227 let (rejected_prompt, rejected_response) = (self.extractor)(&pair.rejected)
228 .ok_or_else(|| LearnError::MissingData("rejected prompt/response".into()))?;
229
230 if chosen_prompt != rejected_prompt {
232 return Err(LearnError::InvalidEpisode(format!(
233 "Prompt mismatch: '{}' vs '{}'",
234 chosen_prompt, rejected_prompt
235 )));
236 }
237
238 let training = if self.system_prompt.is_empty() {
239 TrainingData::dpo(&chosen_prompt, &chosen_response, &rejected_response)
240 } else {
241 TrainingData::dpo_with_system(
242 &self.system_prompt,
243 &chosen_prompt,
244 &chosen_response,
245 &rejected_response,
246 )
247 };
248
249 Ok(training
250 .with_episode_id(pair.chosen.id.to_string())
251 .with_custom("rejected_episode_id", pair.rejected.id.to_string())
252 .with_custom("quality_gap", pair.quality_gap.to_string())
253 .with_custom("group_id", pair.group_id.0.to_string()))
254 }
255
256 pub fn convert_pairs(&self, pairs: &[DpoPair]) -> Vec<TrainingData> {
258 pairs
259 .iter()
260 .filter_map(|pair| self.convert_pair(pair).ok())
261 .collect()
262 }
263}
264
265impl<F> LearnModel for DpoLearnModel<F>
270where
271 F: Fn(&Episode) -> Option<(String, String)> + Send + Sync,
272{
273 fn name(&self) -> &str {
274 "dpo"
275 }
276
277 fn objective(&self) -> &str {
278 "Learn preferences from success/failure Episode pairs within the same group"
279 }
280
281 fn build_episodes(&self, _records: &[Record]) -> Vec<Episode> {
282 vec![]
284 }
285
286 fn evaluate(&self, _context: &EpisodeContext) -> Outcome {
287 panic!(
298 "DpoLearnModel::evaluate() should not be called.\n\
299 DPO learning compares multiple Episodes by group_id, not individual Episode evaluation.\n\
300 Use build_pairs() to generate training pairs from Episodes."
301 );
302 }
303
304 fn convert(&self, _episode: &Episode) -> Result<TrainingData, LearnError> {
305 Err(LearnError::InvalidEpisode(
308 "DPO requires pairs, use convert_pair instead".into(),
309 ))
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316 use crate::learn::episode::EpisodeBuilder;
317 use crate::learn::record::ActionRecord;
318 use crate::types::TaskId;
319
320 fn create_test_episode(
321 task_id: TaskId,
322 group_id: GroupId,
323 success: bool,
324 score: f64,
325 ) -> Episode {
326 let outcome = if success {
327 Outcome::success(score)
328 } else {
329 Outcome::failure("test failure")
330 };
331
332 EpisodeBuilder::default()
333 .learn_model("test")
334 .task_id(task_id)
335 .group_id(group_id)
336 .record(ActionRecord::new(1, 0, "TestAction").success(success))
337 .outcome(outcome)
338 .build()
339 }
340
341 fn test_extractor(ep: &Episode) -> Option<(String, String)> {
342 Some((
344 "test prompt".to_string(),
345 format!("response for {:?}", ep.id),
346 ))
347 }
348
349 #[test]
350 fn test_build_pairs_basic() {
351 let group_id = GroupId::new();
352 let task1 = TaskId::new();
353 let task2 = TaskId::new();
354
355 let episodes = vec![
356 create_test_episode(task1, group_id, true, 0.9),
357 create_test_episode(task2, group_id, false, 0.0),
358 ];
359
360 let dpo = DpoLearnModel::new(test_extractor);
361 let pairs = dpo.build_pairs(&episodes);
362
363 assert_eq!(pairs.len(), 1);
364 assert!(pairs[0].quality_gap > 0.0);
365 }
366
367 #[test]
368 fn test_build_pairs_different_groups() {
369 let group1 = GroupId::new();
370 let group2 = GroupId::new();
371
372 let episodes = vec![
373 create_test_episode(TaskId::new(), group1, true, 0.9),
374 create_test_episode(TaskId::new(), group2, false, 0.0),
375 ];
376
377 let dpo = DpoLearnModel::new(test_extractor);
378 let pairs = dpo.build_pairs(&episodes);
379
380 assert!(pairs.is_empty());
382 }
383
384 #[test]
385 fn test_min_quality_gap() {
386 let group_id = GroupId::new();
387
388 let episodes = vec![
389 create_test_episode(TaskId::new(), group_id, true, 0.6),
390 create_test_episode(TaskId::new(), group_id, false, 0.0),
391 ];
392
393 let dpo = DpoLearnModel::new(test_extractor).with_min_quality_gap(0.5);
395 let pairs = dpo.build_pairs(&episodes);
396
397 assert_eq!(pairs.len(), 1);
399
400 let dpo = DpoLearnModel::new(test_extractor).with_min_quality_gap(0.7);
402 let pairs = dpo.build_pairs(&episodes);
403
404 assert!(pairs.is_empty());
406 }
407}