1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4use crate::error::{Result, SqzError};
5
6pub struct PresetParser;
8
9impl PresetParser {
10 pub fn parse(toml_str: &str) -> Result<Preset> {
12 let preset: Preset = toml::from_str(toml_str)?;
13 Self::validate(&preset)?;
14 Ok(preset)
15 }
16
17 pub fn to_toml(preset: &Preset) -> Result<String> {
19 Ok(toml::to_string_pretty(preset)?)
20 }
21
22 pub fn validate(preset: &Preset) -> Result<()> {
24 if preset.preset.name.is_empty() {
25 return Err(SqzError::PresetValidation {
26 field: "preset.name".to_string(),
27 message: "must not be empty".to_string(),
28 });
29 }
30
31 if preset.preset.version.is_empty() {
32 return Err(SqzError::PresetValidation {
33 field: "preset.version".to_string(),
34 message: "must not be empty".to_string(),
35 });
36 }
37
38 let wt = preset.budget.warning_threshold;
39 if !(wt > 0.0 && wt < 1.0) {
40 return Err(SqzError::PresetValidation {
41 field: "budget.warning_threshold".to_string(),
42 message: "must be between 0.0 and 1.0".to_string(),
43 });
44 }
45
46 let ct = preset.budget.ceiling_threshold;
47 if !(ct > 0.0 && ct < 1.0) || ct <= wt {
48 return Err(SqzError::PresetValidation {
49 field: "budget.ceiling_threshold".to_string(),
50 message: "must be between 0.0 and 1.0 and greater than warning_threshold"
51 .to_string(),
52 });
53 }
54
55 let max_tools = preset.tool_selection.max_tools;
56 if !(1..=50).contains(&max_tools) {
57 return Err(SqzError::PresetValidation {
58 field: "tool_selection.max_tools".to_string(),
59 message: "must be between 1 and 50".to_string(),
60 });
61 }
62
63 let st = preset.tool_selection.similarity_threshold;
64 if !(st > 0.0 && st < 1.0) {
65 return Err(SqzError::PresetValidation {
66 field: "tool_selection.similarity_threshold".to_string(),
67 message: "must be between 0.0 and 1.0".to_string(),
68 });
69 }
70
71 let cxt = preset.model.complexity_threshold;
72 if !(cxt > 0.0 && cxt < 1.0) {
73 return Err(SqzError::PresetValidation {
74 field: "model.complexity_threshold".to_string(),
75 message: "must be between 0.0 and 1.0".to_string(),
76 });
77 }
78
79 Ok(())
80 }
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct Preset {
85 pub preset: PresetMeta,
86 pub compression: CompressionConfig,
87 pub tool_selection: ToolSelectionConfig,
88 pub budget: BudgetConfig,
89 pub terse_mode: TerseModeConfig,
90 pub model: ModelConfig,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct PresetMeta {
95 pub name: String,
96 pub version: String,
97 #[serde(default)]
98 pub description: String,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct CompressionConfig {
105 #[serde(default)]
106 pub stages: Vec<String>,
107 pub keep_fields: Option<KeepFieldsConfig>,
108 pub strip_fields: Option<StripFieldsConfig>,
109 pub condense: Option<CondenseConfig>,
110 pub strip_nulls: Option<StripNullsConfig>,
111 pub flatten: Option<FlattenConfig>,
112 pub truncate_strings: Option<TruncateStringsConfig>,
113 pub collapse_arrays: Option<CollapseArraysConfig>,
114 pub custom_transforms: Option<CustomTransformsConfig>,
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct KeepFieldsConfig {
119 pub enabled: bool,
120 #[serde(default)]
121 pub fields: Vec<String>,
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct StripFieldsConfig {
126 pub enabled: bool,
127 #[serde(default)]
128 pub fields: Vec<String>,
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct CondenseConfig {
133 pub enabled: bool,
134 #[serde(default = "default_max_repeated_lines")]
135 pub max_repeated_lines: u32,
136}
137
138fn default_max_repeated_lines() -> u32 {
139 3
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct StripNullsConfig {
144 pub enabled: bool,
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct FlattenConfig {
149 pub enabled: bool,
150 #[serde(default = "default_max_depth")]
151 pub max_depth: u32,
152}
153
154fn default_max_depth() -> u32 {
155 3
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct TruncateStringsConfig {
160 pub enabled: bool,
161 #[serde(default = "default_max_length")]
162 pub max_length: u32,
163}
164
165fn default_max_length() -> u32 {
166 500
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
170pub struct CollapseArraysConfig {
171 pub enabled: bool,
172 #[serde(default = "default_max_items")]
173 pub max_items: u32,
174 #[serde(default)]
175 pub summary_template: String,
176}
177
178fn default_max_items() -> u32 {
179 5
180}
181
182#[derive(Debug, Clone, Serialize, Deserialize)]
183pub struct CustomTransformsConfig {
184 pub enabled: bool,
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct ToolSelectionConfig {
191 #[serde(default = "default_max_tools")]
192 pub max_tools: usize,
193 #[serde(default = "default_similarity_threshold")]
194 pub similarity_threshold: f64,
195 #[serde(default)]
196 pub default_tools: Vec<String>,
197}
198
199fn default_max_tools() -> usize {
200 5
201}
202
203fn default_similarity_threshold() -> f64 {
204 0.7
205}
206
207#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct BudgetConfig {
211 #[serde(default = "default_warning_threshold")]
212 pub warning_threshold: f64,
213 #[serde(default = "default_ceiling_threshold")]
214 pub ceiling_threshold: f64,
215 #[serde(default = "default_window_size")]
216 pub default_window_size: u32,
217 #[serde(default)]
218 pub agents: HashMap<String, f64>,
219}
220
221fn default_warning_threshold() -> f64 {
222 0.70
223}
224
225fn default_ceiling_threshold() -> f64 {
226 0.85
227}
228
229fn default_window_size() -> u32 {
230 200_000
231}
232
233#[derive(Debug, Clone, Serialize, Deserialize)]
236pub struct TerseModeConfig {
237 pub enabled: bool,
238 #[serde(default = "default_terse_level")]
239 pub level: TerseLevel,
240}
241
242#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
243#[serde(rename_all = "lowercase")]
244pub enum TerseLevel {
245 Minimal,
246 Moderate,
247 Verbose,
248}
249
250fn default_terse_level() -> TerseLevel {
251 TerseLevel::Moderate
252}
253
254#[derive(Debug, Clone, Serialize, Deserialize)]
257pub struct ModelConfig {
258 pub family: String,
259 #[serde(default)]
260 pub primary: String,
261 #[serde(default)]
262 pub local: String,
263 #[serde(default = "default_complexity_threshold")]
264 pub complexity_threshold: f64,
265 pub pricing: Option<ModelPricingConfig>,
266}
267
268fn default_complexity_threshold() -> f64 {
269 0.4
270}
271
272#[derive(Debug, Clone, Serialize, Deserialize)]
273pub struct ModelPricingConfig {
274 pub input_per_1k: f64,
275 pub output_per_1k: f64,
276 #[serde(default)]
277 pub cache_read_discount: f64,
278}
279
280impl Default for Preset {
281 fn default() -> Self {
282 Preset {
283 preset: PresetMeta {
284 name: "default".to_string(),
285 version: "1.0".to_string(),
286 description: "Default compression preset for general development".to_string(),
287 },
288 compression: CompressionConfig {
289 stages: vec![
290 "keep_fields".to_string(),
291 "strip_fields".to_string(),
292 "condense".to_string(),
293 "strip_nulls".to_string(),
294 "flatten".to_string(),
295 "truncate_strings".to_string(),
296 "collapse_arrays".to_string(),
297 "custom_transforms".to_string(),
298 ],
299 keep_fields: Some(KeepFieldsConfig {
300 enabled: false,
301 fields: vec![
302 "id".to_string(),
303 "name".to_string(),
304 "type".to_string(),
305 "status".to_string(),
306 "error".to_string(),
307 "message".to_string(),
308 ],
309 }),
310 strip_fields: Some(StripFieldsConfig {
311 enabled: true,
312 fields: vec![
313 "metadata.internal_id".to_string(),
314 "debug_info".to_string(),
315 "trace_id".to_string(),
316 ],
317 }),
318 condense: Some(CondenseConfig {
319 enabled: true,
320 max_repeated_lines: 3,
321 }),
322 strip_nulls: Some(StripNullsConfig { enabled: true }),
323 flatten: Some(FlattenConfig {
324 enabled: true,
325 max_depth: 3,
326 }),
327 truncate_strings: Some(TruncateStringsConfig {
328 enabled: true,
329 max_length: 500,
330 }),
331 collapse_arrays: Some(CollapseArraysConfig {
332 enabled: true,
333 max_items: 5,
334 summary_template: "... and {remaining} more items".to_string(),
335 }),
336 custom_transforms: Some(CustomTransformsConfig { enabled: true }),
337 },
338 tool_selection: ToolSelectionConfig {
339 max_tools: 5,
340 similarity_threshold: 0.7,
341 default_tools: vec![
342 "read_file".to_string(),
343 "write_file".to_string(),
344 "search".to_string(),
345 ],
346 },
347 budget: BudgetConfig {
348 warning_threshold: 0.70,
349 ceiling_threshold: 0.85,
350 default_window_size: 200_000,
351 agents: {
352 let mut m = HashMap::new();
353 m.insert("parent".to_string(), 0.60);
354 m.insert("child".to_string(), 0.20);
355 m
356 },
357 },
358 terse_mode: TerseModeConfig {
359 enabled: true,
360 level: TerseLevel::Moderate,
361 },
362 model: ModelConfig {
363 family: "anthropic".to_string(),
364 primary: "claude-sonnet-4-20250514".to_string(),
365 local: "llama-3.1-8b".to_string(),
366 complexity_threshold: 0.4,
367 pricing: Some(ModelPricingConfig {
368 input_per_1k: 0.003,
369 output_per_1k: 0.015,
370 cache_read_discount: 0.9,
371 }),
372 },
373 }
374 }
375}
376
377#[cfg(test)]
382mod tests {
383 use super::*;
384 use proptest::prelude::*;
385
386 fn arb_nonempty_string() -> impl Strategy<Value = String> {
392 "[a-zA-Z0-9_\\-\\.]{1,32}".prop_map(|s| s)
393 }
394
395 fn arb_open_unit() -> impl Strategy<Value = f64> {
397 (1u32..=9999u32).prop_map(|n| n as f64 / 10_000.0)
398 }
399
400 fn arb_budget_config() -> impl Strategy<Value = BudgetConfig> {
402 (1u32..=8999u32).prop_flat_map(|w_raw| {
404 let warning = w_raw as f64 / 10_000.0; let c_min = (w_raw + 1) as f64 / 10_000.0;
407 let c_max = 9999.0_f64 / 10_000.0;
408 let c_min_int = w_raw + 1;
410 (c_min_int..=9999u32).prop_map(move |c_raw| {
411 let ceiling = c_raw as f64 / 10_000.0;
412 let _ = (c_min, c_max); BudgetConfig {
414 warning_threshold: warning,
415 ceiling_threshold: ceiling,
416 default_window_size: 200_000,
417 agents: Default::default(),
418 }
419 })
420 })
421 }
422
423 fn arb_tool_selection_config() -> impl Strategy<Value = ToolSelectionConfig> {
425 (1usize..=50usize, arb_open_unit()).prop_map(|(max_tools, similarity_threshold)| {
426 ToolSelectionConfig {
427 max_tools,
428 similarity_threshold,
429 default_tools: vec![],
430 }
431 })
432 }
433
434 fn arb_model_config() -> impl Strategy<Value = ModelConfig> {
436 (arb_nonempty_string(), arb_open_unit()).prop_map(|(family, complexity_threshold)| {
437 ModelConfig {
438 family,
439 primary: String::new(),
440 local: String::new(),
441 complexity_threshold,
442 pricing: None,
443 }
444 })
445 }
446
447 fn arb_preset() -> impl Strategy<Value = Preset> {
449 (
450 arb_nonempty_string(), arb_nonempty_string(), arb_budget_config(),
453 arb_tool_selection_config(),
454 arb_model_config(),
455 )
456 .prop_map(|(name, version, budget, tool_selection, model)| Preset {
457 preset: PresetMeta {
458 name,
459 version,
460 description: String::new(),
461 },
462 compression: CompressionConfig {
463 stages: vec![],
464 keep_fields: None,
465 strip_fields: None,
466 condense: None,
467 strip_nulls: None,
468 flatten: None,
469 truncate_strings: None,
470 collapse_arrays: None,
471 custom_transforms: None,
472 },
473 tool_selection,
474 budget,
475 terse_mode: TerseModeConfig {
476 enabled: false,
477 level: TerseLevel::Moderate,
478 },
479 model,
480 })
481 }
482
483 proptest! {
489 #[test]
501 fn prop_preset_toml_round_trip(preset in arb_preset()) {
502 let toml1 = PresetParser::to_toml(&preset)
504 .expect("to_toml should not fail on a valid preset");
505
506 let parsed = PresetParser::parse(&toml1)
508 .expect("parse should not fail on a valid TOML string");
509
510 let toml2 = PresetParser::to_toml(&parsed)
512 .expect("to_toml should not fail on re-parsed preset");
513
514 prop_assert_eq!(
516 &toml1,
517 &toml2,
518 "TOML round-trip mismatch:\nfirst: {}\nsecond: {}",
519 toml1,
520 toml2
521 );
522 }
523 }
524
525 fn arb_invalid_warning_threshold() -> impl Strategy<Value = f64> {
532 prop_oneof![
533 Just(0.0_f64),
534 Just(1.0_f64),
535 (1u32..=10000u32).prop_map(|n| -(n as f64 / 10_000.0)),
537 (10001u32..=20000u32).prop_map(|n| n as f64 / 10_000.0),
539 ]
540 }
541
542 fn arb_invalid_ceiling_threshold() -> impl Strategy<Value = f64> {
544 prop_oneof![
545 Just(0.0_f64),
546 Just(1.0_f64),
547 (1u32..=10000u32).prop_map(|n| -(n as f64 / 10_000.0)),
548 (10001u32..=20000u32).prop_map(|n| n as f64 / 10_000.0),
549 ]
550 }
551
552 fn arb_invalid_max_tools() -> impl Strategy<Value = usize> {
554 prop_oneof![
555 Just(0usize),
556 (51usize..=200usize),
557 ]
558 }
559
560 fn arb_invalid_complexity_threshold() -> impl Strategy<Value = f64> {
562 prop_oneof![
563 Just(0.0_f64),
564 Just(1.0_f64),
565 (1u32..=10000u32).prop_map(|n| -(n as f64 / 10_000.0)),
566 (10001u32..=20000u32).prop_map(|n| n as f64 / 10_000.0),
567 ]
568 }
569
570 proptest! {
571 #[test]
576 fn prop_invalid_warning_threshold_error_mentions_field(
577 invalid_wt in arb_invalid_warning_threshold()
578 ) {
579 let mut preset = Preset::default();
580 preset.budget.warning_threshold = invalid_wt;
581 preset.budget.ceiling_threshold = 0.85;
587
588 let result = PresetParser::validate(&preset);
589 prop_assert!(result.is_err(), "expected validation error for warning_threshold={}", invalid_wt);
590 let err_msg = result.unwrap_err().to_string();
591 prop_assert!(
592 err_msg.contains("budget.warning_threshold"),
593 "error message '{}' does not mention 'budget.warning_threshold'",
594 err_msg
595 );
596 }
597
598 #[test]
603 fn prop_invalid_ceiling_threshold_error_mentions_field(
604 invalid_ct in arb_invalid_ceiling_threshold()
605 ) {
606 let mut preset = Preset::default();
607 preset.budget.warning_threshold = 0.70;
609 preset.budget.ceiling_threshold = invalid_ct;
610
611 let result = PresetParser::validate(&preset);
612 prop_assert!(result.is_err(), "expected validation error for ceiling_threshold={}", invalid_ct);
613 let err_msg = result.unwrap_err().to_string();
614 prop_assert!(
615 err_msg.contains("budget.ceiling_threshold"),
616 "error message '{}' does not mention 'budget.ceiling_threshold'",
617 err_msg
618 );
619 }
620
621 #[test]
626 fn prop_empty_preset_name_error_mentions_field(_dummy in 0u32..1u32) {
627 let mut preset = Preset::default();
628 preset.preset.name = String::new();
629
630 let result = PresetParser::validate(&preset);
631 prop_assert!(result.is_err(), "expected validation error for empty preset.name");
632 let err_msg = result.unwrap_err().to_string();
633 prop_assert!(
634 err_msg.contains("preset.name"),
635 "error message '{}' does not mention 'preset.name'",
636 err_msg
637 );
638 }
639
640 #[test]
645 fn prop_invalid_max_tools_error_mentions_field(
646 invalid_mt in arb_invalid_max_tools()
647 ) {
648 let mut preset = Preset::default();
649 preset.tool_selection.max_tools = invalid_mt;
650
651 let result = PresetParser::validate(&preset);
652 prop_assert!(result.is_err(), "expected validation error for max_tools={}", invalid_mt);
653 let err_msg = result.unwrap_err().to_string();
654 prop_assert!(
655 err_msg.contains("tool_selection.max_tools"),
656 "error message '{}' does not mention 'tool_selection.max_tools'",
657 err_msg
658 );
659 }
660
661 #[test]
666 fn prop_invalid_complexity_threshold_error_mentions_field(
667 invalid_cxt in arb_invalid_complexity_threshold()
668 ) {
669 let mut preset = Preset::default();
670 preset.model.complexity_threshold = invalid_cxt;
671
672 let result = PresetParser::validate(&preset);
673 prop_assert!(result.is_err(), "expected validation error for complexity_threshold={}", invalid_cxt);
674 let err_msg = result.unwrap_err().to_string();
675 prop_assert!(
676 err_msg.contains("model.complexity_threshold"),
677 "error message '{}' does not mention 'model.complexity_threshold'",
678 err_msg
679 );
680 }
681 }
682}