1use serde::{Deserialize, Serialize};
2
3use crate::safety::SafetyWarning;
4
5pub const fn estimate_tokens(s: &str) -> usize {
9 s.len() / 4
10}
11
12#[allow(clippy::cast_precision_loss)]
16pub fn reduction_pct(raw_tokens: usize, filtered_tokens: usize) -> f64 {
17 if raw_tokens == 0 {
18 0.0
19 } else {
20 (1.0 - filtered_tokens as f64 / raw_tokens as f64) * 100.0
21 }
22}
23
24#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
26#[cfg_attr(
27 test,
28 derive(ts_rs::TS),
29 ts(export, export_to = "../../tokf-server/generated/")
30)]
31pub struct FilterExample {
32 pub name: String,
34 #[cfg_attr(test, ts(type = "number"))]
36 pub exit_code: i32,
37 pub raw: String,
39 pub filtered: String,
41 #[cfg_attr(test, ts(type = "number"))]
43 pub raw_line_count: usize,
44 #[cfg_attr(test, ts(type = "number"))]
46 pub filtered_line_count: usize,
47 #[serde(default)]
49 #[cfg_attr(test, ts(type = "number"))]
50 pub raw_tokens_est: usize,
51 #[serde(default)]
53 #[cfg_attr(test, ts(type = "number"))]
54 pub filtered_tokens_est: usize,
55 #[serde(default)]
57 pub reduction_pct: f64,
58}
59
60#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
62#[cfg_attr(
63 test,
64 derive(ts_rs::TS),
65 ts(export, export_to = "../../tokf-server/generated/")
66)]
67pub struct FilterExamples {
68 pub examples: Vec<FilterExample>,
69 pub safety: ExamplesSafety,
70}
71
72#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
74#[cfg_attr(
75 test,
76 derive(ts_rs::TS),
77 ts(export, export_to = "../../tokf-server/generated/")
78)]
79pub struct ExamplesSafety {
80 pub passed: bool,
81 pub warnings: Vec<SafetyWarningDto>,
82}
83
84#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
86#[cfg_attr(
87 test,
88 derive(ts_rs::TS),
89 ts(export, export_to = "../../tokf-server/generated/")
90)]
91pub struct SafetyWarningDto {
92 pub kind: String,
93 pub message: String,
94 #[serde(skip_serializing_if = "Option::is_none")]
95 pub detail: Option<String>,
96}
97
98impl From<&SafetyWarning> for SafetyWarningDto {
99 fn from(w: &SafetyWarning) -> Self {
100 Self {
101 kind: w.kind.as_str().to_string(),
102 message: w.message.clone(),
103 detail: w.detail.clone(),
104 }
105 }
106}
107
108#[cfg(test)]
109#[allow(clippy::unwrap_used)]
110mod tests {
111 use super::*;
112 use crate::safety::WarningKind;
113
114 #[test]
115 fn serialize_round_trip() {
116 let raw = "line1\nline2\nline3";
117 let filtered = "line1";
118 let examples = FilterExamples {
119 examples: vec![FilterExample {
120 name: "basic".to_string(),
121 exit_code: 0,
122 raw: raw.to_string(),
123 filtered: filtered.to_string(),
124 raw_line_count: 3,
125 filtered_line_count: 1,
126 raw_tokens_est: estimate_tokens(raw),
127 filtered_tokens_est: estimate_tokens(filtered),
128 reduction_pct: reduction_pct(estimate_tokens(raw), estimate_tokens(filtered)),
129 }],
130 safety: ExamplesSafety {
131 passed: true,
132 warnings: vec![],
133 },
134 };
135
136 let json = serde_json::to_string(&examples).unwrap();
137 let parsed: FilterExamples = serde_json::from_str(&json).unwrap();
138 assert_eq!(parsed.examples.len(), 1);
139 assert_eq!(parsed.examples[0].name, "basic");
140 assert!(parsed.safety.passed);
141 }
142
143 #[test]
144 fn deserialize_without_token_fields_defaults_to_zero() {
145 let json = r#"{"examples":[{"name":"old","exit_code":0,"raw":"abc","filtered":"a","raw_line_count":1,"filtered_line_count":1}],"safety":{"passed":true,"warnings":[]}}"#;
146 let parsed: FilterExamples = serde_json::from_str(json).unwrap();
147 assert_eq!(parsed.examples[0].raw_tokens_est, 0);
148 assert_eq!(parsed.examples[0].filtered_tokens_est, 0);
149 assert!((parsed.examples[0].reduction_pct).abs() < f64::EPSILON);
150 }
151
152 #[test]
153 fn estimate_tokens_basic() {
154 assert_eq!(estimate_tokens(""), 0);
155 assert_eq!(estimate_tokens("abcd"), 1);
156 assert_eq!(estimate_tokens("abcdefgh"), 2);
157 assert_eq!(estimate_tokens("abc"), 0);
159 }
160
161 #[test]
162 fn reduction_pct_basic() {
163 assert!((reduction_pct(100, 25) - 75.0).abs() < 0.01);
164 assert!((reduction_pct(100, 0) - 100.0).abs() < 0.01);
165 assert!((reduction_pct(100, 100)).abs() < 0.01);
166 assert!((reduction_pct(0, 0)).abs() < 0.01);
167 }
168
169 #[test]
170 fn warning_dto_from_safety_warning() {
171 let warning = SafetyWarning {
172 kind: WarningKind::TemplateInjection,
173 message: "bad template".to_string(),
174 detail: Some("ignore previous instructions".to_string()),
175 };
176 let dto = SafetyWarningDto::from(&warning);
177 assert_eq!(dto.kind, "template_injection");
178 assert_eq!(dto.message, "bad template");
179 }
180}