1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4use crate::error::{Result, SqzError};
5
6pub struct PresetParser;
45
46impl PresetParser {
47 pub fn parse(toml_str: &str) -> Result<Preset> {
49 let preset: Preset = toml::from_str(toml_str)?;
50 Self::validate(&preset)?;
51 Ok(preset)
52 }
53
54 pub fn to_toml(preset: &Preset) -> Result<String> {
56 Ok(toml::to_string_pretty(preset)?)
57 }
58
59 pub fn validate(preset: &Preset) -> Result<()> {
61 if preset.preset.name.is_empty() {
62 return Err(SqzError::PresetValidation {
63 field: "preset.name".to_string(),
64 message: "must not be empty".to_string(),
65 });
66 }
67
68 if preset.preset.version.is_empty() {
69 return Err(SqzError::PresetValidation {
70 field: "preset.version".to_string(),
71 message: "must not be empty".to_string(),
72 });
73 }
74
75 let wt = preset.budget.warning_threshold;
76 if !(wt > 0.0 && wt < 1.0) {
77 return Err(SqzError::PresetValidation {
78 field: "budget.warning_threshold".to_string(),
79 message: "must be between 0.0 and 1.0".to_string(),
80 });
81 }
82
83 let ct = preset.budget.ceiling_threshold;
84 if !(ct > 0.0 && ct < 1.0) || ct <= wt {
85 return Err(SqzError::PresetValidation {
86 field: "budget.ceiling_threshold".to_string(),
87 message: "must be between 0.0 and 1.0 and greater than warning_threshold"
88 .to_string(),
89 });
90 }
91
92 let max_tools = preset.tool_selection.max_tools;
93 if !(1..=50).contains(&max_tools) {
94 return Err(SqzError::PresetValidation {
95 field: "tool_selection.max_tools".to_string(),
96 message: "must be between 1 and 50".to_string(),
97 });
98 }
99
100 let st = preset.tool_selection.similarity_threshold;
101 if !(st > 0.0 && st < 1.0) {
102 return Err(SqzError::PresetValidation {
103 field: "tool_selection.similarity_threshold".to_string(),
104 message: "must be between 0.0 and 1.0".to_string(),
105 });
106 }
107
108 let cxt = preset.model.complexity_threshold;
109 if !(cxt > 0.0 && cxt < 1.0) {
110 return Err(SqzError::PresetValidation {
111 field: "model.complexity_threshold".to_string(),
112 message: "must be between 0.0 and 1.0".to_string(),
113 });
114 }
115
116 Ok(())
117 }
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct Preset {
127 pub preset: PresetHeader,
128 pub compression: CompressionConfig,
129 pub tool_selection: ToolSelectionConfig,
130 pub budget: BudgetConfig,
131 pub terse_mode: TerseModeConfig,
132 pub model: ModelConfig,
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct PresetHeader {
138 pub name: String,
140 pub version: String,
142 #[serde(default)]
144 pub description: String,
145}
146
147pub type PresetMeta = PresetHeader;
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct CompressionConfig {
154 #[serde(default)]
155 pub stages: Vec<String>,
156 pub keep_fields: Option<KeepFieldsConfig>,
157 pub strip_fields: Option<StripFieldsConfig>,
158 pub condense: Option<CondenseConfig>,
159 pub git_diff_fold: Option<GitDiffFoldConfig>,
160 pub strip_nulls: Option<StripNullsConfig>,
161 pub flatten: Option<FlattenConfig>,
162 pub truncate_strings: Option<TruncateStringsConfig>,
163 pub collapse_arrays: Option<CollapseArraysConfig>,
164 pub custom_transforms: Option<CustomTransformsConfig>,
165}
166
167#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct GitDiffFoldConfig {
169 pub enabled: bool,
170 #[serde(default = "default_max_context_lines")]
171 pub max_context_lines: u32,
172}
173
174fn default_max_context_lines() -> u32 {
175 2
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct KeepFieldsConfig {
180 pub enabled: bool,
181 #[serde(default)]
182 pub fields: Vec<String>,
183}
184
185#[derive(Debug, Clone, Serialize, Deserialize)]
186pub struct StripFieldsConfig {
187 pub enabled: bool,
188 #[serde(default)]
189 pub fields: Vec<String>,
190}
191
192#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct CondenseConfig {
194 pub enabled: bool,
195 #[serde(default = "default_max_repeated_lines")]
196 pub max_repeated_lines: u32,
197}
198
199fn default_max_repeated_lines() -> u32 {
200 3
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize)]
204pub struct StripNullsConfig {
205 pub enabled: bool,
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct FlattenConfig {
210 pub enabled: bool,
211 #[serde(default = "default_max_depth")]
212 pub max_depth: u32,
213}
214
215fn default_max_depth() -> u32 {
216 3
217}
218
219#[derive(Debug, Clone, Serialize, Deserialize)]
220pub struct TruncateStringsConfig {
221 pub enabled: bool,
222 #[serde(default = "default_max_length")]
223 pub max_length: u32,
224}
225
226fn default_max_length() -> u32 {
227 500
228}
229
230#[derive(Debug, Clone, Serialize, Deserialize)]
231pub struct CollapseArraysConfig {
232 pub enabled: bool,
233 #[serde(default = "default_max_items")]
234 pub max_items: u32,
235 #[serde(default)]
236 pub summary_template: String,
237}
238
239fn default_max_items() -> u32 {
240 5
241}
242
243#[derive(Debug, Clone, Serialize, Deserialize)]
244pub struct CustomTransformsConfig {
245 pub enabled: bool,
246}
247
248#[derive(Debug, Clone, Serialize, Deserialize)]
251pub struct ToolSelectionConfig {
252 #[serde(default = "default_max_tools")]
253 pub max_tools: usize,
254 #[serde(default = "default_similarity_threshold")]
255 pub similarity_threshold: f64,
256 #[serde(default)]
257 pub default_tools: Vec<String>,
258}
259
260fn default_max_tools() -> usize {
261 5
262}
263
264fn default_similarity_threshold() -> f64 {
265 0.7
266}
267
268#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct BudgetConfig {
272 #[serde(default = "default_warning_threshold")]
273 pub warning_threshold: f64,
274 #[serde(default = "default_ceiling_threshold")]
275 pub ceiling_threshold: f64,
276 #[serde(default = "default_window_size")]
277 pub default_window_size: u32,
278 #[serde(default)]
279 pub agents: HashMap<String, f64>,
280}
281
282fn default_warning_threshold() -> f64 {
283 0.70
284}
285
286fn default_ceiling_threshold() -> f64 {
287 0.85
288}
289
290fn default_window_size() -> u32 {
291 200_000
292}
293
294#[derive(Debug, Clone, Serialize, Deserialize)]
297pub struct TerseModeConfig {
298 pub enabled: bool,
299 #[serde(default = "default_terse_level")]
300 pub level: TerseLevel,
301}
302
303#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
304#[serde(rename_all = "lowercase")]
305pub enum TerseLevel {
306 Minimal,
307 Moderate,
308 Verbose,
309}
310
311fn default_terse_level() -> TerseLevel {
312 TerseLevel::Moderate
313}
314
315#[derive(Debug, Clone, Serialize, Deserialize)]
318pub struct ModelConfig {
319 pub family: String,
320 #[serde(default)]
321 pub primary: String,
322 #[serde(default)]
323 pub local: String,
324 #[serde(default = "default_complexity_threshold")]
325 pub complexity_threshold: f64,
326 pub pricing: Option<ModelPricingConfig>,
327}
328
329fn default_complexity_threshold() -> f64 {
330 0.4
331}
332
333#[derive(Debug, Clone, Serialize, Deserialize)]
334pub struct ModelPricingConfig {
335 pub input_per_1k: f64,
336 pub output_per_1k: f64,
337 #[serde(default)]
338 pub cache_read_discount: f64,
339}
340
341impl Default for Preset {
342 fn default() -> Self {
343 Preset {
344 preset: PresetMeta {
345 name: "default".to_string(),
346 version: "1.0".to_string(),
347 description: "Default compression preset for general development".to_string(),
348 },
349 compression: CompressionConfig {
350 stages: vec![
351 "keep_fields".to_string(),
352 "strip_fields".to_string(),
353 "condense".to_string(),
354 "strip_nulls".to_string(),
355 "flatten".to_string(),
356 "truncate_strings".to_string(),
357 "collapse_arrays".to_string(),
358 "custom_transforms".to_string(),
359 ],
360 keep_fields: Some(KeepFieldsConfig {
361 enabled: false,
362 fields: vec![
363 "id".to_string(),
364 "name".to_string(),
365 "type".to_string(),
366 "status".to_string(),
367 "error".to_string(),
368 "message".to_string(),
369 ],
370 }),
371 strip_fields: Some(StripFieldsConfig {
372 enabled: true,
373 fields: vec![
374 "metadata.internal_id".to_string(),
375 "debug_info".to_string(),
376 "trace_id".to_string(),
377 ],
378 }),
379 condense: Some(CondenseConfig {
380 enabled: true,
381 max_repeated_lines: 3,
382 }),
383 git_diff_fold: Some(GitDiffFoldConfig {
384 enabled: true,
385 max_context_lines: 2,
386 }),
387 strip_nulls: Some(StripNullsConfig { enabled: true }),
388 flatten: Some(FlattenConfig {
389 enabled: true,
390 max_depth: 3,
391 }),
392 truncate_strings: Some(TruncateStringsConfig {
393 enabled: true,
394 max_length: 500,
395 }),
396 collapse_arrays: Some(CollapseArraysConfig {
397 enabled: true,
398 max_items: 5,
399 summary_template: "... and {remaining} more items".to_string(),
400 }),
401 custom_transforms: Some(CustomTransformsConfig { enabled: true }),
402 },
403 tool_selection: ToolSelectionConfig {
404 max_tools: 5,
405 similarity_threshold: 0.7,
406 default_tools: vec![
407 "read_file".to_string(),
408 "write_file".to_string(),
409 "search".to_string(),
410 ],
411 },
412 budget: BudgetConfig {
413 warning_threshold: 0.70,
414 ceiling_threshold: 0.85,
415 default_window_size: 200_000,
416 agents: {
417 let mut m = HashMap::new();
418 m.insert("parent".to_string(), 0.60);
419 m.insert("child".to_string(), 0.20);
420 m
421 },
422 },
423 terse_mode: TerseModeConfig {
424 enabled: true,
425 level: TerseLevel::Moderate,
426 },
427 model: ModelConfig {
428 family: "anthropic".to_string(),
429 primary: "claude-sonnet-4-20250514".to_string(),
430 local: "llama-3.1-8b".to_string(),
431 complexity_threshold: 0.4,
432 pricing: Some(ModelPricingConfig {
433 input_per_1k: 0.003,
434 output_per_1k: 0.015,
435 cache_read_discount: 0.9,
436 }),
437 },
438 }
439 }
440}
441
442#[cfg(test)]
447mod tests {
448 use super::*;
449 use proptest::prelude::*;
450
451 fn arb_nonempty_string() -> impl Strategy<Value = String> {
457 "[a-zA-Z0-9_\\-\\.]{1,32}".prop_map(|s| s)
458 }
459
460 fn arb_open_unit() -> impl Strategy<Value = f64> {
462 (1u32..=9999u32).prop_map(|n| n as f64 / 10_000.0)
463 }
464
465 fn arb_budget_config() -> impl Strategy<Value = BudgetConfig> {
467 (1u32..=8999u32).prop_flat_map(|w_raw| {
469 let warning = w_raw as f64 / 10_000.0; let c_min = (w_raw + 1) as f64 / 10_000.0;
472 let c_max = 9999.0_f64 / 10_000.0;
473 let c_min_int = w_raw + 1;
475 (c_min_int..=9999u32).prop_map(move |c_raw| {
476 let ceiling = c_raw as f64 / 10_000.0;
477 let _ = (c_min, c_max); BudgetConfig {
479 warning_threshold: warning,
480 ceiling_threshold: ceiling,
481 default_window_size: 200_000,
482 agents: Default::default(),
483 }
484 })
485 })
486 }
487
488 fn arb_tool_selection_config() -> impl Strategy<Value = ToolSelectionConfig> {
490 (1usize..=50usize, arb_open_unit()).prop_map(|(max_tools, similarity_threshold)| {
491 ToolSelectionConfig {
492 max_tools,
493 similarity_threshold,
494 default_tools: vec![],
495 }
496 })
497 }
498
499 fn arb_model_config() -> impl Strategy<Value = ModelConfig> {
501 (arb_nonempty_string(), arb_open_unit()).prop_map(|(family, complexity_threshold)| {
502 ModelConfig {
503 family,
504 primary: String::new(),
505 local: String::new(),
506 complexity_threshold,
507 pricing: None,
508 }
509 })
510 }
511
512 fn arb_preset() -> impl Strategy<Value = Preset> {
514 (
515 arb_nonempty_string(), arb_nonempty_string(), arb_budget_config(),
518 arb_tool_selection_config(),
519 arb_model_config(),
520 )
521 .prop_map(|(name, version, budget, tool_selection, model)| Preset {
522 preset: PresetMeta {
523 name,
524 version,
525 description: String::new(),
526 },
527 compression: CompressionConfig {
528 stages: vec![],
529 keep_fields: None,
530 strip_fields: None,
531 condense: None,
532 git_diff_fold: None,
533 strip_nulls: None,
534 flatten: None,
535 truncate_strings: None,
536 collapse_arrays: None,
537 custom_transforms: None,
538 },
539 tool_selection,
540 budget,
541 terse_mode: TerseModeConfig {
542 enabled: false,
543 level: TerseLevel::Moderate,
544 },
545 model,
546 })
547 }
548
549 proptest! {
555 #[test]
567 fn prop_preset_toml_round_trip(preset in arb_preset()) {
568 let toml1 = PresetParser::to_toml(&preset)
570 .expect("to_toml should not fail on a valid preset");
571
572 let parsed = PresetParser::parse(&toml1)
574 .expect("parse should not fail on a valid TOML string");
575
576 let toml2 = PresetParser::to_toml(&parsed)
578 .expect("to_toml should not fail on re-parsed preset");
579
580 prop_assert_eq!(
582 &toml1,
583 &toml2,
584 "TOML round-trip mismatch:\nfirst: {}\nsecond: {}",
585 toml1,
586 toml2
587 );
588 }
589 }
590
591 fn arb_invalid_warning_threshold() -> impl Strategy<Value = f64> {
598 prop_oneof![
599 Just(0.0_f64),
600 Just(1.0_f64),
601 (1u32..=10000u32).prop_map(|n| -(n as f64 / 10_000.0)),
603 (10001u32..=20000u32).prop_map(|n| n as f64 / 10_000.0),
605 ]
606 }
607
608 fn arb_invalid_ceiling_threshold() -> impl Strategy<Value = f64> {
610 prop_oneof![
611 Just(0.0_f64),
612 Just(1.0_f64),
613 (1u32..=10000u32).prop_map(|n| -(n as f64 / 10_000.0)),
614 (10001u32..=20000u32).prop_map(|n| n as f64 / 10_000.0),
615 ]
616 }
617
618 fn arb_invalid_max_tools() -> impl Strategy<Value = usize> {
620 prop_oneof![
621 Just(0usize),
622 (51usize..=200usize),
623 ]
624 }
625
626 fn arb_invalid_complexity_threshold() -> impl Strategy<Value = f64> {
628 prop_oneof![
629 Just(0.0_f64),
630 Just(1.0_f64),
631 (1u32..=10000u32).prop_map(|n| -(n as f64 / 10_000.0)),
632 (10001u32..=20000u32).prop_map(|n| n as f64 / 10_000.0),
633 ]
634 }
635
636 proptest! {
637 #[test]
642 fn prop_invalid_warning_threshold_error_mentions_field(
643 invalid_wt in arb_invalid_warning_threshold()
644 ) {
645 let mut preset = Preset::default();
646 preset.budget.warning_threshold = invalid_wt;
647 preset.budget.ceiling_threshold = 0.85;
653
654 let result = PresetParser::validate(&preset);
655 prop_assert!(result.is_err(), "expected validation error for warning_threshold={}", invalid_wt);
656 let err_msg = result.unwrap_err().to_string();
657 prop_assert!(
658 err_msg.contains("budget.warning_threshold"),
659 "error message '{}' does not mention 'budget.warning_threshold'",
660 err_msg
661 );
662 }
663
664 #[test]
669 fn prop_invalid_ceiling_threshold_error_mentions_field(
670 invalid_ct in arb_invalid_ceiling_threshold()
671 ) {
672 let mut preset = Preset::default();
673 preset.budget.warning_threshold = 0.70;
675 preset.budget.ceiling_threshold = invalid_ct;
676
677 let result = PresetParser::validate(&preset);
678 prop_assert!(result.is_err(), "expected validation error for ceiling_threshold={}", invalid_ct);
679 let err_msg = result.unwrap_err().to_string();
680 prop_assert!(
681 err_msg.contains("budget.ceiling_threshold"),
682 "error message '{}' does not mention 'budget.ceiling_threshold'",
683 err_msg
684 );
685 }
686
687 #[test]
692 fn prop_empty_preset_name_error_mentions_field(_dummy in 0u32..1u32) {
693 let mut preset = Preset::default();
694 preset.preset.name = String::new();
695
696 let result = PresetParser::validate(&preset);
697 prop_assert!(result.is_err(), "expected validation error for empty preset.name");
698 let err_msg = result.unwrap_err().to_string();
699 prop_assert!(
700 err_msg.contains("preset.name"),
701 "error message '{}' does not mention 'preset.name'",
702 err_msg
703 );
704 }
705
706 #[test]
711 fn prop_invalid_max_tools_error_mentions_field(
712 invalid_mt in arb_invalid_max_tools()
713 ) {
714 let mut preset = Preset::default();
715 preset.tool_selection.max_tools = invalid_mt;
716
717 let result = PresetParser::validate(&preset);
718 prop_assert!(result.is_err(), "expected validation error for max_tools={}", invalid_mt);
719 let err_msg = result.unwrap_err().to_string();
720 prop_assert!(
721 err_msg.contains("tool_selection.max_tools"),
722 "error message '{}' does not mention 'tool_selection.max_tools'",
723 err_msg
724 );
725 }
726
727 #[test]
732 fn prop_invalid_complexity_threshold_error_mentions_field(
733 invalid_cxt in arb_invalid_complexity_threshold()
734 ) {
735 let mut preset = Preset::default();
736 preset.model.complexity_threshold = invalid_cxt;
737
738 let result = PresetParser::validate(&preset);
739 prop_assert!(result.is_err(), "expected validation error for complexity_threshold={}", invalid_cxt);
740 let err_msg = result.unwrap_err().to_string();
741 prop_assert!(
742 err_msg.contains("model.complexity_threshold"),
743 "error message '{}' does not mention 'model.complexity_threshold'",
744 err_msg
745 );
746 }
747 }
748}