1use serde::de::Deserializer;
17use serde::{Deserialize, Serialize};
18use std::collections::BTreeMap;
19use std::fs;
20use std::path::Path;
21use std::time::{SystemTime, UNIX_EPOCH};
22
23#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
24#[serde(rename_all = "lowercase")]
25pub enum Source {
26 Ski,
28 Model,
30}
31
32#[derive(Clone, Copy, Debug, PartialEq, Serialize)]
36pub struct Record {
37 pub source: Source,
38 pub confidence: f32,
39}
40
41impl<'de> Deserialize<'de> for Record {
45 fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
46 #[derive(Deserialize)]
47 #[serde(untagged)]
48 enum Repr {
49 Bare(Source),
50 Full {
51 source: Source,
52 #[serde(default)]
53 confidence: f32,
54 },
55 }
56 Ok(match Repr::deserialize(d)? {
57 Repr::Bare(source) => Record {
58 source,
59 confidence: 0.0,
60 },
61 Repr::Full { source, confidence } => Record { source, confidence },
62 })
63 }
64}
65
66#[derive(Clone, Debug, Default, Serialize, Deserialize)]
67pub struct Session {
68 #[serde(default)]
70 pub loaded: BTreeMap<String, Record>,
71 #[serde(default, skip_serializing_if = "String::is_empty")]
77 pub last_prompt: String,
78 #[serde(default, skip_serializing_if = "Vec::is_empty")]
85 pub recent_prompts: Vec<String>,
86 #[serde(default)]
88 pub updated: u64,
89}
90
91impl Session {
92 pub fn load(path: &Path) -> Session {
95 fs::read_to_string(path)
96 .ok()
97 .and_then(|s| serde_json::from_str(&s).ok())
98 .unwrap_or_default()
99 }
100
101 pub fn save(&self, path: &Path) -> anyhow::Result<()> {
112 if let Some(parent) = path.parent() {
113 fs::create_dir_all(parent)?;
114 }
115 let mut snapshot = self.clone();
116 snapshot.updated = now_secs();
117 let json = serde_json::to_string_pretty(&snapshot)?;
118 let tmp = path.with_extension(format!("tmp.{}.{}", std::process::id(), now_nanos()));
119 fs::write(&tmp, json)?;
120 if let Err(e) = fs::rename(&tmp, path) {
121 let _ = fs::remove_file(&tmp);
122 return Err(e.into());
123 }
124 Ok(())
125 }
126
127 pub fn is_loaded(&self, id: &str) -> bool {
128 self.loaded.contains_key(id)
129 }
130
131 pub fn get(&self, id: &str) -> Option<&Record> {
132 self.loaded.get(id)
133 }
134
135 pub fn save_merged(&self, path: &Path) -> anyhow::Result<()> {
157 let disk = Session::load(path);
158 let mut merged = self.clone();
159 for (id, theirs) in disk.loaded {
160 match merged.loaded.get(&id) {
161 None => {
162 merged.loaded.insert(id, theirs);
163 }
164 Some(ours) => {
165 let take_theirs = match (theirs.source, ours.source) {
166 (Source::Model, Source::Ski) => true,
167 (Source::Ski, Source::Model) => false,
168 _ => theirs.confidence > ours.confidence,
169 };
170 if take_theirs {
171 merged.loaded.insert(id, theirs);
172 }
173 }
174 }
175 }
176 merged.save(path)
177 }
178
179 pub fn should_recommend(&self, id: &str, new_conf: f32, high: f32) -> bool {
187 match self.loaded.get(id) {
188 None => true,
189 Some(r) if r.source == Source::Model => false,
190 Some(r) => new_conf >= high && r.confidence < high,
191 }
192 }
193
194 pub fn mark_recommended(&mut self, id: &str, confidence: f32) {
199 match self.loaded.get(id) {
200 Some(r) if r.source == Source::Model => {}
201 _ => {
202 self.loaded.insert(
203 id.to_string(),
204 Record {
205 source: Source::Ski,
206 confidence,
207 },
208 );
209 }
210 }
211 }
212
213 pub fn mark_used(&mut self, id: &str) {
216 let confidence = self.loaded.get(id).map(|r| r.confidence).unwrap_or(0.0);
217 self.loaded.insert(
218 id.to_string(),
219 Record {
220 source: Source::Model,
221 confidence,
222 },
223 );
224 }
225
226 pub fn mark(&mut self, id: &str, source: Source) {
230 match source {
231 Source::Model => self.mark_used(id),
232 Source::Ski => {
233 self.loaded.entry(id.to_string()).or_insert(Record {
234 source: Source::Ski,
235 confidence: 0.0,
236 });
237 }
238 }
239 }
240
241 pub fn push_prompt(&mut self, prompt: &str, max: usize) {
247 let p = prompt.trim();
248 if max == 0 || p.is_empty() {
249 return;
250 }
251 if self.recent_prompts.last().map(String::as_str) == Some(p) {
252 return;
253 }
254 self.recent_prompts.push(p.to_string());
255 let len = self.recent_prompts.len();
256 if len > max {
257 self.recent_prompts.drain(0..len - max);
258 }
259 }
260
261 pub fn clear(&mut self) {
264 self.loaded.clear();
265 self.recent_prompts.clear();
266 }
267}
268
269fn now_secs() -> u64 {
270 SystemTime::now()
271 .duration_since(UNIX_EPOCH)
272 .map(|d| d.as_secs())
273 .unwrap_or(0)
274}
275
276fn now_nanos() -> u128 {
279 SystemTime::now()
280 .duration_since(UNIX_EPOCH)
281 .map(|d| d.as_nanos())
282 .unwrap_or(0)
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 #[test]
290 fn mark_and_dedup() {
291 let mut s = Session::default();
292 assert!(!s.is_loaded("a"));
293 s.mark("a", Source::Ski);
294 assert!(s.is_loaded("a"));
295 }
296
297 #[test]
298 fn model_load_is_not_downgraded() {
299 let mut s = Session::default();
300 s.mark("a", Source::Model);
301 s.mark("a", Source::Ski); assert_eq!(s.loaded["a"].source, Source::Model);
303 }
304
305 #[test]
306 fn ski_then_model_upgrades() {
307 let mut s = Session::default();
308 s.mark("a", Source::Ski);
309 s.mark("a", Source::Model);
310 assert_eq!(s.loaded["a"].source, Source::Model);
311 }
312
313 #[test]
314 fn used_skill_is_never_recommended() {
315 let mut s = Session::default();
316 s.mark_used("a");
317 assert!(!s.should_recommend("a", 1.0, 0.80));
319 }
320
321 #[test]
322 fn unseen_skill_is_recommended() {
323 let s = Session::default();
324 assert!(s.should_recommend("a", 0.40, 0.80)); }
326
327 #[test]
328 fn repeat_only_on_rise_into_high() {
329 let mut s = Session::default();
330 s.mark_recommended("a", 0.60); assert!(!s.should_recommend("a", 0.70, 0.80)); assert!(s.should_recommend("a", 0.90, 0.80)); }
334
335 #[test]
336 fn no_repeat_after_high_showing() {
337 let mut s = Session::default();
338 s.mark_recommended("a", 0.90); assert!(!s.should_recommend("a", 0.95, 0.80)); }
341
342 #[test]
343 fn mark_recommended_does_not_downgrade_model() {
344 let mut s = Session::default();
345 s.mark_used("a");
346 s.mark_recommended("a", 0.99);
347 assert_eq!(s.loaded["a"].source, Source::Model);
348 }
349
350 #[test]
351 fn legacy_bare_string_value_still_loads() {
352 let json = r#"{"loaded":{"a":"ski","b":"model"},"updated":0}"#;
354 let s: Session = serde_json::from_str(json).unwrap();
355 assert_eq!(s.loaded["a"].source, Source::Ski);
356 assert_eq!(s.loaded["a"].confidence, 0.0);
357 assert_eq!(s.loaded["b"].source, Source::Model);
358 }
359
360 #[test]
361 fn clear_re_arms() {
362 let mut s = Session::default();
363 s.mark("a", Source::Ski);
364 s.push_prompt("set up pytest", 3);
365 s.clear();
366 assert!(!s.is_loaded("a"));
367 assert!(s.recent_prompts.is_empty()); }
369
370 #[test]
371 fn push_prompt_bounds_window_oldest_first() {
372 let mut s = Session::default();
373 for p in ["one", "two", "three", "four"] {
374 s.push_prompt(p, 3);
375 }
376 assert_eq!(s.recent_prompts, ["two", "three", "four"]);
378 }
379
380 #[test]
381 fn push_prompt_ignores_blank_and_consecutive_dupes() {
382 let mut s = Session::default();
383 s.push_prompt(" ", 3); s.push_prompt("set up pytest", 3);
385 s.push_prompt("set up pytest", 3); s.push_prompt("now the other one", 3);
387 assert_eq!(s.recent_prompts, ["set up pytest", "now the other one"]);
388 }
389
390 #[test]
391 fn push_prompt_zero_max_disables_window() {
392 let mut s = Session::default();
393 s.push_prompt("anything", 0);
394 assert!(s.recent_prompts.is_empty()); }
396
397 #[test]
398 fn recent_prompts_absent_when_empty_in_json() {
399 let s = Session::default();
402 let json = serde_json::to_string(&s).unwrap();
403 assert!(!json.contains("recent_prompts"), "got {json}");
404 }
405
406 #[test]
407 fn source_serializes_lowercase() {
408 let json = serde_json::to_string(&Source::Ski).unwrap();
409 assert_eq!(json, "\"ski\"");
410 let json = serde_json::to_string(&Source::Model).unwrap();
411 assert_eq!(json, "\"model\"");
412 }
413
414 #[test]
415 fn missing_file_is_empty_session() {
416 let s = Session::load(Path::new("/nonexistent/ski/session.json"));
417 assert!(s.loaded.is_empty());
418 }
419
420 #[test]
421 fn save_then_load_roundtrips_and_leaves_no_temp() {
422 let dir = std::env::temp_dir().join(format!(
423 "ski-session-save-{}-{}",
424 std::process::id(),
425 now_nanos()
426 ));
427 let path = dir.join("conv.json");
428 let mut s = Session::default();
429 s.mark("uv-setup", Source::Ski);
430 s.save(&path).unwrap();
431
432 let back = Session::load(&path);
433 assert_eq!(back.loaded["uv-setup"].source, Source::Ski);
434 let leftovers: Vec<_> = fs::read_dir(&dir)
436 .unwrap()
437 .filter_map(|e| e.ok())
438 .map(|e| e.file_name())
439 .filter(|n| n != "conv.json")
440 .collect();
441 assert!(leftovers.is_empty(), "temp file left behind: {leftovers:?}");
442 let _ = fs::remove_dir_all(&dir);
443 }
444
445 #[test]
446 fn save_merged_keeps_concurrent_writers_mark() {
447 let dir = std::env::temp_dir().join(format!(
452 "ski-session-merge-{}-{}",
453 std::process::id(),
454 now_nanos()
455 ));
456 let path = dir.join("conv.json");
457
458 let hook_snapshot = Session::load(&path); let mut observe = Session::load(&path); observe.mark_used("xlsx");
462 observe.save(&path).unwrap(); let mut hook = hook_snapshot;
465 hook.mark_recommended("pdf", 0.9);
466 hook.save_merged(&path).unwrap(); let merged = Session::load(&path);
469 assert_eq!(merged.loaded["xlsx"].source, Source::Model, "mark lost");
470 assert_eq!(merged.loaded["pdf"].source, Source::Ski);
471 let _ = fs::remove_dir_all(&dir);
472 }
473
474 #[test]
475 fn save_merged_model_beats_ski_and_max_confidence_wins() {
476 let dir = std::env::temp_dir().join(format!(
477 "ski-session-merge2-{}-{}",
478 std::process::id(),
479 now_nanos()
480 ));
481 let path = dir.join("conv.json");
482
483 let mut disk = Session::default();
485 disk.mark_used("a");
486 disk.mark_recommended("b", 0.9);
487 disk.save(&path).unwrap();
488
489 let mut ours = Session::default();
492 ours.mark_recommended("a", 0.99);
493 ours.mark_recommended("b", 0.6);
494 ours.save_merged(&path).unwrap();
495
496 let merged = Session::load(&path);
497 assert_eq!(merged.loaded["a"].source, Source::Model);
498 assert!(merged.loaded["b"].confidence > 0.8);
499 let _ = fs::remove_dir_all(&dir);
500 }
501
502 #[test]
503 fn plain_save_still_wipes_for_compaction() {
504 let dir = std::env::temp_dir().join(format!(
507 "ski-session-wipe-{}-{}",
508 std::process::id(),
509 now_nanos()
510 ));
511 let path = dir.join("conv.json");
512 let mut s = Session::default();
513 s.mark_used("a");
514 s.save(&path).unwrap();
515
516 let mut rearmed = Session::load(&path);
517 rearmed.clear();
518 rearmed.save(&path).unwrap();
519 assert!(Session::load(&path).loaded.is_empty());
520 let _ = fs::remove_dir_all(&dir);
521 }
522
523 #[test]
524 fn roundtrip_through_json() {
525 let mut s = Session::default();
526 s.mark("git-attribution", Source::Ski);
527 s.mark("uv-setup", Source::Model);
528 let text = serde_json::to_string(&s).unwrap();
529 let back: Session = serde_json::from_str(&text).unwrap();
530 assert_eq!(back.loaded["git-attribution"].source, Source::Ski);
531 assert_eq!(back.loaded["uv-setup"].source, Source::Model);
532 }
533}