1use serde::{Deserialize, Serialize};
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
39#[derive(Default)]
40pub enum TrainingFormat {
41 #[default]
43 Sft,
44 Dpo,
46}
47
48
49#[derive(Debug, Clone, Default, Serialize, Deserialize)]
55pub struct TrainingMetadata {
56 pub episode_id: Option<String>,
58
59 pub outcome_score: Option<f64>,
61
62 pub model: Option<String>,
64
65 pub lora: Option<String>,
67
68 pub strategy_name: Option<String>,
70
71 pub scenario_name: Option<String>,
73
74 #[serde(default, skip_serializing_if = "std::collections::HashMap::is_empty")]
76 pub custom: std::collections::HashMap<String, String>,
77}
78
79impl TrainingMetadata {
80 pub fn new() -> Self {
81 Self::default()
82 }
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct TrainingData {
95 #[serde(default, skip_serializing_if = "Option::is_none")]
97 pub system: Option<String>,
98
99 pub prompt: String,
101
102 pub chosen: String,
104
105 #[serde(default, skip_serializing_if = "Option::is_none")]
108 pub rejected: Option<String>,
109
110 pub format: TrainingFormat,
112
113 #[serde(default)]
115 pub metadata: TrainingMetadata,
116}
117
118impl TrainingData {
119 pub fn sft(system: &str, prompt: &str, response: &str) -> Self {
130 Self {
131 system: Some(system.to_string()),
132 prompt: prompt.to_string(),
133 chosen: response.to_string(),
134 rejected: None,
135 format: TrainingFormat::Sft,
136 metadata: TrainingMetadata::default(),
137 }
138 }
139
140 pub fn sft_simple(prompt: &str, response: &str) -> Self {
146 Self {
147 system: None,
148 prompt: prompt.to_string(),
149 chosen: response.to_string(),
150 rejected: None,
151 format: TrainingFormat::Sft,
152 metadata: TrainingMetadata::default(),
153 }
154 }
155
156 pub fn dpo(prompt: &str, chosen: &str, rejected: &str) -> Self {
163 Self {
164 system: None,
165 prompt: prompt.to_string(),
166 chosen: chosen.to_string(),
167 rejected: Some(rejected.to_string()),
168 format: TrainingFormat::Dpo,
169 metadata: TrainingMetadata::default(),
170 }
171 }
172
173 pub fn dpo_with_system(system: &str, prompt: &str, chosen: &str, rejected: &str) -> Self {
175 Self {
176 system: Some(system.to_string()),
177 prompt: prompt.to_string(),
178 chosen: chosen.to_string(),
179 rejected: Some(rejected.to_string()),
180 format: TrainingFormat::Dpo,
181 metadata: TrainingMetadata::default(),
182 }
183 }
184
185 pub fn with_episode_id(mut self, episode_id: String) -> Self {
191 self.metadata.episode_id = Some(episode_id);
192 self
193 }
194
195 pub fn with_outcome_score(mut self, score: f64) -> Self {
197 self.metadata.outcome_score = Some(score);
198 self
199 }
200
201 pub fn with_model(mut self, model: &str) -> Self {
203 self.metadata.model = Some(model.to_string());
204 self
205 }
206
207 pub fn with_lora(mut self, lora: Option<String>) -> Self {
209 self.metadata.lora = lora;
210 self
211 }
212
213 pub fn with_strategy(mut self, strategy: &str) -> Self {
215 self.metadata.strategy_name = Some(strategy.to_string());
216 self
217 }
218
219 pub fn with_scenario(mut self, scenario: &str) -> Self {
221 self.metadata.scenario_name = Some(scenario.to_string());
222 self
223 }
224
225 pub fn with_custom(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
227 self.metadata.custom.insert(key.into(), value.into());
228 self
229 }
230
231 pub fn is_sft(&self) -> bool {
237 matches!(self.format, TrainingFormat::Sft)
238 }
239
240 pub fn is_dpo(&self) -> bool {
242 matches!(self.format, TrainingFormat::Dpo)
243 }
244
245 pub fn is_valid(&self) -> bool {
247 !self.prompt.is_empty() && !self.chosen.is_empty()
248 }
249
250 pub fn is_valid_dpo(&self) -> bool {
252 self.is_valid()
253 && self
254 .rejected
255 .as_ref()
256 .map(|r| !r.is_empty())
257 .unwrap_or(false)
258 }
259}
260
261#[derive(Debug, Clone, Serialize, Deserialize)]
269pub struct ConversationData {
270 pub conversations: Vec<ConversationTurn>,
272
273 #[serde(default, skip_serializing_if = "Option::is_none")]
275 pub metadata: Option<TrainingMetadata>,
276}
277
278#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct ConversationTurn {
281 pub role: ConversationRole,
283
284 pub content: String,
286}
287
288#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
290#[serde(rename_all = "lowercase")]
291pub enum ConversationRole {
292 System,
293 User,
294 Assistant,
295}
296
297impl From<&TrainingData> for ConversationData {
298 fn from(data: &TrainingData) -> Self {
299 let mut conversations = Vec::new();
300
301 if let Some(system) = &data.system {
303 conversations.push(ConversationTurn {
304 role: ConversationRole::System,
305 content: system.clone(),
306 });
307 }
308
309 conversations.push(ConversationTurn {
311 role: ConversationRole::User,
312 content: data.prompt.clone(),
313 });
314
315 conversations.push(ConversationTurn {
317 role: ConversationRole::Assistant,
318 content: data.chosen.clone(),
319 });
320
321 Self {
322 conversations,
323 metadata: Some(data.metadata.clone()),
324 }
325 }
326}
327
328impl TrainingData {
329 pub fn to_conversation(&self) -> ConversationData {
331 ConversationData::from(self)
332 }
333}
334
335#[cfg(test)]
340mod tests {
341 use super::*;
342
343 #[test]
344 fn test_sft_simple() {
345 let data = TrainingData::sft_simple("What action?", "CheckStatus");
346
347 assert_eq!(data.prompt, "What action?");
348 assert_eq!(data.chosen, "CheckStatus");
349 assert!(data.system.is_none());
350 assert!(data.rejected.is_none());
351 assert!(data.is_sft());
352 assert!(data.is_valid());
353 }
354
355 #[test]
356 fn test_sft_with_system() {
357 let data = TrainingData::sft("You are an agent.", "What to do?", "CheckStatus");
358
359 assert_eq!(data.system, Some("You are an agent.".to_string()));
360 assert_eq!(data.prompt, "What to do?");
361 assert_eq!(data.chosen, "CheckStatus");
362 assert!(data.is_sft());
363 }
364
365 #[test]
366 fn test_dpo() {
367 let data = TrainingData::dpo("What action?", "CheckStatus", "InvalidAction");
368
369 assert_eq!(data.chosen, "CheckStatus");
370 assert_eq!(data.rejected, Some("InvalidAction".to_string()));
371 assert!(data.is_dpo());
372 assert!(data.is_valid_dpo());
373 }
374
375 #[test]
376 fn test_builder_methods() {
377 let data = TrainingData::sft_simple("prompt", "response")
378 .with_episode_id("ep_001".to_string())
379 .with_outcome_score(0.85)
380 .with_model("qwen2.5")
381 .with_lora(Some("my_lora".to_string()))
382 .with_strategy("worker_action")
383 .with_scenario("troubleshooting")
384 .with_custom("key", "value");
385
386 assert_eq!(data.metadata.episode_id, Some("ep_001".to_string()));
387 assert_eq!(data.metadata.outcome_score, Some(0.85));
388 assert_eq!(data.metadata.model, Some("qwen2.5".to_string()));
389 assert_eq!(data.metadata.lora, Some("my_lora".to_string()));
390 assert_eq!(
391 data.metadata.strategy_name,
392 Some("worker_action".to_string())
393 );
394 assert_eq!(
395 data.metadata.scenario_name,
396 Some("troubleshooting".to_string())
397 );
398 assert_eq!(data.metadata.custom.get("key"), Some(&"value".to_string()));
399 }
400
401 #[test]
402 fn test_to_conversation() {
403 let data = TrainingData::sft("System prompt", "User prompt", "Assistant response");
404
405 let conv = data.to_conversation();
406
407 assert_eq!(conv.conversations.len(), 3);
408 assert_eq!(conv.conversations[0].role, ConversationRole::System);
409 assert_eq!(conv.conversations[0].content, "System prompt");
410 assert_eq!(conv.conversations[1].role, ConversationRole::User);
411 assert_eq!(conv.conversations[1].content, "User prompt");
412 assert_eq!(conv.conversations[2].role, ConversationRole::Assistant);
413 assert_eq!(conv.conversations[2].content, "Assistant response");
414 }
415
416 #[test]
417 fn test_to_conversation_no_system() {
418 let data = TrainingData::sft_simple("prompt", "response");
419
420 let conv = data.to_conversation();
421
422 assert_eq!(conv.conversations.len(), 2);
423 assert_eq!(conv.conversations[0].role, ConversationRole::User);
424 assert_eq!(conv.conversations[1].role, ConversationRole::Assistant);
425 }
426
427 #[test]
428 fn test_serialization() {
429 let data =
430 TrainingData::sft_simple("prompt", "response").with_episode_id("ep_001".to_string());
431
432 let json = serde_json::to_string(&data).unwrap();
433 let deserialized: TrainingData = serde_json::from_str(&json).unwrap();
434
435 assert_eq!(deserialized.prompt, data.prompt);
436 assert_eq!(deserialized.chosen, data.chosen);
437 assert_eq!(deserialized.metadata.episode_id, data.metadata.episode_id);
438 }
439
440 #[test]
441 fn test_conversation_serialization() {
442 let data = TrainingData::sft("System", "User", "Assistant");
443 let conv = data.to_conversation();
444
445 let json = serde_json::to_string(&conv).unwrap();
446
447 assert!(json.contains("\"conversations\""));
449 assert!(json.contains("\"role\""));
450 assert!(json.contains("\"system\""));
451 assert!(json.contains("\"user\""));
452 assert!(json.contains("\"assistant\""));
453 }
454}