Skip to main content

tokf_common/
examples.rs

1use serde::{Deserialize, Serialize};
2
3use crate::safety::SafetyWarning;
4
5/// Estimate token count from a string using the bytes/4 heuristic.
6///
7/// This matches the estimation used by the tracking module.
8pub const fn estimate_tokens(s: &str) -> usize {
9    s.len() / 4
10}
11
12/// Compute the reduction percentage between raw and filtered token estimates.
13///
14/// Returns 0.0 when `raw_tokens` is zero.
15#[allow(clippy::cast_precision_loss)]
16pub fn reduction_pct(raw_tokens: usize, filtered_tokens: usize) -> f64 {
17    if raw_tokens == 0 {
18        0.0
19    } else {
20        (1.0 - filtered_tokens as f64 / raw_tokens as f64) * 100.0
21    }
22}
23
24/// A single before/after example for a filter.
25#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
26#[cfg_attr(
27    test,
28    derive(ts_rs::TS),
29    ts(export, export_to = "../../tokf-server/generated/")
30)]
31pub struct FilterExample {
32    /// Test case name.
33    pub name: String,
34    /// Exit code used for this example.
35    #[cfg_attr(test, ts(type = "number"))]
36    pub exit_code: i32,
37    /// Raw (unfiltered) input.
38    pub raw: String,
39    /// Filtered output.
40    pub filtered: String,
41    /// Number of lines in raw input.
42    #[cfg_attr(test, ts(type = "number"))]
43    pub raw_line_count: usize,
44    /// Number of lines in filtered output.
45    #[cfg_attr(test, ts(type = "number"))]
46    pub filtered_line_count: usize,
47    /// Estimated tokens in raw input (bytes / 4).
48    #[serde(default)]
49    #[cfg_attr(test, ts(type = "number"))]
50    pub raw_tokens_est: usize,
51    /// Estimated tokens in filtered output (bytes / 4).
52    #[serde(default)]
53    #[cfg_attr(test, ts(type = "number"))]
54    pub filtered_tokens_est: usize,
55    /// Percentage reduction in estimated tokens.
56    #[serde(default)]
57    pub reduction_pct: f64,
58}
59
60/// Collection of examples with aggregated safety results.
61#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
62#[cfg_attr(
63    test,
64    derive(ts_rs::TS),
65    ts(export, export_to = "../../tokf-server/generated/")
66)]
67pub struct FilterExamples {
68    pub examples: Vec<FilterExample>,
69    pub safety: ExamplesSafety,
70}
71
72/// Serializable safety summary for the examples payload.
73#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
74#[cfg_attr(
75    test,
76    derive(ts_rs::TS),
77    ts(export, export_to = "../../tokf-server/generated/")
78)]
79pub struct ExamplesSafety {
80    pub passed: bool,
81    pub warnings: Vec<SafetyWarningDto>,
82}
83
84/// A flattened, transport-friendly representation of a safety warning.
85#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
86#[cfg_attr(
87    test,
88    derive(ts_rs::TS),
89    ts(export, export_to = "../../tokf-server/generated/")
90)]
91pub struct SafetyWarningDto {
92    pub kind: String,
93    pub message: String,
94    #[serde(skip_serializing_if = "Option::is_none")]
95    pub detail: Option<String>,
96}
97
98impl From<&SafetyWarning> for SafetyWarningDto {
99    fn from(w: &SafetyWarning) -> Self {
100        Self {
101            kind: w.kind.as_str().to_string(),
102            message: w.message.clone(),
103            detail: w.detail.clone(),
104        }
105    }
106}
107
108#[cfg(test)]
109#[allow(clippy::unwrap_used)]
110mod tests {
111    use super::*;
112    use crate::safety::WarningKind;
113
114    #[test]
115    fn serialize_round_trip() {
116        let raw = "line1\nline2\nline3";
117        let filtered = "line1";
118        let examples = FilterExamples {
119            examples: vec![FilterExample {
120                name: "basic".to_string(),
121                exit_code: 0,
122                raw: raw.to_string(),
123                filtered: filtered.to_string(),
124                raw_line_count: 3,
125                filtered_line_count: 1,
126                raw_tokens_est: estimate_tokens(raw),
127                filtered_tokens_est: estimate_tokens(filtered),
128                reduction_pct: reduction_pct(estimate_tokens(raw), estimate_tokens(filtered)),
129            }],
130            safety: ExamplesSafety {
131                passed: true,
132                warnings: vec![],
133            },
134        };
135
136        let json = serde_json::to_string(&examples).unwrap();
137        let parsed: FilterExamples = serde_json::from_str(&json).unwrap();
138        assert_eq!(parsed.examples.len(), 1);
139        assert_eq!(parsed.examples[0].name, "basic");
140        assert!(parsed.safety.passed);
141    }
142
143    #[test]
144    fn deserialize_without_token_fields_defaults_to_zero() {
145        let json = r#"{"examples":[{"name":"old","exit_code":0,"raw":"abc","filtered":"a","raw_line_count":1,"filtered_line_count":1}],"safety":{"passed":true,"warnings":[]}}"#;
146        let parsed: FilterExamples = serde_json::from_str(json).unwrap();
147        assert_eq!(parsed.examples[0].raw_tokens_est, 0);
148        assert_eq!(parsed.examples[0].filtered_tokens_est, 0);
149        assert!((parsed.examples[0].reduction_pct).abs() < f64::EPSILON);
150    }
151
152    #[test]
153    fn estimate_tokens_basic() {
154        assert_eq!(estimate_tokens(""), 0);
155        assert_eq!(estimate_tokens("abcd"), 1);
156        assert_eq!(estimate_tokens("abcdefgh"), 2);
157        // 3 chars rounds down to 0
158        assert_eq!(estimate_tokens("abc"), 0);
159    }
160
161    #[test]
162    fn reduction_pct_basic() {
163        assert!((reduction_pct(100, 25) - 75.0).abs() < 0.01);
164        assert!((reduction_pct(100, 0) - 100.0).abs() < 0.01);
165        assert!((reduction_pct(100, 100)).abs() < 0.01);
166        assert!((reduction_pct(0, 0)).abs() < 0.01);
167    }
168
169    #[test]
170    fn warning_dto_from_safety_warning() {
171        let warning = SafetyWarning {
172            kind: WarningKind::TemplateInjection,
173            message: "bad template".to_string(),
174            detail: Some("ignore previous instructions".to_string()),
175        };
176        let dto = SafetyWarningDto::from(&warning);
177        assert_eq!(dto.kind, "template_injection");
178        assert_eq!(dto.message, "bad template");
179    }
180}