swink_agent/
context_cache.rs1#![forbid(unsafe_code)]
8
9use std::time::Duration;
10
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone)]
21pub struct CacheConfig {
22 pub ttl: Duration,
24 pub min_tokens: usize,
27 pub cache_intervals: usize,
29}
30
31impl CacheConfig {
32 pub const fn new(ttl: Duration, min_tokens: usize, cache_intervals: usize) -> Self {
34 Self {
35 ttl,
36 min_tokens,
37 cache_intervals,
38 }
39 }
40}
41
42#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
49#[serde(tag = "action", rename_all = "snake_case")]
50pub enum CacheHint {
51 Write {
53 #[serde(with = "duration_secs")]
54 ttl: Duration,
55 },
56 Read,
58}
59
60mod duration_secs {
62 use std::time::Duration;
63
64 use serde::{Deserialize, Deserializer, Serializer};
65
66 pub fn serialize<S: Serializer>(dur: &Duration, s: S) -> Result<S::Ok, S::Error> {
67 s.serialize_u64(dur.as_secs())
68 }
69
70 pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Duration, D::Error> {
71 let secs = u64::deserialize(d)?;
72 Ok(Duration::from_secs(secs))
73 }
74}
75
76#[derive(Debug, Clone)]
84pub struct CacheState {
85 turns_since_write: usize,
86 pub cached_prefix_len: usize,
88}
89
90impl CacheState {
91 pub const fn new() -> Self {
93 Self {
94 turns_since_write: 0,
95 cached_prefix_len: 0,
96 }
97 }
98
99 pub const fn advance_turn(&mut self, config: &CacheConfig) -> CacheHint {
105 if self.turns_since_write == 0 {
106 self.turns_since_write = 1;
108 CacheHint::Write { ttl: config.ttl }
109 } else if self.turns_since_write >= config.cache_intervals {
110 self.turns_since_write = 1;
112 CacheHint::Write { ttl: config.ttl }
113 } else {
114 self.turns_since_write += 1;
115 CacheHint::Read
116 }
117 }
118
119 pub const fn reset(&mut self) {
123 self.turns_since_write = 0;
124 self.cached_prefix_len = 0;
125 }
126}
127
128impl Default for CacheState {
129 fn default() -> Self {
130 Self::new()
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137
138 fn test_config(intervals: usize) -> CacheConfig {
139 CacheConfig::new(Duration::from_secs(600), 4096, intervals)
140 }
141
142 #[test]
143 fn first_turn_emits_write() {
144 let mut state = CacheState::new();
145 let config = test_config(3);
146 let hint = state.advance_turn(&config);
147 assert_eq!(
148 hint,
149 CacheHint::Write {
150 ttl: Duration::from_secs(600)
151 }
152 );
153 }
154
155 #[test]
156 fn subsequent_turns_emit_read() {
157 let mut state = CacheState::new();
158 let config = test_config(3);
159 state.advance_turn(&config); assert_eq!(state.advance_turn(&config), CacheHint::Read); assert_eq!(state.advance_turn(&config), CacheHint::Read); }
163
164 #[test]
165 fn refresh_after_cache_intervals() {
166 let mut state = CacheState::new();
167 let config = test_config(3);
168 state.advance_turn(&config); state.advance_turn(&config); state.advance_turn(&config); let hint = state.advance_turn(&config);
173 assert_eq!(
174 hint,
175 CacheHint::Write {
176 ttl: Duration::from_secs(600)
177 }
178 );
179 }
180
181 #[test]
182 fn reset_forces_write_on_next_turn() {
183 let mut state = CacheState::new();
184 let config = test_config(5);
185 state.advance_turn(&config); state.advance_turn(&config); state.reset(); let hint = state.advance_turn(&config);
189 assert_eq!(
190 hint,
191 CacheHint::Write {
192 ttl: Duration::from_secs(600)
193 }
194 );
195 }
196
197 #[test]
198 fn cached_prefix_len_tracks_correctly() {
199 let mut state = CacheState::new();
200 assert_eq!(state.cached_prefix_len, 0);
201 state.cached_prefix_len = 5;
202 assert_eq!(state.cached_prefix_len, 5);
203 state.reset();
204 assert_eq!(state.cached_prefix_len, 0);
205 }
206
207 #[test]
208 fn min_tokens_below_threshold_suppresses_hints() {
209 let config = CacheConfig::new(Duration::from_secs(300), 8192, 2);
212 assert_eq!(config.min_tokens, 8192);
213 }
214
215 #[test]
216 fn serde_round_trip_write_hint() {
217 let hint = CacheHint::Write {
218 ttl: Duration::from_secs(600),
219 };
220 let json = serde_json::to_string(&hint).unwrap();
221 let back: CacheHint = serde_json::from_str(&json).unwrap();
222 assert_eq!(hint, back);
223 }
224
225 #[test]
226 fn serde_round_trip_read_hint() {
227 let hint = CacheHint::Read;
228 let json = serde_json::to_string(&hint).unwrap();
229 let back: CacheHint = serde_json::from_str(&json).unwrap();
230 assert_eq!(hint, back);
231 }
232}