1use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct PipelineConfig {
13 pub pipeline: HashMap<String, PipelineDefinition>,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct PipelineDefinition {
20 #[serde(rename = "type")]
22 pub pipeline_type: String,
23
24 #[serde(default)]
26 pub director: Option<DirectorConfig>,
27
28 #[serde(default)]
30 pub critic: Option<CriticConfig>,
31
32 #[serde(default)]
34 pub convergence: Option<ConvergenceConfig>,
35
36 #[serde(default)]
38 pub steps: Vec<StepConfig>,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct DirectorConfig {
44 pub model: String,
46
47 #[serde(default = "default_temperature")]
49 pub temperature: f64,
50
51 #[serde(default)]
53 pub system_prompt: Option<String>,
54
55 #[serde(default)]
57 pub max_tokens: Option<u32>,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct CriticConfig {
63 pub model: String,
65
66 #[serde(default = "default_evaluation_mode")]
68 pub evaluation_mode: EvaluationMode,
69
70 #[serde(default = "default_threshold")]
72 pub threshold: f64,
73
74 #[serde(default)]
76 pub rubric: HashMap<String, RubricDimension>,
77
78 #[serde(default)]
80 pub system_prompt: Option<String>,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
85#[serde(rename_all = "snake_case")]
86pub enum EvaluationMode {
87 Binary,
89 Score,
91 Rubric,
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct RubricDimension {
98 pub weight: f64,
100
101 #[serde(default)]
103 pub description: Option<String>,
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108#[serde(tag = "type", rename_all = "snake_case")]
109pub enum ConvergenceConfig {
110 FixedRounds { rounds: u32 },
112
113 AdaptiveBreak {
115 min_rounds: u32,
116 max_rounds: u32,
117 improvement_threshold: f64,
118 },
119
120 ScoreThreshold { target_score: f64, max_rounds: u32 },
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct StepConfig {
127 pub name: String,
129
130 pub model: String,
132
133 #[serde(default)]
135 pub system_prompt: Option<String>,
136
137 #[serde(default)]
139 pub output_format: Option<String>,
140}
141
142fn default_temperature() -> f64 {
143 0.7
144}
145
146fn default_evaluation_mode() -> EvaluationMode {
147 EvaluationMode::Score
148}
149
150fn default_threshold() -> f64 {
151 0.7
152}
153
154impl PipelineConfig {
155 pub fn from_toml(toml_str: &str) -> Result<Self, PipelineConfigError> {
157 toml::from_str(toml_str).map_err(|e| PipelineConfigError::ParseError {
158 message: e.to_string(),
159 })
160 }
161
162 pub fn to_toml(&self) -> Result<String, PipelineConfigError> {
164 toml::to_string_pretty(self).map_err(|e| PipelineConfigError::SerializeError {
165 message: e.to_string(),
166 })
167 }
168
169 pub fn get_pipeline(&self, name: &str) -> Option<&PipelineDefinition> {
171 self.pipeline.get(name)
172 }
173
174 pub fn validate(&self) -> Result<(), PipelineConfigError> {
176 for (name, pipeline) in &self.pipeline {
177 if let Some(critic) = &pipeline.critic {
179 if !(0.0..=1.0).contains(&critic.threshold) {
180 return Err(PipelineConfigError::ValidationError {
181 field: format!("pipeline.{}.critic.threshold", name),
182 message: "Threshold must be between 0.0 and 1.0".into(),
183 });
184 }
185
186 if critic.evaluation_mode == EvaluationMode::Rubric && !critic.rubric.is_empty() {
188 let total_weight: f64 = critic.rubric.values().map(|d| d.weight).sum();
189 if (total_weight - 1.0).abs() > 0.01 {
190 return Err(PipelineConfigError::ValidationError {
191 field: format!("pipeline.{}.critic.rubric", name),
192 message: format!(
193 "Rubric weights must sum to 1.0, got {}",
194 total_weight
195 ),
196 });
197 }
198 }
199 }
200
201 if let Some(ConvergenceConfig::AdaptiveBreak {
203 min_rounds,
204 max_rounds,
205 ..
206 }) = &pipeline.convergence
207 {
208 if min_rounds > max_rounds {
209 return Err(PipelineConfigError::ValidationError {
210 field: format!("pipeline.{}.convergence", name),
211 message: "min_rounds must be <= max_rounds".into(),
212 });
213 }
214 }
215 }
216
217 Ok(())
218 }
219}
220
221#[derive(Debug, thiserror::Error)]
223pub enum PipelineConfigError {
224 #[error("TOML parse error: {message}")]
225 ParseError { message: String },
226
227 #[error("TOML serialization error: {message}")]
228 SerializeError { message: String },
229
230 #[error("Validation error in '{field}': {message}")]
231 ValidationError { field: String, message: String },
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 #[test]
239 fn test_parse_director_critic_pipeline() {
240 let toml = r#"
241[pipeline.compliance_review]
242type = "director_critic"
243
244[pipeline.compliance_review.director]
245model = "slm"
246temperature = 0.7
247
248[pipeline.compliance_review.critic]
249model = "claude-sonnet"
250evaluation_mode = "rubric"
251threshold = 0.85
252
253[pipeline.compliance_review.critic.rubric.accuracy]
254weight = 0.4
255
256[pipeline.compliance_review.critic.rubric.compliance]
257weight = 0.3
258
259[pipeline.compliance_review.critic.rubric.completeness]
260weight = 0.3
261
262[pipeline.compliance_review.convergence]
263type = "adaptive_break"
264min_rounds = 1
265max_rounds = 3
266improvement_threshold = 0.05
267"#;
268
269 let config = PipelineConfig::from_toml(toml).unwrap();
270 assert!(config.validate().is_ok());
271
272 let pipeline = config.get_pipeline("compliance_review").unwrap();
273 assert_eq!(pipeline.pipeline_type, "director_critic");
274
275 let critic = pipeline.critic.as_ref().unwrap();
276 assert_eq!(critic.model, "claude-sonnet");
277 assert_eq!(critic.evaluation_mode, EvaluationMode::Rubric);
278 assert_eq!(critic.rubric.len(), 3);
279
280 let director = pipeline.director.as_ref().unwrap();
281 assert_eq!(director.model, "slm");
282 assert!((director.temperature - 0.7).abs() < f64::EPSILON);
283 }
284
285 #[test]
286 fn test_parse_fixed_rounds_convergence() {
287 let toml = r#"
288[pipeline.simple]
289type = "director_critic"
290
291[pipeline.simple.convergence]
292type = "fixed_rounds"
293rounds = 3
294"#;
295
296 let config = PipelineConfig::from_toml(toml).unwrap();
297 let pipeline = config.get_pipeline("simple").unwrap();
298 match pipeline.convergence.as_ref().unwrap() {
299 ConvergenceConfig::FixedRounds { rounds } => assert_eq!(*rounds, 3),
300 _ => panic!("Expected FixedRounds"),
301 }
302 }
303
304 #[test]
305 fn test_parse_score_threshold_convergence() {
306 let toml = r#"
307[pipeline.quality]
308type = "director_critic"
309
310[pipeline.quality.convergence]
311type = "score_threshold"
312target_score = 0.9
313max_rounds = 5
314"#;
315
316 let config = PipelineConfig::from_toml(toml).unwrap();
317 let pipeline = config.get_pipeline("quality").unwrap();
318 match pipeline.convergence.as_ref().unwrap() {
319 ConvergenceConfig::ScoreThreshold {
320 target_score,
321 max_rounds,
322 } => {
323 assert!((target_score - 0.9).abs() < f64::EPSILON);
324 assert_eq!(*max_rounds, 5);
325 }
326 _ => panic!("Expected ScoreThreshold"),
327 }
328 }
329
330 #[test]
331 fn test_validate_invalid_threshold() {
332 let toml = r#"
333[pipeline.bad]
334type = "director_critic"
335
336[pipeline.bad.critic]
337model = "test"
338threshold = 1.5
339"#;
340
341 let config = PipelineConfig::from_toml(toml).unwrap();
342 assert!(config.validate().is_err());
343 }
344
345 #[test]
346 fn test_validate_rubric_weights() {
347 let toml = r#"
348[pipeline.bad]
349type = "director_critic"
350
351[pipeline.bad.critic]
352model = "test"
353evaluation_mode = "rubric"
354threshold = 0.5
355
356[pipeline.bad.critic.rubric.a]
357weight = 0.3
358
359[pipeline.bad.critic.rubric.b]
360weight = 0.3
361"#;
362
363 let config = PipelineConfig::from_toml(toml).unwrap();
364 let err = config.validate().unwrap_err();
365 assert!(err.to_string().contains("sum to 1.0"));
366 }
367
368 #[test]
369 fn test_validate_convergence_bounds() {
370 let toml = r#"
371[pipeline.bad]
372type = "director_critic"
373
374[pipeline.bad.convergence]
375type = "adaptive_break"
376min_rounds = 5
377max_rounds = 3
378improvement_threshold = 0.05
379"#;
380
381 let config = PipelineConfig::from_toml(toml).unwrap();
382 let err = config.validate().unwrap_err();
383 assert!(err.to_string().contains("min_rounds"));
384 }
385
386 #[test]
387 fn test_roundtrip_serialization() {
388 let toml = r#"
389[pipeline.test]
390type = "chain"
391
392[[pipeline.test.steps]]
393name = "summarize"
394model = "slm"
395
396[[pipeline.test.steps]]
397name = "refine"
398model = "claude-sonnet"
399"#;
400
401 let config = PipelineConfig::from_toml(toml).unwrap();
402 let serialized = config.to_toml().unwrap();
403 let restored = PipelineConfig::from_toml(&serialized).unwrap();
404
405 let pipeline = restored.get_pipeline("test").unwrap();
406 assert_eq!(pipeline.steps.len(), 2);
407 assert_eq!(pipeline.steps[0].name, "summarize");
408 assert_eq!(pipeline.steps[1].name, "refine");
409 }
410
411 #[test]
412 fn test_default_values() {
413 let toml = r#"
414[pipeline.minimal]
415type = "director_critic"
416
417[pipeline.minimal.director]
418model = "slm"
419
420[pipeline.minimal.critic]
421model = "test"
422"#;
423
424 let config = PipelineConfig::from_toml(toml).unwrap();
425 let pipeline = config.get_pipeline("minimal").unwrap();
426
427 let director = pipeline.director.as_ref().unwrap();
428 assert!((director.temperature - 0.7).abs() < f64::EPSILON);
429
430 let critic = pipeline.critic.as_ref().unwrap();
431 assert_eq!(critic.evaluation_mode, EvaluationMode::Score);
432 assert!((critic.threshold - 0.7).abs() < f64::EPSILON);
433 }
434}