Skip to main content

symbi_runtime/reasoning/
pipeline_config.rs

1//! Declarative pipeline configuration
2//!
3//! Supports TOML-based pipeline definitions for director-critic patterns
4//! and other multi-agent workflows. Enterprise admins define mandatory
5//! quality gates that policies can reference.
6
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Top-level pipeline configuration, typically loaded from TOML.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct PipelineConfig {
13    /// Named pipeline definitions.
14    pub pipeline: HashMap<String, PipelineDefinition>,
15}
16
17/// A single pipeline definition.
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct PipelineDefinition {
20    /// Pipeline type (e.g., "director_critic", "chain", "map_reduce").
21    #[serde(rename = "type")]
22    pub pipeline_type: String,
23
24    /// Director/orchestrator configuration.
25    #[serde(default)]
26    pub director: Option<DirectorConfig>,
27
28    /// Critic configuration.
29    #[serde(default)]
30    pub critic: Option<CriticConfig>,
31
32    /// Convergence criteria.
33    #[serde(default)]
34    pub convergence: Option<ConvergenceConfig>,
35
36    /// Chain steps (for chain-type pipelines).
37    #[serde(default)]
38    pub steps: Vec<StepConfig>,
39}
40
41/// Director/orchestrator model configuration.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct DirectorConfig {
44    /// Model to use (e.g., "slm", "claude-sonnet", "gpt-4").
45    pub model: String,
46
47    /// Temperature for generation.
48    #[serde(default = "default_temperature")]
49    pub temperature: f64,
50
51    /// System prompt override.
52    #[serde(default)]
53    pub system_prompt: Option<String>,
54
55    /// Maximum tokens for each director response.
56    #[serde(default)]
57    pub max_tokens: Option<u32>,
58}
59
60/// Critic evaluation configuration.
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct CriticConfig {
63    /// Model to use for critique.
64    pub model: String,
65
66    /// Evaluation mode: "binary", "score", or "rubric".
67    #[serde(default = "default_evaluation_mode")]
68    pub evaluation_mode: EvaluationMode,
69
70    /// Score threshold for approval (0.0 - 1.0).
71    #[serde(default = "default_threshold")]
72    pub threshold: f64,
73
74    /// Rubric for multi-dimension evaluation.
75    #[serde(default)]
76    pub rubric: HashMap<String, RubricDimension>,
77
78    /// System prompt for the critic.
79    #[serde(default)]
80    pub system_prompt: Option<String>,
81}
82
83/// Evaluation mode for critics.
84#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
85#[serde(rename_all = "snake_case")]
86pub enum EvaluationMode {
87    /// Simple approve/reject.
88    Binary,
89    /// Single numeric score.
90    Score,
91    /// Multi-dimension rubric with weights.
92    Rubric,
93}
94
95/// A single dimension in a rubric-based evaluation.
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct RubricDimension {
98    /// Weight of this dimension in the overall score.
99    pub weight: f64,
100
101    /// Description of what this dimension evaluates.
102    #[serde(default)]
103    pub description: Option<String>,
104}
105
106/// Convergence criteria for iterative patterns.
107#[derive(Debug, Clone, Serialize, Deserialize)]
108#[serde(tag = "type", rename_all = "snake_case")]
109pub enum ConvergenceConfig {
110    /// Run for exactly N rounds.
111    FixedRounds { rounds: u32 },
112
113    /// Run until improvement drops below threshold, with min/max bounds.
114    AdaptiveBreak {
115        min_rounds: u32,
116        max_rounds: u32,
117        improvement_threshold: f64,
118    },
119
120    /// Run until a score threshold is met.
121    ScoreThreshold { target_score: f64, max_rounds: u32 },
122}
123
124/// A step in a chain-type pipeline.
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct StepConfig {
127    /// Step name.
128    pub name: String,
129
130    /// Model to use for this step.
131    pub model: String,
132
133    /// System prompt for this step.
134    #[serde(default)]
135    pub system_prompt: Option<String>,
136
137    /// Output format: "text", "json", or a schema name.
138    #[serde(default)]
139    pub output_format: Option<String>,
140}
141
142fn default_temperature() -> f64 {
143    0.7
144}
145
146fn default_evaluation_mode() -> EvaluationMode {
147    EvaluationMode::Score
148}
149
150fn default_threshold() -> f64 {
151    0.7
152}
153
154impl PipelineConfig {
155    /// Parse a pipeline configuration from TOML.
156    pub fn from_toml(toml_str: &str) -> Result<Self, PipelineConfigError> {
157        toml::from_str(toml_str).map_err(|e| PipelineConfigError::ParseError {
158            message: e.to_string(),
159        })
160    }
161
162    /// Serialize to TOML.
163    pub fn to_toml(&self) -> Result<String, PipelineConfigError> {
164        toml::to_string_pretty(self).map_err(|e| PipelineConfigError::SerializeError {
165            message: e.to_string(),
166        })
167    }
168
169    /// Get a pipeline definition by name.
170    pub fn get_pipeline(&self, name: &str) -> Option<&PipelineDefinition> {
171        self.pipeline.get(name)
172    }
173
174    /// Validate the configuration.
175    pub fn validate(&self) -> Result<(), PipelineConfigError> {
176        for (name, pipeline) in &self.pipeline {
177            // Validate critic threshold
178            if let Some(critic) = &pipeline.critic {
179                if !(0.0..=1.0).contains(&critic.threshold) {
180                    return Err(PipelineConfigError::ValidationError {
181                        field: format!("pipeline.{}.critic.threshold", name),
182                        message: "Threshold must be between 0.0 and 1.0".into(),
183                    });
184                }
185
186                // Validate rubric weights sum to ~1.0 if rubric mode
187                if critic.evaluation_mode == EvaluationMode::Rubric && !critic.rubric.is_empty() {
188                    let total_weight: f64 = critic.rubric.values().map(|d| d.weight).sum();
189                    if (total_weight - 1.0).abs() > 0.01 {
190                        return Err(PipelineConfigError::ValidationError {
191                            field: format!("pipeline.{}.critic.rubric", name),
192                            message: format!(
193                                "Rubric weights must sum to 1.0, got {}",
194                                total_weight
195                            ),
196                        });
197                    }
198                }
199            }
200
201            // Validate convergence bounds
202            if let Some(ConvergenceConfig::AdaptiveBreak {
203                min_rounds,
204                max_rounds,
205                ..
206            }) = &pipeline.convergence
207            {
208                if min_rounds > max_rounds {
209                    return Err(PipelineConfigError::ValidationError {
210                        field: format!("pipeline.{}.convergence", name),
211                        message: "min_rounds must be <= max_rounds".into(),
212                    });
213                }
214            }
215        }
216
217        Ok(())
218    }
219}
220
221/// Errors from pipeline configuration.
222#[derive(Debug, thiserror::Error)]
223pub enum PipelineConfigError {
224    #[error("TOML parse error: {message}")]
225    ParseError { message: String },
226
227    #[error("TOML serialization error: {message}")]
228    SerializeError { message: String },
229
230    #[error("Validation error in '{field}': {message}")]
231    ValidationError { field: String, message: String },
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    #[test]
239    fn test_parse_director_critic_pipeline() {
240        let toml = r#"
241[pipeline.compliance_review]
242type = "director_critic"
243
244[pipeline.compliance_review.director]
245model = "slm"
246temperature = 0.7
247
248[pipeline.compliance_review.critic]
249model = "claude-sonnet"
250evaluation_mode = "rubric"
251threshold = 0.85
252
253[pipeline.compliance_review.critic.rubric.accuracy]
254weight = 0.4
255
256[pipeline.compliance_review.critic.rubric.compliance]
257weight = 0.3
258
259[pipeline.compliance_review.critic.rubric.completeness]
260weight = 0.3
261
262[pipeline.compliance_review.convergence]
263type = "adaptive_break"
264min_rounds = 1
265max_rounds = 3
266improvement_threshold = 0.05
267"#;
268
269        let config = PipelineConfig::from_toml(toml).unwrap();
270        assert!(config.validate().is_ok());
271
272        let pipeline = config.get_pipeline("compliance_review").unwrap();
273        assert_eq!(pipeline.pipeline_type, "director_critic");
274
275        let critic = pipeline.critic.as_ref().unwrap();
276        assert_eq!(critic.model, "claude-sonnet");
277        assert_eq!(critic.evaluation_mode, EvaluationMode::Rubric);
278        assert_eq!(critic.rubric.len(), 3);
279
280        let director = pipeline.director.as_ref().unwrap();
281        assert_eq!(director.model, "slm");
282        assert!((director.temperature - 0.7).abs() < f64::EPSILON);
283    }
284
285    #[test]
286    fn test_parse_fixed_rounds_convergence() {
287        let toml = r#"
288[pipeline.simple]
289type = "director_critic"
290
291[pipeline.simple.convergence]
292type = "fixed_rounds"
293rounds = 3
294"#;
295
296        let config = PipelineConfig::from_toml(toml).unwrap();
297        let pipeline = config.get_pipeline("simple").unwrap();
298        match pipeline.convergence.as_ref().unwrap() {
299            ConvergenceConfig::FixedRounds { rounds } => assert_eq!(*rounds, 3),
300            _ => panic!("Expected FixedRounds"),
301        }
302    }
303
304    #[test]
305    fn test_parse_score_threshold_convergence() {
306        let toml = r#"
307[pipeline.quality]
308type = "director_critic"
309
310[pipeline.quality.convergence]
311type = "score_threshold"
312target_score = 0.9
313max_rounds = 5
314"#;
315
316        let config = PipelineConfig::from_toml(toml).unwrap();
317        let pipeline = config.get_pipeline("quality").unwrap();
318        match pipeline.convergence.as_ref().unwrap() {
319            ConvergenceConfig::ScoreThreshold {
320                target_score,
321                max_rounds,
322            } => {
323                assert!((target_score - 0.9).abs() < f64::EPSILON);
324                assert_eq!(*max_rounds, 5);
325            }
326            _ => panic!("Expected ScoreThreshold"),
327        }
328    }
329
330    #[test]
331    fn test_validate_invalid_threshold() {
332        let toml = r#"
333[pipeline.bad]
334type = "director_critic"
335
336[pipeline.bad.critic]
337model = "test"
338threshold = 1.5
339"#;
340
341        let config = PipelineConfig::from_toml(toml).unwrap();
342        assert!(config.validate().is_err());
343    }
344
345    #[test]
346    fn test_validate_rubric_weights() {
347        let toml = r#"
348[pipeline.bad]
349type = "director_critic"
350
351[pipeline.bad.critic]
352model = "test"
353evaluation_mode = "rubric"
354threshold = 0.5
355
356[pipeline.bad.critic.rubric.a]
357weight = 0.3
358
359[pipeline.bad.critic.rubric.b]
360weight = 0.3
361"#;
362
363        let config = PipelineConfig::from_toml(toml).unwrap();
364        let err = config.validate().unwrap_err();
365        assert!(err.to_string().contains("sum to 1.0"));
366    }
367
368    #[test]
369    fn test_validate_convergence_bounds() {
370        let toml = r#"
371[pipeline.bad]
372type = "director_critic"
373
374[pipeline.bad.convergence]
375type = "adaptive_break"
376min_rounds = 5
377max_rounds = 3
378improvement_threshold = 0.05
379"#;
380
381        let config = PipelineConfig::from_toml(toml).unwrap();
382        let err = config.validate().unwrap_err();
383        assert!(err.to_string().contains("min_rounds"));
384    }
385
386    #[test]
387    fn test_roundtrip_serialization() {
388        let toml = r#"
389[pipeline.test]
390type = "chain"
391
392[[pipeline.test.steps]]
393name = "summarize"
394model = "slm"
395
396[[pipeline.test.steps]]
397name = "refine"
398model = "claude-sonnet"
399"#;
400
401        let config = PipelineConfig::from_toml(toml).unwrap();
402        let serialized = config.to_toml().unwrap();
403        let restored = PipelineConfig::from_toml(&serialized).unwrap();
404
405        let pipeline = restored.get_pipeline("test").unwrap();
406        assert_eq!(pipeline.steps.len(), 2);
407        assert_eq!(pipeline.steps[0].name, "summarize");
408        assert_eq!(pipeline.steps[1].name, "refine");
409    }
410
411    #[test]
412    fn test_default_values() {
413        let toml = r#"
414[pipeline.minimal]
415type = "director_critic"
416
417[pipeline.minimal.director]
418model = "slm"
419
420[pipeline.minimal.critic]
421model = "test"
422"#;
423
424        let config = PipelineConfig::from_toml(toml).unwrap();
425        let pipeline = config.get_pipeline("minimal").unwrap();
426
427        let director = pipeline.director.as_ref().unwrap();
428        assert!((director.temperature - 0.7).abs() < f64::EPSILON);
429
430        let critic = pipeline.critic.as_ref().unwrap();
431        assert_eq!(critic.evaluation_mode, EvaluationMode::Score);
432        assert!((critic.threshold - 0.7).abs() < f64::EPSILON);
433    }
434}