Skip to main content

tokmd_context_policy/
lib.rs

1//! Deterministic context/handoff policy helpers.
2
3#![forbid(unsafe_code)]
4
5use tokmd_path::normalize_slashes as normalize_path;
6use tokmd_types::{FileClassification, InclusionPolicy};
7
8/// Default maximum fraction of budget a single file may consume.
9pub const DEFAULT_MAX_FILE_PCT: f64 = 0.15;
10/// Default hard cap for a single file when no explicit cap is provided.
11pub const DEFAULT_MAX_FILE_TOKENS: usize = 16_000;
12/// Default tokens-per-line threshold for dense blob detection.
13pub const DEFAULT_DENSE_THRESHOLD: f64 = 50.0;
14
15const LOCKFILES: &[&str] = &[
16    "Cargo.lock",
17    "package-lock.json",
18    "pnpm-lock.yaml",
19    "yarn.lock",
20    "poetry.lock",
21    "Pipfile.lock",
22    "go.sum",
23    "composer.lock",
24    "Gemfile.lock",
25];
26
27const SMART_EXCLUDE_SUFFIXES: &[(&str, &str)] = &[
28    (".min.js", "minified"),
29    (".min.css", "minified"),
30    (".js.map", "sourcemap"),
31    (".css.map", "sourcemap"),
32];
33
34const SPINE_PATTERNS: &[&str] = &[
35    "README.md",
36    "README",
37    "README.rst",
38    "README.txt",
39    "ROADMAP.md",
40    "docs/ROADMAP.md",
41    "CONTRIBUTING.md",
42    "Cargo.toml",
43    "package.json",
44    "pyproject.toml",
45    "go.mod",
46    "docs/architecture.md",
47    "docs/design.md",
48    "tokmd.toml",
49    "cockpit.toml",
50];
51
52const GENERATED_PATTERNS: &[&str] = &[
53    "node-types.json",
54    "grammar.json",
55    ".generated.",
56    ".pb.go",
57    ".pb.rs",
58    "_pb2.py",
59    ".g.dart",
60    ".freezed.dart",
61];
62
63const VENDORED_DIRS: &[&str] = &["vendor/", "third_party/", "third-party/", "node_modules/"];
64const FIXTURE_DIRS: &[&str] = &[
65    "fixtures/",
66    "testdata/",
67    "test_data/",
68    "__snapshots__/",
69    "golden/",
70];
71
72/// Returns the smart-exclude reason for a path, if any.
73///
74/// Reasons:
75/// - `lockfile`
76/// - `minified`
77/// - `sourcemap`
78#[must_use]
79pub fn smart_exclude_reason(path: &str) -> Option<&'static str> {
80    let basename = path.rsplit('/').next().unwrap_or(path);
81
82    if LOCKFILES.contains(&basename) {
83        return Some("lockfile");
84    }
85
86    for &(suffix, reason) in SMART_EXCLUDE_SUFFIXES {
87        if basename.ends_with(suffix) {
88            return Some(reason);
89        }
90    }
91
92    None
93}
94
95/// Returns `true` when a path matches a "spine" file that should be prioritized.
96#[must_use]
97pub fn is_spine_file(path: &str) -> bool {
98    let normalized = normalize_path(path);
99    let basename = normalized.rsplit('/').next().unwrap_or(&normalized);
100
101    for &pattern in SPINE_PATTERNS {
102        if pattern.contains('/') {
103            if normalized == pattern || normalized.ends_with(&format!("/{pattern}")) {
104                return true;
105            }
106        } else if basename == pattern {
107            return true;
108        }
109    }
110
111    false
112}
113
114/// Classify a file for context/handoff hygiene policy evaluation.
115#[must_use]
116pub fn classify_file(
117    path: &str,
118    tokens: usize,
119    lines: usize,
120    dense_threshold: f64,
121) -> Vec<FileClassification> {
122    let mut classes = Vec::new();
123    let normalized = normalize_path(path);
124    let basename = normalized.rsplit('/').next().unwrap_or(&normalized);
125
126    if LOCKFILES.contains(&basename) {
127        classes.push(FileClassification::Lockfile);
128    }
129
130    if basename.ends_with(".min.js") || basename.ends_with(".min.css") {
131        classes.push(FileClassification::Minified);
132    }
133
134    if basename.ends_with(".js.map") || basename.ends_with(".css.map") {
135        classes.push(FileClassification::Sourcemap);
136    }
137
138    if GENERATED_PATTERNS
139        .iter()
140        .any(|pat| basename == *pat || basename.contains(pat))
141    {
142        classes.push(FileClassification::Generated);
143    }
144
145    if VENDORED_DIRS
146        .iter()
147        .any(|dir| normalized.contains(dir) || normalized.starts_with(dir.trim_end_matches('/')))
148    {
149        classes.push(FileClassification::Vendored);
150    }
151
152    if FIXTURE_DIRS
153        .iter()
154        .any(|dir| normalized.contains(dir) || normalized.starts_with(dir.trim_end_matches('/')))
155    {
156        classes.push(FileClassification::Fixture);
157    }
158
159    let effective_lines = lines.max(1);
160    let tokens_per_line = tokens as f64 / effective_lines as f64;
161    if tokens_per_line > dense_threshold {
162        classes.push(FileClassification::DataBlob);
163    }
164
165    classes.sort();
166    classes.dedup();
167    classes
168}
169
170/// Compute the maximum tokens a single file may consume.
171#[must_use]
172pub fn compute_file_cap(budget: usize, max_file_pct: f64, max_file_tokens: Option<usize>) -> usize {
173    if budget == usize::MAX {
174        return usize::MAX;
175    }
176
177    let pct_cap = (budget as f64 * max_file_pct) as usize;
178    let hard_cap = max_file_tokens.unwrap_or(DEFAULT_MAX_FILE_TOKENS);
179    pct_cap.min(hard_cap)
180}
181
182/// Assign an inclusion policy based on size and file classifications.
183#[must_use]
184pub fn assign_policy(
185    tokens: usize,
186    file_cap: usize,
187    classifications: &[FileClassification],
188) -> (InclusionPolicy, Option<String>) {
189    if tokens <= file_cap {
190        return (InclusionPolicy::Full, None);
191    }
192
193    let skip_classes = [
194        FileClassification::Generated,
195        FileClassification::DataBlob,
196        FileClassification::Vendored,
197    ];
198
199    if classifications.iter().any(|c| skip_classes.contains(c)) {
200        let class_names: Vec<&str> = classifications.iter().map(classification_name).collect();
201        return (
202            InclusionPolicy::Skip,
203            Some(format!(
204                "{} file exceeds cap ({} > {} tokens)",
205                class_names.join("+"),
206                tokens,
207                file_cap
208            )),
209        );
210    }
211
212    (
213        InclusionPolicy::HeadTail,
214        Some(format!(
215            "file exceeds cap ({} > {} tokens); head+tail included",
216            tokens, file_cap
217        )),
218    )
219}
220
221fn classification_name(classification: &FileClassification) -> &'static str {
222    match classification {
223        FileClassification::Generated => "generated",
224        FileClassification::Fixture => "fixture",
225        FileClassification::Vendored => "vendored",
226        FileClassification::Lockfile => "lockfile",
227        FileClassification::Minified => "minified",
228        FileClassification::DataBlob => "data_blob",
229        FileClassification::Sourcemap => "sourcemap",
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236
237    #[test]
238    fn smart_exclude_reason_detects_lockfiles_and_sourcemaps() {
239        assert_eq!(smart_exclude_reason("Cargo.lock"), Some("lockfile"));
240        assert_eq!(smart_exclude_reason("dist/app.js.map"), Some("sourcemap"));
241        assert_eq!(smart_exclude_reason("src/main.rs"), None);
242    }
243
244    #[test]
245    fn is_spine_file_matches_basename_and_document_paths() {
246        assert!(is_spine_file("README.md"));
247        assert!(is_spine_file("nested/docs/architecture.md"));
248        assert!(!is_spine_file("src/main.rs"));
249    }
250
251    #[test]
252    fn classify_file_detects_generated_and_dense_blob() {
253        let classes = classify_file("src/node-types.json", 50_000, 5, 50.0);
254        assert!(classes.contains(&FileClassification::Generated));
255        assert!(classes.contains(&FileClassification::DataBlob));
256    }
257
258    #[test]
259    fn assign_policy_skips_oversized_generated_files() {
260        let (policy, reason) = assign_policy(20_000, 16_000, &[FileClassification::Generated]);
261        assert_eq!(policy, InclusionPolicy::Skip);
262        assert!(reason.unwrap_or_default().contains("generated"));
263    }
264}