Skip to main content

swink_agent_eval/
training.rs

1//! RL-compatible training-format trace export (feature: `training-export`).
2//!
3//! Exports [`Invocation`] traces collected during eval runs into formats
4//! compatible with LLM fine-tuning pipelines: ChatML/SFT, DPO pairs, and
5//! ShareGPT.
6//!
7//! # Quick Start
8//!
9//! ```rust,ignore
10//! use swink_agent_eval::training::{
11//!     ChatMlExporter, ExportOptions, ScoredTrace, TrainingExporter, TrainingFormat,
12//! };
13//!
14//! let exporter = ChatMlExporter;
15//! let opts = ExportOptions::default();
16//! let bytes = exporter.export(&traces, &opts)?;
17//! ```
18
19use std::path::PathBuf;
20
21use serde::{Deserialize, Serialize};
22use thiserror::Error;
23
24use crate::EvalCaseResult;
25use crate::report::{Reporter, ReporterError, ReporterOutput};
26use crate::types::Invocation;
27
28// ─── Error ──────────────────────────────────────────────────────────────────
29
30/// Errors produced by training-format exporters.
31#[derive(Debug, Error)]
32pub enum ExportError {
33    /// JSON serialization failed.
34    #[error("serialization error: {0}")]
35    Serialization(#[from] serde_json::Error),
36    /// The requested format is not yet fully implemented (stubs).
37    #[error("format not fully implemented: {0:?}")]
38    NotImplemented(TrainingFormat),
39    /// No traces survived the quality threshold filter.
40    #[error("no traces passed the quality threshold ({threshold})")]
41    NothingToExport { threshold: f32 },
42}
43
44// ─── Format Enum ────────────────────────────────────────────────────────────
45
46/// Supported training-data output formats.
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
48#[serde(rename_all = "snake_case")]
49#[non_exhaustive]
50pub enum TrainingFormat {
51    /// Conversation-style JSONL with system/user/assistant turns and tool calls.
52    ChatMlSft,
53    /// Chosen/rejected pairs from high-score vs low-score traces on the same case.
54    DpoPairs,
55    /// Community ShareGPT conversation format.
56    ShareGpt,
57}
58
59// ─── Options ────────────────────────────────────────────────────────────────
60
61/// Options controlling how traces are exported.
62#[derive(Debug, Clone)]
63pub struct ExportOptions {
64    /// Target output format.
65    pub format: TrainingFormat,
66    /// Minimum score `[0.0, 1.0]` a trace must have to be included.
67    /// Traces with `score < quality_threshold` are filtered out.
68    pub quality_threshold: f32,
69    /// When `true`, per-record metadata (model, temperature proxy, eval case
70    /// ID, timestamp) is included in the exported records.
71    pub include_metadata: bool,
72}
73
74impl Default for ExportOptions {
75    fn default() -> Self {
76        Self {
77            format: TrainingFormat::ChatMlSft,
78            quality_threshold: 0.0,
79            include_metadata: true,
80        }
81    }
82}
83
84impl ExportOptions {
85    /// Create options targeting ChatML/SFT format with a quality gate.
86    #[must_use]
87    pub fn chatml_sft(quality_threshold: f32) -> Self {
88        Self {
89            format: TrainingFormat::ChatMlSft,
90            quality_threshold,
91            include_metadata: true,
92        }
93    }
94
95    /// Create options targeting DPO pairs with a quality gate.
96    #[must_use]
97    pub fn dpo_pairs(quality_threshold: f32) -> Self {
98        Self {
99            format: TrainingFormat::DpoPairs,
100            quality_threshold,
101            include_metadata: true,
102        }
103    }
104
105    /// Create options targeting ShareGPT format.
106    #[must_use]
107    pub fn sharegpt() -> Self {
108        Self {
109            format: TrainingFormat::ShareGpt,
110            quality_threshold: 0.0,
111            include_metadata: true,
112        }
113    }
114}
115
116// ─── ScoredTrace ────────────────────────────────────────────────────────────
117
118/// An [`Invocation`] paired with a quality score and the originating case.
119///
120/// Built from [`EvalCaseResult`] values collected during a run.
121#[derive(Debug, Clone)]
122pub struct ScoredTrace {
123    /// The captured execution trace.
124    pub invocation: Invocation,
125    /// Aggregate quality score for this trace, typically the mean of all
126    /// evaluator scores, in `[0.0, 1.0]`.
127    pub score: f64,
128    /// Identifier of the eval case this trace was produced by.
129    pub case_id: String,
130}
131
132impl ScoredTrace {
133    /// Construct a `ScoredTrace` from an [`EvalCaseResult`].
134    ///
135    /// `score` is the mean of all metric scores (0.0 when there are none).
136    #[must_use]
137    pub fn from_case_result(result: &EvalCaseResult) -> Self {
138        let score = if result.metric_results.is_empty() {
139            0.0
140        } else {
141            let sum: f64 = result.metric_results.iter().map(|m| m.score.value).sum();
142            #[allow(clippy::cast_precision_loss)]
143            let mean = sum / result.metric_results.len() as f64;
144            mean
145        };
146        Self {
147            invocation: result.invocation.clone(),
148            score,
149            case_id: result.case_id.clone(),
150        }
151    }
152}
153
154// ─── Trait ──────────────────────────────────────────────────────────────────
155
156/// Converts a slice of scored traces into a training-data byte payload.
157///
158/// Implementations are stateless; all configuration is passed via
159/// [`ExportOptions`].
160pub trait TrainingExporter: Send + Sync {
161    /// Export `traces` according to `opts`.
162    ///
163    /// Returns a `Vec<u8>` whose encoding depends on the implementation
164    /// (typically UTF-8 JSONL). Returns [`ExportError::NothingToExport`] when
165    /// every trace is below `opts.quality_threshold`.
166    fn export(&self, traces: &[ScoredTrace], opts: &ExportOptions) -> Result<Vec<u8>, ExportError>;
167}
168
169// ─── ChatML/SFT ─────────────────────────────────────────────────────────────
170
171/// Conversation-style JSONL exporter.
172///
173/// Each qualifying trace produces one JSON object per line.  The schema
174/// follows OpenAI's ChatML convention used by many fine-tuning platforms:
175///
176/// ```json
177/// {"messages": [
178///   {"role": "system",    "content": "You are a helpful agent."},
179///   {"role": "user",      "content": "What is 2+2?"},
180///   {"role": "assistant", "content": "4", "tool_calls": [...]}
181/// ], "metadata": {...}}
182/// ```
183///
184/// Tool calls on an assistant turn are serialised as an array of
185/// `{id, type, function: {name, arguments}}` objects, matching the OpenAI
186/// tool-call schema so downstream fine-tuning pipelines can parse them
187/// without additional transformation.
188#[derive(Debug, Default, Clone, Copy)]
189pub struct ChatMlExporter;
190
191impl TrainingExporter for ChatMlExporter {
192    fn export(&self, traces: &[ScoredTrace], opts: &ExportOptions) -> Result<Vec<u8>, ExportError> {
193        let threshold = f64::from(opts.quality_threshold);
194        let qualified: Vec<&ScoredTrace> = traces.iter().filter(|t| t.score >= threshold).collect();
195
196        if qualified.is_empty() {
197            return Err(ExportError::NothingToExport {
198                threshold: opts.quality_threshold,
199            });
200        }
201
202        let mut out = Vec::new();
203        for trace in qualified {
204            let record = build_chatml_record(trace, opts);
205            serde_json::to_writer(&mut out, &record)?;
206            out.push(b'\n');
207        }
208        Ok(out)
209    }
210}
211
212// ─── ChatML helpers ─────────────────────────────────────────────────────────
213
214#[derive(Serialize)]
215struct ChatMlRecord<'a> {
216    messages: Vec<ChatMlMessage>,
217    #[serde(skip_serializing_if = "Option::is_none")]
218    metadata: Option<ChatMlMetadata<'a>>,
219}
220
221#[derive(Serialize)]
222struct ChatMlMessage {
223    role: &'static str,
224    content: String,
225    #[serde(skip_serializing_if = "Option::is_none")]
226    tool_calls: Option<Vec<ChatMlToolCall>>,
227}
228
229#[derive(Serialize)]
230struct ChatMlToolCall {
231    id: String,
232    #[serde(rename = "type")]
233    call_type: &'static str,
234    function: ChatMlFunction,
235}
236
237#[derive(Serialize)]
238struct ChatMlFunction {
239    name: String,
240    arguments: String,
241}
242
243#[derive(Serialize)]
244struct ChatMlMetadata<'a> {
245    case_id: &'a str,
246    score: f64,
247    model_id: String,
248    turns: usize,
249}
250
251fn build_chatml_record<'a>(trace: &'a ScoredTrace, opts: &ExportOptions) -> ChatMlRecord<'a> {
252    let inv = &trace.invocation;
253    let mut messages: Vec<ChatMlMessage> = Vec::new();
254
255    // System message — derive from the first user message context if there is
256    // relevant text, otherwise use an empty string.  The system prompt is not
257    // stored on `Invocation` directly; we emit a placeholder so downstream
258    // pipelines always have a system slot to fill from their own case data.
259    messages.push(ChatMlMessage {
260        role: "system",
261        content: String::new(),
262        tool_calls: None,
263    });
264
265    for turn in &inv.turns {
266        // User turn: synthesise from tool results of the *previous* turn or
267        // from the first turn where we have no tool results to carry.
268        // For turn 0 the user message is implicit (not stored in Invocation).
269        // We add a user placeholder only for turn 0.
270        if turn.turn_index == 0 {
271            messages.push(ChatMlMessage {
272                role: "user",
273                content: String::new(), // prompt not stored in Invocation
274                tool_calls: None,
275            });
276        }
277
278        // Assistant message
279        let content = extract_assistant_text(&turn.assistant_message);
280        let tool_calls: Vec<ChatMlToolCall> = turn
281            .tool_calls
282            .iter()
283            .map(|tc| ChatMlToolCall {
284                id: tc.id.clone(),
285                call_type: "function",
286                function: ChatMlFunction {
287                    name: tc.name.clone(),
288                    arguments: tc.arguments.to_string(),
289                },
290            })
291            .collect();
292
293        messages.push(ChatMlMessage {
294            role: "assistant",
295            content,
296            tool_calls: if tool_calls.is_empty() {
297                None
298            } else {
299                Some(tool_calls)
300            },
301        });
302    }
303
304    // Final response appended as a last assistant message if not already
305    // captured via the last turn's assistant message.
306    if let Some(response) = &inv.final_response {
307        let needs_patch = messages
308            .last()
309            .is_some_and(|last| last.role == "assistant" && last.content.is_empty());
310        let needs_append = messages.last().is_some_and(|last| last.role != "assistant");
311
312        if needs_patch && !response.is_empty() {
313            if let Some(last_mut) = messages.last_mut() {
314                last_mut.content.clone_from(response);
315            }
316        } else if needs_append {
317            messages.push(ChatMlMessage {
318                role: "assistant",
319                content: response.clone(),
320                tool_calls: None,
321            });
322        }
323    }
324
325    let metadata = if opts.include_metadata {
326        Some(ChatMlMetadata {
327            case_id: &trace.case_id,
328            score: trace.score,
329            model_id: inv.model.model_id.clone(),
330            turns: inv.turns.len(),
331        })
332    } else {
333        None
334    };
335
336    ChatMlRecord { messages, metadata }
337}
338
339fn extract_assistant_text(msg: &swink_agent::AssistantMessage) -> String {
340    use swink_agent::ContentBlock;
341    msg.content
342        .iter()
343        .filter_map(|block| {
344            if let ContentBlock::Text { text } = block {
345                Some(text.as_str())
346            } else {
347                None
348            }
349        })
350        .collect::<Vec<_>>()
351        .join("")
352}
353
354// ─── DPO Pairs ──────────────────────────────────────────────────────────────
355
356/// Chosen/rejected pair exporter for DPO (Direct Preference Optimization).
357///
358/// Traces are grouped by `case_id`. Within each group, the highest-scoring
359/// trace becomes the `chosen` side and the lowest-scoring trace becomes the
360/// `rejected` side. Cases with fewer than two traces are skipped.
361///
362/// Output schema (one JSON object per line):
363///
364/// ```json
365/// {"case_id": "...", "chosen": {...chatml record...}, "rejected": {...chatml record...}}
366/// ```
367#[derive(Debug, Default, Clone, Copy)]
368pub struct DpoExporter;
369
370/// A single DPO pair (one JSONL record).
371#[derive(Serialize)]
372struct DpoPairRecord {
373    case_id: String,
374    chosen: serde_json::Value,
375    rejected: serde_json::Value,
376}
377
378impl TrainingExporter for DpoExporter {
379    fn export(&self, traces: &[ScoredTrace], opts: &ExportOptions) -> Result<Vec<u8>, ExportError> {
380        let threshold = f64::from(opts.quality_threshold);
381        let qualified: Vec<&ScoredTrace> = traces.iter().filter(|t| t.score >= threshold).collect();
382
383        // Group by case_id
384        let mut by_case: std::collections::HashMap<&str, Vec<&ScoredTrace>> =
385            std::collections::HashMap::new();
386        for trace in &qualified {
387            by_case
388                .entry(trace.case_id.as_str())
389                .or_default()
390                .push(trace);
391        }
392
393        let mut pairs: Vec<DpoPairRecord> = Vec::new();
394        for (case_id, mut group) in by_case {
395            if group.len() < 2 {
396                continue;
397            }
398            // Sort by score descending
399            group.sort_by(|a, b| {
400                b.score
401                    .partial_cmp(&a.score)
402                    .unwrap_or(std::cmp::Ordering::Equal)
403            });
404            let chosen_trace = group[0];
405            let rejected_trace = group[group.len() - 1];
406
407            let chosen_record = build_chatml_record(chosen_trace, opts);
408            let rejected_record = build_chatml_record(rejected_trace, opts);
409
410            pairs.push(DpoPairRecord {
411                case_id: case_id.to_string(),
412                chosen: serde_json::to_value(chosen_record)?,
413                rejected: serde_json::to_value(rejected_record)?,
414            });
415        }
416
417        if pairs.is_empty() {
418            return Err(ExportError::NothingToExport {
419                threshold: opts.quality_threshold,
420            });
421        }
422
423        let mut out = Vec::new();
424        for pair in &pairs {
425            serde_json::to_writer(&mut out, pair)?;
426            out.push(b'\n');
427        }
428        Ok(out)
429    }
430}
431
432// ─── ShareGPT ───────────────────────────────────────────────────────────────
433
434/// Community ShareGPT conversation format exporter.
435///
436/// Output schema (one JSON object per line):
437///
438/// ```json
439/// {"conversations": [
440///   {"from": "system", "value": "..."},
441///   {"from": "human",  "value": "..."},
442///   {"from": "gpt",    "value": "..."}
443/// ]}
444/// ```
445#[derive(Debug, Default, Clone, Copy)]
446pub struct ShareGptExporter;
447
448#[derive(Serialize)]
449struct ShareGptRecord {
450    conversations: Vec<ShareGptTurn>,
451    #[serde(skip_serializing_if = "Option::is_none")]
452    metadata: Option<serde_json::Value>,
453}
454
455#[derive(Serialize)]
456struct ShareGptTurn {
457    from: &'static str,
458    value: String,
459}
460
461impl TrainingExporter for ShareGptExporter {
462    fn export(&self, traces: &[ScoredTrace], opts: &ExportOptions) -> Result<Vec<u8>, ExportError> {
463        let threshold = f64::from(opts.quality_threshold);
464        let qualified: Vec<&ScoredTrace> = traces.iter().filter(|t| t.score >= threshold).collect();
465
466        if qualified.is_empty() {
467            return Err(ExportError::NothingToExport {
468                threshold: opts.quality_threshold,
469            });
470        }
471
472        let mut out = Vec::new();
473        for trace in qualified {
474            let record = build_sharegpt_record(trace, opts);
475            serde_json::to_writer(&mut out, &record)?;
476            out.push(b'\n');
477        }
478        Ok(out)
479    }
480}
481
482fn build_sharegpt_record(trace: &ScoredTrace, opts: &ExportOptions) -> ShareGptRecord {
483    let inv = &trace.invocation;
484    let mut conversations: Vec<ShareGptTurn> = Vec::new();
485
486    // System placeholder
487    conversations.push(ShareGptTurn {
488        from: "system",
489        value: String::new(),
490    });
491
492    for turn in &inv.turns {
493        if turn.turn_index == 0 {
494            conversations.push(ShareGptTurn {
495                from: "human",
496                value: String::new(), // user prompt not stored in Invocation
497            });
498        }
499        let content = extract_assistant_text(&turn.assistant_message);
500        conversations.push(ShareGptTurn {
501            from: "gpt",
502            value: content,
503        });
504    }
505
506    // Final response patch — same logic as ChatML.
507    if let Some(response) = &inv.final_response {
508        let needs_patch = conversations
509            .last()
510            .is_some_and(|last| last.from == "gpt" && last.value.is_empty());
511        let needs_append = conversations.last().is_some_and(|last| last.from != "gpt");
512
513        if needs_patch && !response.is_empty() {
514            if let Some(last_mut) = conversations.last_mut() {
515                last_mut.value.clone_from(response);
516            }
517        } else if needs_append {
518            conversations.push(ShareGptTurn {
519                from: "gpt",
520                value: response.clone(),
521            });
522        }
523    }
524
525    let metadata = if opts.include_metadata {
526        Some(serde_json::json!({
527            "case_id": trace.case_id,
528            "score": trace.score,
529        }))
530    } else {
531        None
532    };
533
534    ShareGptRecord {
535        conversations,
536        metadata,
537    }
538}
539
540// ─── Dispatch helper ─────────────────────────────────────────────────────────
541
542/// Dispatch export to the appropriate exporter based on `opts.format`.
543pub fn export_traces(traces: &[ScoredTrace], opts: &ExportOptions) -> Result<Vec<u8>, ExportError> {
544    match opts.format {
545        TrainingFormat::ChatMlSft => ChatMlExporter.export(traces, opts),
546        TrainingFormat::DpoPairs => DpoExporter.export(traces, opts),
547        TrainingFormat::ShareGpt => ShareGptExporter.export(traces, opts),
548    }
549}
550
551// ─── TrainingReporter ────────────────────────────────────────────────────────
552
553/// A [`Reporter`] that exports all eval results as training data.
554///
555/// Implements the existing [`Reporter`] trait so it can be composed with other
556/// reporters in the eval runner pipeline.
557///
558/// The reporter converts each [`EvalCaseResult`] into a [`ScoredTrace`],
559/// applies the configured [`ExportOptions`], and writes the export artifact to
560/// the configured output path (or returns it as [`ReporterOutput::Artifact`]).
561#[derive(Debug, Clone)]
562pub struct TrainingReporter {
563    opts: ExportOptions,
564    /// Suggested output file path. Callers may override.
565    output_path: PathBuf,
566}
567
568impl TrainingReporter {
569    /// Create a new reporter with explicit options and output path.
570    #[must_use]
571    pub fn new(opts: ExportOptions, output_path: impl Into<PathBuf>) -> Self {
572        Self {
573            opts,
574            output_path: output_path.into(),
575        }
576    }
577
578    /// Create a ChatML/SFT reporter writing to `output_path`.
579    #[must_use]
580    pub fn chatml_sft(quality_threshold: f32, output_path: impl Into<PathBuf>) -> Self {
581        Self::new(ExportOptions::chatml_sft(quality_threshold), output_path)
582    }
583
584    /// Create a DPO pairs reporter writing to `output_path`.
585    #[must_use]
586    pub fn dpo_pairs(quality_threshold: f32, output_path: impl Into<PathBuf>) -> Self {
587        Self::new(ExportOptions::dpo_pairs(quality_threshold), output_path)
588    }
589
590    /// Create a ShareGPT reporter writing to `output_path`.
591    #[must_use]
592    pub fn sharegpt(output_path: impl Into<PathBuf>) -> Self {
593        Self::new(ExportOptions::sharegpt(), output_path)
594    }
595}
596
597impl Reporter for TrainingReporter {
598    fn render(&self, result: &EvalSetResult) -> Result<ReporterOutput, ReporterError> {
599        let traces: Vec<ScoredTrace> = result
600            .case_results
601            .iter()
602            .map(ScoredTrace::from_case_result)
603            .collect();
604
605        let bytes =
606            export_traces(&traces, &self.opts).map_err(|e| ReporterError::Format(e.to_string()))?;
607
608        Ok(ReporterOutput::Artifact {
609            path: self.output_path.clone(),
610            bytes,
611        })
612    }
613}
614
615// Import needed for Reporter impl
616use crate::types::EvalSetResult;