1use serde::{Deserialize, Serialize};
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
39pub enum TrainingFormat {
40 #[default]
42 Sft,
43 Dpo,
45}
46
47#[derive(Debug, Clone, Default, Serialize, Deserialize)]
53pub struct TrainingMetadata {
54 pub episode_id: Option<String>,
56
57 pub outcome_score: Option<f64>,
59
60 pub model: Option<String>,
62
63 pub lora: Option<String>,
65
66 pub strategy_name: Option<String>,
68
69 pub scenario_name: Option<String>,
71
72 #[serde(default, skip_serializing_if = "std::collections::HashMap::is_empty")]
74 pub custom: std::collections::HashMap<String, String>,
75}
76
77impl TrainingMetadata {
78 pub fn new() -> Self {
79 Self::default()
80 }
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct TrainingData {
93 #[serde(default, skip_serializing_if = "Option::is_none")]
95 pub system: Option<String>,
96
97 pub prompt: String,
99
100 pub chosen: String,
102
103 #[serde(default, skip_serializing_if = "Option::is_none")]
106 pub rejected: Option<String>,
107
108 pub format: TrainingFormat,
110
111 #[serde(default)]
113 pub metadata: TrainingMetadata,
114}
115
116impl TrainingData {
117 pub fn sft(system: &str, prompt: &str, response: &str) -> Self {
128 Self {
129 system: Some(system.to_string()),
130 prompt: prompt.to_string(),
131 chosen: response.to_string(),
132 rejected: None,
133 format: TrainingFormat::Sft,
134 metadata: TrainingMetadata::default(),
135 }
136 }
137
138 pub fn sft_simple(prompt: &str, response: &str) -> Self {
144 Self {
145 system: None,
146 prompt: prompt.to_string(),
147 chosen: response.to_string(),
148 rejected: None,
149 format: TrainingFormat::Sft,
150 metadata: TrainingMetadata::default(),
151 }
152 }
153
154 pub fn dpo(prompt: &str, chosen: &str, rejected: &str) -> Self {
161 Self {
162 system: None,
163 prompt: prompt.to_string(),
164 chosen: chosen.to_string(),
165 rejected: Some(rejected.to_string()),
166 format: TrainingFormat::Dpo,
167 metadata: TrainingMetadata::default(),
168 }
169 }
170
171 pub fn dpo_with_system(system: &str, prompt: &str, chosen: &str, rejected: &str) -> Self {
173 Self {
174 system: Some(system.to_string()),
175 prompt: prompt.to_string(),
176 chosen: chosen.to_string(),
177 rejected: Some(rejected.to_string()),
178 format: TrainingFormat::Dpo,
179 metadata: TrainingMetadata::default(),
180 }
181 }
182
183 pub fn with_episode_id(mut self, episode_id: String) -> Self {
189 self.metadata.episode_id = Some(episode_id);
190 self
191 }
192
193 pub fn with_outcome_score(mut self, score: f64) -> Self {
195 self.metadata.outcome_score = Some(score);
196 self
197 }
198
199 pub fn with_model(mut self, model: &str) -> Self {
201 self.metadata.model = Some(model.to_string());
202 self
203 }
204
205 pub fn with_lora(mut self, lora: Option<String>) -> Self {
207 self.metadata.lora = lora;
208 self
209 }
210
211 pub fn with_strategy(mut self, strategy: &str) -> Self {
213 self.metadata.strategy_name = Some(strategy.to_string());
214 self
215 }
216
217 pub fn with_scenario(mut self, scenario: &str) -> Self {
219 self.metadata.scenario_name = Some(scenario.to_string());
220 self
221 }
222
223 pub fn with_custom(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
225 self.metadata.custom.insert(key.into(), value.into());
226 self
227 }
228
229 pub fn is_sft(&self) -> bool {
235 matches!(self.format, TrainingFormat::Sft)
236 }
237
238 pub fn is_dpo(&self) -> bool {
240 matches!(self.format, TrainingFormat::Dpo)
241 }
242
243 pub fn is_valid(&self) -> bool {
245 !self.prompt.is_empty() && !self.chosen.is_empty()
246 }
247
248 pub fn is_valid_dpo(&self) -> bool {
250 self.is_valid()
251 && self
252 .rejected
253 .as_ref()
254 .map(|r| !r.is_empty())
255 .unwrap_or(false)
256 }
257}
258
259#[derive(Debug, Clone, Serialize, Deserialize)]
267pub struct ConversationData {
268 pub conversations: Vec<ConversationTurn>,
270
271 #[serde(default, skip_serializing_if = "Option::is_none")]
273 pub metadata: Option<TrainingMetadata>,
274}
275
276#[derive(Debug, Clone, Serialize, Deserialize)]
278pub struct ConversationTurn {
279 pub role: ConversationRole,
281
282 pub content: String,
284}
285
286#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
288#[serde(rename_all = "lowercase")]
289pub enum ConversationRole {
290 System,
291 User,
292 Assistant,
293}
294
295impl From<&TrainingData> for ConversationData {
296 fn from(data: &TrainingData) -> Self {
297 let mut conversations = Vec::new();
298
299 if let Some(system) = &data.system {
301 conversations.push(ConversationTurn {
302 role: ConversationRole::System,
303 content: system.clone(),
304 });
305 }
306
307 conversations.push(ConversationTurn {
309 role: ConversationRole::User,
310 content: data.prompt.clone(),
311 });
312
313 conversations.push(ConversationTurn {
315 role: ConversationRole::Assistant,
316 content: data.chosen.clone(),
317 });
318
319 Self {
320 conversations,
321 metadata: Some(data.metadata.clone()),
322 }
323 }
324}
325
326impl TrainingData {
327 pub fn to_conversation(&self) -> ConversationData {
329 ConversationData::from(self)
330 }
331}
332
333#[cfg(test)]
338mod tests {
339 use super::*;
340
341 #[test]
342 fn test_sft_simple() {
343 let data = TrainingData::sft_simple("What action?", "CheckStatus");
344
345 assert_eq!(data.prompt, "What action?");
346 assert_eq!(data.chosen, "CheckStatus");
347 assert!(data.system.is_none());
348 assert!(data.rejected.is_none());
349 assert!(data.is_sft());
350 assert!(data.is_valid());
351 }
352
353 #[test]
354 fn test_sft_with_system() {
355 let data = TrainingData::sft("You are an agent.", "What to do?", "CheckStatus");
356
357 assert_eq!(data.system, Some("You are an agent.".to_string()));
358 assert_eq!(data.prompt, "What to do?");
359 assert_eq!(data.chosen, "CheckStatus");
360 assert!(data.is_sft());
361 }
362
363 #[test]
364 fn test_dpo() {
365 let data = TrainingData::dpo("What action?", "CheckStatus", "InvalidAction");
366
367 assert_eq!(data.chosen, "CheckStatus");
368 assert_eq!(data.rejected, Some("InvalidAction".to_string()));
369 assert!(data.is_dpo());
370 assert!(data.is_valid_dpo());
371 }
372
373 #[test]
374 fn test_builder_methods() {
375 let data = TrainingData::sft_simple("prompt", "response")
376 .with_episode_id("ep_001".to_string())
377 .with_outcome_score(0.85)
378 .with_model("qwen2.5")
379 .with_lora(Some("my_lora".to_string()))
380 .with_strategy("worker_action")
381 .with_scenario("troubleshooting")
382 .with_custom("key", "value");
383
384 assert_eq!(data.metadata.episode_id, Some("ep_001".to_string()));
385 assert_eq!(data.metadata.outcome_score, Some(0.85));
386 assert_eq!(data.metadata.model, Some("qwen2.5".to_string()));
387 assert_eq!(data.metadata.lora, Some("my_lora".to_string()));
388 assert_eq!(
389 data.metadata.strategy_name,
390 Some("worker_action".to_string())
391 );
392 assert_eq!(
393 data.metadata.scenario_name,
394 Some("troubleshooting".to_string())
395 );
396 assert_eq!(data.metadata.custom.get("key"), Some(&"value".to_string()));
397 }
398
399 #[test]
400 fn test_to_conversation() {
401 let data = TrainingData::sft("System prompt", "User prompt", "Assistant response");
402
403 let conv = data.to_conversation();
404
405 assert_eq!(conv.conversations.len(), 3);
406 assert_eq!(conv.conversations[0].role, ConversationRole::System);
407 assert_eq!(conv.conversations[0].content, "System prompt");
408 assert_eq!(conv.conversations[1].role, ConversationRole::User);
409 assert_eq!(conv.conversations[1].content, "User prompt");
410 assert_eq!(conv.conversations[2].role, ConversationRole::Assistant);
411 assert_eq!(conv.conversations[2].content, "Assistant response");
412 }
413
414 #[test]
415 fn test_to_conversation_no_system() {
416 let data = TrainingData::sft_simple("prompt", "response");
417
418 let conv = data.to_conversation();
419
420 assert_eq!(conv.conversations.len(), 2);
421 assert_eq!(conv.conversations[0].role, ConversationRole::User);
422 assert_eq!(conv.conversations[1].role, ConversationRole::Assistant);
423 }
424
425 #[test]
426 fn test_serialization() {
427 let data =
428 TrainingData::sft_simple("prompt", "response").with_episode_id("ep_001".to_string());
429
430 let json = serde_json::to_string(&data).unwrap();
431 let deserialized: TrainingData = serde_json::from_str(&json).unwrap();
432
433 assert_eq!(deserialized.prompt, data.prompt);
434 assert_eq!(deserialized.chosen, data.chosen);
435 assert_eq!(deserialized.metadata.episode_id, data.metadata.episode_id);
436 }
437
438 #[test]
439 fn test_conversation_serialization() {
440 let data = TrainingData::sft("System", "User", "Assistant");
441 let conv = data.to_conversation();
442
443 let json = serde_json::to_string(&conv).unwrap();
444
445 assert!(json.contains("\"conversations\""));
447 assert!(json.contains("\"role\""));
448 assert!(json.contains("\"system\""));
449 assert!(json.contains("\"user\""));
450 assert!(json.contains("\"assistant\""));
451 }
452}