1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4use crate::error::{Result, SqzError};
5
6pub struct PresetParser;
8
9impl PresetParser {
10 pub fn parse(toml_str: &str) -> Result<Preset> {
12 let preset: Preset = toml::from_str(toml_str)?;
13 Self::validate(&preset)?;
14 Ok(preset)
15 }
16
17 pub fn to_toml(preset: &Preset) -> Result<String> {
19 Ok(toml::to_string_pretty(preset)?)
20 }
21
22 pub fn validate(preset: &Preset) -> Result<()> {
24 if preset.preset.name.is_empty() {
25 return Err(SqzError::PresetValidation {
26 field: "preset.name".to_string(),
27 message: "must not be empty".to_string(),
28 });
29 }
30
31 if preset.preset.version.is_empty() {
32 return Err(SqzError::PresetValidation {
33 field: "preset.version".to_string(),
34 message: "must not be empty".to_string(),
35 });
36 }
37
38 let wt = preset.budget.warning_threshold;
39 if !(wt > 0.0 && wt < 1.0) {
40 return Err(SqzError::PresetValidation {
41 field: "budget.warning_threshold".to_string(),
42 message: "must be between 0.0 and 1.0".to_string(),
43 });
44 }
45
46 let ct = preset.budget.ceiling_threshold;
47 if !(ct > 0.0 && ct < 1.0) || ct <= wt {
48 return Err(SqzError::PresetValidation {
49 field: "budget.ceiling_threshold".to_string(),
50 message: "must be between 0.0 and 1.0 and greater than warning_threshold"
51 .to_string(),
52 });
53 }
54
55 let max_tools = preset.tool_selection.max_tools;
56 if !(1..=50).contains(&max_tools) {
57 return Err(SqzError::PresetValidation {
58 field: "tool_selection.max_tools".to_string(),
59 message: "must be between 1 and 50".to_string(),
60 });
61 }
62
63 let st = preset.tool_selection.similarity_threshold;
64 if !(st > 0.0 && st < 1.0) {
65 return Err(SqzError::PresetValidation {
66 field: "tool_selection.similarity_threshold".to_string(),
67 message: "must be between 0.0 and 1.0".to_string(),
68 });
69 }
70
71 let cxt = preset.model.complexity_threshold;
72 if !(cxt > 0.0 && cxt < 1.0) {
73 return Err(SqzError::PresetValidation {
74 field: "model.complexity_threshold".to_string(),
75 message: "must be between 0.0 and 1.0".to_string(),
76 });
77 }
78
79 Ok(())
80 }
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct Preset {
85 pub preset: PresetHeader,
86 pub compression: CompressionConfig,
87 pub tool_selection: ToolSelectionConfig,
88 pub budget: BudgetConfig,
89 pub terse_mode: TerseModeConfig,
90 pub model: ModelConfig,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct PresetHeader {
96 pub name: String,
98 pub version: String,
100 #[serde(default)]
102 pub description: String,
103}
104
105pub type PresetMeta = PresetHeader;
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct CompressionConfig {
112 #[serde(default)]
113 pub stages: Vec<String>,
114 pub keep_fields: Option<KeepFieldsConfig>,
115 pub strip_fields: Option<StripFieldsConfig>,
116 pub condense: Option<CondenseConfig>,
117 pub git_diff_fold: Option<GitDiffFoldConfig>,
118 pub strip_nulls: Option<StripNullsConfig>,
119 pub flatten: Option<FlattenConfig>,
120 pub truncate_strings: Option<TruncateStringsConfig>,
121 pub collapse_arrays: Option<CollapseArraysConfig>,
122 pub custom_transforms: Option<CustomTransformsConfig>,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct GitDiffFoldConfig {
127 pub enabled: bool,
128 #[serde(default = "default_max_context_lines")]
129 pub max_context_lines: u32,
130}
131
132fn default_max_context_lines() -> u32 {
133 2
134}
135
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct KeepFieldsConfig {
138 pub enabled: bool,
139 #[serde(default)]
140 pub fields: Vec<String>,
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct StripFieldsConfig {
145 pub enabled: bool,
146 #[serde(default)]
147 pub fields: Vec<String>,
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct CondenseConfig {
152 pub enabled: bool,
153 #[serde(default = "default_max_repeated_lines")]
154 pub max_repeated_lines: u32,
155}
156
157fn default_max_repeated_lines() -> u32 {
158 3
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct StripNullsConfig {
163 pub enabled: bool,
164}
165
166#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct FlattenConfig {
168 pub enabled: bool,
169 #[serde(default = "default_max_depth")]
170 pub max_depth: u32,
171}
172
173fn default_max_depth() -> u32 {
174 3
175}
176
177#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct TruncateStringsConfig {
179 pub enabled: bool,
180 #[serde(default = "default_max_length")]
181 pub max_length: u32,
182}
183
184fn default_max_length() -> u32 {
185 500
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct CollapseArraysConfig {
190 pub enabled: bool,
191 #[serde(default = "default_max_items")]
192 pub max_items: u32,
193 #[serde(default)]
194 pub summary_template: String,
195}
196
197fn default_max_items() -> u32 {
198 5
199}
200
201#[derive(Debug, Clone, Serialize, Deserialize)]
202pub struct CustomTransformsConfig {
203 pub enabled: bool,
204}
205
206#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct ToolSelectionConfig {
210 #[serde(default = "default_max_tools")]
211 pub max_tools: usize,
212 #[serde(default = "default_similarity_threshold")]
213 pub similarity_threshold: f64,
214 #[serde(default)]
215 pub default_tools: Vec<String>,
216}
217
218fn default_max_tools() -> usize {
219 5
220}
221
222fn default_similarity_threshold() -> f64 {
223 0.7
224}
225
226#[derive(Debug, Clone, Serialize, Deserialize)]
229pub struct BudgetConfig {
230 #[serde(default = "default_warning_threshold")]
231 pub warning_threshold: f64,
232 #[serde(default = "default_ceiling_threshold")]
233 pub ceiling_threshold: f64,
234 #[serde(default = "default_window_size")]
235 pub default_window_size: u32,
236 #[serde(default)]
237 pub agents: HashMap<String, f64>,
238}
239
240fn default_warning_threshold() -> f64 {
241 0.70
242}
243
244fn default_ceiling_threshold() -> f64 {
245 0.85
246}
247
248fn default_window_size() -> u32 {
249 200_000
250}
251
252#[derive(Debug, Clone, Serialize, Deserialize)]
255pub struct TerseModeConfig {
256 pub enabled: bool,
257 #[serde(default = "default_terse_level")]
258 pub level: TerseLevel,
259}
260
261#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
262#[serde(rename_all = "lowercase")]
263pub enum TerseLevel {
264 Minimal,
265 Moderate,
266 Verbose,
267}
268
269fn default_terse_level() -> TerseLevel {
270 TerseLevel::Moderate
271}
272
273#[derive(Debug, Clone, Serialize, Deserialize)]
276pub struct ModelConfig {
277 pub family: String,
278 #[serde(default)]
279 pub primary: String,
280 #[serde(default)]
281 pub local: String,
282 #[serde(default = "default_complexity_threshold")]
283 pub complexity_threshold: f64,
284 pub pricing: Option<ModelPricingConfig>,
285}
286
287fn default_complexity_threshold() -> f64 {
288 0.4
289}
290
291#[derive(Debug, Clone, Serialize, Deserialize)]
292pub struct ModelPricingConfig {
293 pub input_per_1k: f64,
294 pub output_per_1k: f64,
295 #[serde(default)]
296 pub cache_read_discount: f64,
297}
298
299impl Default for Preset {
300 fn default() -> Self {
301 Preset {
302 preset: PresetMeta {
303 name: "default".to_string(),
304 version: "1.0".to_string(),
305 description: "Default compression preset for general development".to_string(),
306 },
307 compression: CompressionConfig {
308 stages: vec![
309 "keep_fields".to_string(),
310 "strip_fields".to_string(),
311 "condense".to_string(),
312 "strip_nulls".to_string(),
313 "flatten".to_string(),
314 "truncate_strings".to_string(),
315 "collapse_arrays".to_string(),
316 "custom_transforms".to_string(),
317 ],
318 keep_fields: Some(KeepFieldsConfig {
319 enabled: false,
320 fields: vec![
321 "id".to_string(),
322 "name".to_string(),
323 "type".to_string(),
324 "status".to_string(),
325 "error".to_string(),
326 "message".to_string(),
327 ],
328 }),
329 strip_fields: Some(StripFieldsConfig {
330 enabled: true,
331 fields: vec![
332 "metadata.internal_id".to_string(),
333 "debug_info".to_string(),
334 "trace_id".to_string(),
335 ],
336 }),
337 condense: Some(CondenseConfig {
338 enabled: true,
339 max_repeated_lines: 3,
340 }),
341 git_diff_fold: Some(GitDiffFoldConfig {
342 enabled: true,
343 max_context_lines: 2,
344 }),
345 strip_nulls: Some(StripNullsConfig { enabled: true }),
346 flatten: Some(FlattenConfig {
347 enabled: true,
348 max_depth: 3,
349 }),
350 truncate_strings: Some(TruncateStringsConfig {
351 enabled: true,
352 max_length: 500,
353 }),
354 collapse_arrays: Some(CollapseArraysConfig {
355 enabled: true,
356 max_items: 5,
357 summary_template: "... and {remaining} more items".to_string(),
358 }),
359 custom_transforms: Some(CustomTransformsConfig { enabled: true }),
360 },
361 tool_selection: ToolSelectionConfig {
362 max_tools: 5,
363 similarity_threshold: 0.7,
364 default_tools: vec![
365 "read_file".to_string(),
366 "write_file".to_string(),
367 "search".to_string(),
368 ],
369 },
370 budget: BudgetConfig {
371 warning_threshold: 0.70,
372 ceiling_threshold: 0.85,
373 default_window_size: 200_000,
374 agents: {
375 let mut m = HashMap::new();
376 m.insert("parent".to_string(), 0.60);
377 m.insert("child".to_string(), 0.20);
378 m
379 },
380 },
381 terse_mode: TerseModeConfig {
382 enabled: true,
383 level: TerseLevel::Moderate,
384 },
385 model: ModelConfig {
386 family: "anthropic".to_string(),
387 primary: "claude-sonnet-4-20250514".to_string(),
388 local: "llama-3.1-8b".to_string(),
389 complexity_threshold: 0.4,
390 pricing: Some(ModelPricingConfig {
391 input_per_1k: 0.003,
392 output_per_1k: 0.015,
393 cache_read_discount: 0.9,
394 }),
395 },
396 }
397 }
398}
399
400#[cfg(test)]
405mod tests {
406 use super::*;
407 use proptest::prelude::*;
408
409 fn arb_nonempty_string() -> impl Strategy<Value = String> {
415 "[a-zA-Z0-9_\\-\\.]{1,32}".prop_map(|s| s)
416 }
417
418 fn arb_open_unit() -> impl Strategy<Value = f64> {
420 (1u32..=9999u32).prop_map(|n| n as f64 / 10_000.0)
421 }
422
423 fn arb_budget_config() -> impl Strategy<Value = BudgetConfig> {
425 (1u32..=8999u32).prop_flat_map(|w_raw| {
427 let warning = w_raw as f64 / 10_000.0; let c_min = (w_raw + 1) as f64 / 10_000.0;
430 let c_max = 9999.0_f64 / 10_000.0;
431 let c_min_int = w_raw + 1;
433 (c_min_int..=9999u32).prop_map(move |c_raw| {
434 let ceiling = c_raw as f64 / 10_000.0;
435 let _ = (c_min, c_max); BudgetConfig {
437 warning_threshold: warning,
438 ceiling_threshold: ceiling,
439 default_window_size: 200_000,
440 agents: Default::default(),
441 }
442 })
443 })
444 }
445
446 fn arb_tool_selection_config() -> impl Strategy<Value = ToolSelectionConfig> {
448 (1usize..=50usize, arb_open_unit()).prop_map(|(max_tools, similarity_threshold)| {
449 ToolSelectionConfig {
450 max_tools,
451 similarity_threshold,
452 default_tools: vec![],
453 }
454 })
455 }
456
457 fn arb_model_config() -> impl Strategy<Value = ModelConfig> {
459 (arb_nonempty_string(), arb_open_unit()).prop_map(|(family, complexity_threshold)| {
460 ModelConfig {
461 family,
462 primary: String::new(),
463 local: String::new(),
464 complexity_threshold,
465 pricing: None,
466 }
467 })
468 }
469
470 fn arb_preset() -> impl Strategy<Value = Preset> {
472 (
473 arb_nonempty_string(), arb_nonempty_string(), arb_budget_config(),
476 arb_tool_selection_config(),
477 arb_model_config(),
478 )
479 .prop_map(|(name, version, budget, tool_selection, model)| Preset {
480 preset: PresetMeta {
481 name,
482 version,
483 description: String::new(),
484 },
485 compression: CompressionConfig {
486 stages: vec![],
487 keep_fields: None,
488 strip_fields: None,
489 condense: None,
490 git_diff_fold: None,
491 strip_nulls: None,
492 flatten: None,
493 truncate_strings: None,
494 collapse_arrays: None,
495 custom_transforms: None,
496 },
497 tool_selection,
498 budget,
499 terse_mode: TerseModeConfig {
500 enabled: false,
501 level: TerseLevel::Moderate,
502 },
503 model,
504 })
505 }
506
507 proptest! {
513 #[test]
525 fn prop_preset_toml_round_trip(preset in arb_preset()) {
526 let toml1 = PresetParser::to_toml(&preset)
528 .expect("to_toml should not fail on a valid preset");
529
530 let parsed = PresetParser::parse(&toml1)
532 .expect("parse should not fail on a valid TOML string");
533
534 let toml2 = PresetParser::to_toml(&parsed)
536 .expect("to_toml should not fail on re-parsed preset");
537
538 prop_assert_eq!(
540 &toml1,
541 &toml2,
542 "TOML round-trip mismatch:\nfirst: {}\nsecond: {}",
543 toml1,
544 toml2
545 );
546 }
547 }
548
549 fn arb_invalid_warning_threshold() -> impl Strategy<Value = f64> {
556 prop_oneof![
557 Just(0.0_f64),
558 Just(1.0_f64),
559 (1u32..=10000u32).prop_map(|n| -(n as f64 / 10_000.0)),
561 (10001u32..=20000u32).prop_map(|n| n as f64 / 10_000.0),
563 ]
564 }
565
566 fn arb_invalid_ceiling_threshold() -> impl Strategy<Value = f64> {
568 prop_oneof![
569 Just(0.0_f64),
570 Just(1.0_f64),
571 (1u32..=10000u32).prop_map(|n| -(n as f64 / 10_000.0)),
572 (10001u32..=20000u32).prop_map(|n| n as f64 / 10_000.0),
573 ]
574 }
575
576 fn arb_invalid_max_tools() -> impl Strategy<Value = usize> {
578 prop_oneof![
579 Just(0usize),
580 (51usize..=200usize),
581 ]
582 }
583
584 fn arb_invalid_complexity_threshold() -> impl Strategy<Value = f64> {
586 prop_oneof![
587 Just(0.0_f64),
588 Just(1.0_f64),
589 (1u32..=10000u32).prop_map(|n| -(n as f64 / 10_000.0)),
590 (10001u32..=20000u32).prop_map(|n| n as f64 / 10_000.0),
591 ]
592 }
593
594 proptest! {
595 #[test]
600 fn prop_invalid_warning_threshold_error_mentions_field(
601 invalid_wt in arb_invalid_warning_threshold()
602 ) {
603 let mut preset = Preset::default();
604 preset.budget.warning_threshold = invalid_wt;
605 preset.budget.ceiling_threshold = 0.85;
611
612 let result = PresetParser::validate(&preset);
613 prop_assert!(result.is_err(), "expected validation error for warning_threshold={}", invalid_wt);
614 let err_msg = result.unwrap_err().to_string();
615 prop_assert!(
616 err_msg.contains("budget.warning_threshold"),
617 "error message '{}' does not mention 'budget.warning_threshold'",
618 err_msg
619 );
620 }
621
622 #[test]
627 fn prop_invalid_ceiling_threshold_error_mentions_field(
628 invalid_ct in arb_invalid_ceiling_threshold()
629 ) {
630 let mut preset = Preset::default();
631 preset.budget.warning_threshold = 0.70;
633 preset.budget.ceiling_threshold = invalid_ct;
634
635 let result = PresetParser::validate(&preset);
636 prop_assert!(result.is_err(), "expected validation error for ceiling_threshold={}", invalid_ct);
637 let err_msg = result.unwrap_err().to_string();
638 prop_assert!(
639 err_msg.contains("budget.ceiling_threshold"),
640 "error message '{}' does not mention 'budget.ceiling_threshold'",
641 err_msg
642 );
643 }
644
645 #[test]
650 fn prop_empty_preset_name_error_mentions_field(_dummy in 0u32..1u32) {
651 let mut preset = Preset::default();
652 preset.preset.name = String::new();
653
654 let result = PresetParser::validate(&preset);
655 prop_assert!(result.is_err(), "expected validation error for empty preset.name");
656 let err_msg = result.unwrap_err().to_string();
657 prop_assert!(
658 err_msg.contains("preset.name"),
659 "error message '{}' does not mention 'preset.name'",
660 err_msg
661 );
662 }
663
664 #[test]
669 fn prop_invalid_max_tools_error_mentions_field(
670 invalid_mt in arb_invalid_max_tools()
671 ) {
672 let mut preset = Preset::default();
673 preset.tool_selection.max_tools = invalid_mt;
674
675 let result = PresetParser::validate(&preset);
676 prop_assert!(result.is_err(), "expected validation error for max_tools={}", invalid_mt);
677 let err_msg = result.unwrap_err().to_string();
678 prop_assert!(
679 err_msg.contains("tool_selection.max_tools"),
680 "error message '{}' does not mention 'tool_selection.max_tools'",
681 err_msg
682 );
683 }
684
685 #[test]
690 fn prop_invalid_complexity_threshold_error_mentions_field(
691 invalid_cxt in arb_invalid_complexity_threshold()
692 ) {
693 let mut preset = Preset::default();
694 preset.model.complexity_threshold = invalid_cxt;
695
696 let result = PresetParser::validate(&preset);
697 prop_assert!(result.is_err(), "expected validation error for complexity_threshold={}", invalid_cxt);
698 let err_msg = result.unwrap_err().to_string();
699 prop_assert!(
700 err_msg.contains("model.complexity_threshold"),
701 "error message '{}' does not mention 'model.complexity_threshold'",
702 err_msg
703 );
704 }
705 }
706}