1use anyhow::{Context, Result};
14use chrono::Utc;
15use rocksdb::{ColumnFamily, ColumnFamilyDescriptor, Options, WriteBatch, DB};
16use std::path::Path;
17use std::sync::Arc;
18
19use super::types::{ProspectiveTask, ProspectiveTaskId, ProspectiveTaskStatus, ProspectiveTrigger};
20
21const CF_PROSPECTIVE: &str = "prospective";
23const CF_PROSPECTIVE_INDEX: &str = "prospective_index";
25
26fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
28 if a.len() != b.len() || a.is_empty() {
29 return 0.0;
30 }
31
32 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
33 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
34 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
35
36 if norm_a == 0.0 || norm_b == 0.0 {
37 return 0.0;
38 }
39
40 dot / (norm_a * norm_b)
41}
42
43fn migrate_due_key_padding(db: &DB, index_cf: &ColumnFamily) -> Result<usize> {
49 let mut batch = WriteBatch::default();
50 let mut count = 0;
51
52 for item in db.prefix_iterator_cf(index_cf, b"due:") {
53 let (key, value) = item.context("Failed to read due index during migration")?;
54 let key_str = std::str::from_utf8(&key).context("Non-UTF8 key in prospective due index")?;
55
56 let parts: Vec<&str> = key_str.splitn(3, ':').collect();
58 if parts.len() != 3 {
59 continue;
60 }
61
62 if parts[1].len() >= 20 {
64 continue;
65 }
66
67 if let Ok(ts) = parts[1].parse::<i64>() {
68 let new_key = format!("due:{:020}:{}", ts, parts[2]);
69 batch.delete_cf(index_cf, &*key);
70 batch.put_cf(index_cf, new_key.as_bytes(), &*value);
71 count += 1;
72 }
73 }
74
75 if count > 0 {
76 db.write(batch)
77 .context("Failed to write migrated prospective due keys")?;
78 tracing::info!(count, "Migrated prospective due keys to zero-padded format");
79 }
80
81 Ok(count)
82}
83
84pub struct ProspectiveStore {
86 db: Arc<DB>,
88}
89
90impl ProspectiveStore {
91 pub fn column_family_descriptors() -> Vec<ColumnFamilyDescriptor> {
95 let mut opts = Options::default();
96 opts.create_if_missing(true);
97 opts.set_compression_type(rocksdb::DBCompressionType::Lz4);
98
99 vec![
100 ColumnFamilyDescriptor::new(CF_PROSPECTIVE, opts.clone()),
101 ColumnFamilyDescriptor::new(CF_PROSPECTIVE_INDEX, opts),
102 ]
103 }
104
105 fn tasks_cf(&self) -> &ColumnFamily {
107 self.db
108 .cf_handle(CF_PROSPECTIVE)
109 .expect("prospective CF must exist")
110 }
111
112 fn index_cf(&self) -> &ColumnFamily {
114 self.db
115 .cf_handle(CF_PROSPECTIVE_INDEX)
116 .expect("prospective_index CF must exist")
117 }
118
119 pub fn new(db: Arc<DB>, storage_path: &Path) -> Result<Self> {
126 let prospective_path = storage_path.join("prospective");
127 std::fs::create_dir_all(&prospective_path)?;
128
129 Self::migrate_from_separate_dbs(&prospective_path, &db)?;
130
131 let index_cf = db
132 .cf_handle(CF_PROSPECTIVE_INDEX)
133 .expect("prospective_index CF must exist");
134 migrate_due_key_padding(&db, index_cf)?;
135
136 tracing::info!("Prospective memory store initialized");
137 Ok(Self { db })
138 }
139
140 fn migrate_from_separate_dbs(prospective_path: &Path, db: &DB) -> Result<()> {
144 let old_dirs: &[(&str, &str)] =
145 &[("tasks", CF_PROSPECTIVE), ("index", CF_PROSPECTIVE_INDEX)];
146
147 for (old_name, cf_name) in old_dirs {
148 let old_dir = prospective_path.join(old_name);
149 if !old_dir.is_dir() {
150 continue;
151 }
152
153 let cf = db
154 .cf_handle(cf_name)
155 .unwrap_or_else(|| panic!("{cf_name} CF must exist"));
156 let old_opts = Options::default();
157 match DB::open_for_read_only(&old_opts, &old_dir, false) {
158 Ok(old_db) => {
159 let mut batch = WriteBatch::default();
160 let mut count = 0usize;
161 for item in old_db.iterator(rocksdb::IteratorMode::Start) {
162 if let Ok((key, value)) = item {
163 batch.put_cf(cf, &key, &value);
164 count += 1;
165 if count % 10_000 == 0 {
166 db.write(std::mem::take(&mut batch))?;
167 }
168 }
169 }
170 if !batch.is_empty() {
171 db.write(batch)?;
172 }
173 drop(old_db);
174 tracing::info!(
175 " prospective/{old_name}: migrated {count} entries to {cf_name} CF"
176 );
177
178 let backup = prospective_path.join(format!("{old_name}.pre_cf_migration"));
179 if backup.exists() {
180 let _ = std::fs::remove_dir_all(&backup);
181 }
182 if let Err(e) = std::fs::rename(&old_dir, &backup) {
183 tracing::warn!("Could not rename old {old_name} dir: {e}");
184 }
185 }
186 Err(e) => {
187 tracing::warn!("Could not open old {old_name} DB for migration: {e}");
188 }
189 }
190 }
191 Ok(())
192 }
193
194 pub fn flush(&self) -> Result<()> {
196 use rocksdb::FlushOptions;
197 let mut flush_opts = FlushOptions::default();
198 flush_opts.set_wait(true);
199 for cf_name in &[CF_PROSPECTIVE, CF_PROSPECTIVE_INDEX] {
200 if let Some(cf) = self.db.cf_handle(cf_name) {
201 self.db
202 .flush_cf_opt(cf, &flush_opts)
203 .map_err(|e| anyhow::anyhow!("Failed to flush {cf_name}: {e}"))?;
204 }
205 }
206 Ok(())
207 }
208
209 pub fn databases(&self) -> Vec<(&str, &Arc<DB>)> {
211 vec![("prospective_shared", &self.db)]
212 }
213
214 pub fn store(&self, task: &ProspectiveTask) -> Result<()> {
216 let key = format!("{}:{}", task.user_id, task.id);
217 let value =
219 serde_json::to_vec(task).context("Failed to serialize prospective task to JSON")?;
220
221 self.db
222 .put_cf(self.tasks_cf(), key.as_bytes(), &value)
223 .context("Failed to store prospective task")?;
224
225 self.update_indices(task)?;
227
228 tracing::debug!(
229 task_id = %task.id,
230 user_id = %task.user_id,
231 trigger = ?task.trigger,
232 "Stored prospective task"
233 );
234
235 Ok(())
236 }
237
238 fn update_indices(&self, task: &ProspectiveTask) -> Result<()> {
240 let mut batch = WriteBatch::default();
241
242 let user_key = format!("user:{}:{}", task.user_id, task.id);
244 batch.put_cf(self.index_cf(), user_key.as_bytes(), b"1");
245
246 let status_key = format!("status:{:?}:{}:{}", task.status, task.user_id, task.id);
248 batch.put_cf(self.index_cf(), status_key.as_bytes(), b"1");
249
250 if let Some(due_at) = task.trigger.due_at() {
253 let due_key = format!("due:{:020}:{}", due_at.timestamp(), task.id);
254 batch.put_cf(self.index_cf(), due_key.as_bytes(), task.user_id.as_bytes());
255 }
256
257 if let ProspectiveTrigger::OnContext { ref keywords, .. } = task.trigger {
259 for keyword in keywords {
260 let kw_key = format!(
261 "context:{}:{}:{}",
262 keyword.to_lowercase(),
263 task.user_id,
264 task.id
265 );
266 batch.put_cf(self.index_cf(), kw_key.as_bytes(), b"1");
267 }
268 }
269
270 self.db
271 .write(batch)
272 .context("Failed to update prospective indices")?;
273
274 Ok(())
275 }
276
277 pub fn get(
279 &self,
280 user_id: &str,
281 task_id: &ProspectiveTaskId,
282 ) -> Result<Option<ProspectiveTask>> {
283 let key = format!("{}:{}", user_id, task_id);
284
285 match self.db.get_cf(self.tasks_cf(), key.as_bytes())? {
286 Some(value) => {
287 let task: ProspectiveTask = serde_json::from_slice(&value)
288 .context("Failed to deserialize prospective task from JSON")?;
289 Ok(Some(task))
290 }
291 None => Ok(None),
292 }
293 }
294
295 pub fn update(&self, task: &ProspectiveTask) -> Result<()> {
300 let mut batch = WriteBatch::default();
301
302 if let Some(old_task) = self.get(&task.user_id, &task.id)? {
304 let user_key = format!("user:{}:{}", task.user_id, task.id);
305 batch.delete_cf(self.index_cf(), user_key.as_bytes());
306
307 let status_key = format!("status:{:?}:{}:{}", old_task.status, task.user_id, task.id);
308 batch.delete_cf(self.index_cf(), status_key.as_bytes());
309
310 if let Some(due_at) = old_task.trigger.due_at() {
311 let due_key = format!("due:{:020}:{}", due_at.timestamp(), task.id);
312 batch.delete_cf(self.index_cf(), due_key.as_bytes());
313 }
314
315 if let ProspectiveTrigger::OnContext { ref keywords, .. } = old_task.trigger {
316 for keyword in keywords {
317 let kw_key = format!(
318 "context:{}:{}:{}",
319 keyword.to_lowercase(),
320 task.user_id,
321 task.id
322 );
323 batch.delete_cf(self.index_cf(), kw_key.as_bytes());
324 }
325 }
326 }
327
328 let key = format!("{}:{}", task.user_id, task.id);
330 let value =
331 serde_json::to_vec(task).context("Failed to serialize prospective task to JSON")?;
332 batch.put_cf(self.tasks_cf(), key.as_bytes(), &value);
333
334 let user_key = format!("user:{}:{}", task.user_id, task.id);
336 batch.put_cf(self.index_cf(), user_key.as_bytes(), b"1");
337
338 let status_key = format!("status:{:?}:{}:{}", task.status, task.user_id, task.id);
339 batch.put_cf(self.index_cf(), status_key.as_bytes(), b"1");
340
341 if task.status == ProspectiveTaskStatus::Pending {
344 if let Some(due_at) = task.trigger.due_at() {
345 let due_key = format!("due:{:020}:{}", due_at.timestamp(), task.id);
346 batch.put_cf(self.index_cf(), due_key.as_bytes(), task.user_id.as_bytes());
347 }
348 }
349
350 if let ProspectiveTrigger::OnContext { ref keywords, .. } = task.trigger {
351 for keyword in keywords {
352 let kw_key = format!(
353 "context:{}:{}:{}",
354 keyword.to_lowercase(),
355 task.user_id,
356 task.id
357 );
358 batch.put_cf(self.index_cf(), kw_key.as_bytes(), b"1");
359 }
360 }
361
362 self.db
364 .write(batch)
365 .context("Failed to atomically update prospective task")?;
366
367 tracing::debug!(
368 task_id = %task.id,
369 user_id = %task.user_id,
370 status = ?task.status,
371 "Updated prospective task (atomic)"
372 );
373
374 Ok(())
375 }
376
377 pub fn delete(&self, user_id: &str, task_id: &ProspectiveTaskId) -> Result<bool> {
379 let task = match self.get(user_id, task_id)? {
380 Some(t) => t,
381 None => return Ok(false),
382 };
383
384 let mut batch = WriteBatch::default();
385
386 let key = format!("{}:{}", user_id, task_id);
388 batch.delete_cf(self.tasks_cf(), key.as_bytes());
389
390 let user_key = format!("user:{}:{}", user_id, task_id);
392 batch.delete_cf(self.index_cf(), user_key.as_bytes());
393
394 let status_key = format!("status:{:?}:{}:{}", task.status, user_id, task_id);
395 batch.delete_cf(self.index_cf(), status_key.as_bytes());
396
397 if let Some(due_at) = task.trigger.due_at() {
398 let due_key = format!("due:{:020}:{}", due_at.timestamp(), task_id);
399 batch.delete_cf(self.index_cf(), due_key.as_bytes());
400 }
401
402 if let ProspectiveTrigger::OnContext { ref keywords, .. } = task.trigger {
403 for keyword in keywords {
404 let kw_key = format!("context:{}:{}:{}", keyword.to_lowercase(), user_id, task_id);
405 batch.delete_cf(self.index_cf(), kw_key.as_bytes());
406 }
407 }
408
409 self.db
410 .write(batch)
411 .context("Failed to atomically delete prospective task")?;
412
413 tracing::debug!(task_id = %task_id, user_id = %user_id, "Deleted prospective task");
414
415 Ok(true)
416 }
417
418 pub fn list_for_user(
420 &self,
421 user_id: &str,
422 status_filter: Option<ProspectiveTaskStatus>,
423 ) -> Result<Vec<ProspectiveTask>> {
424 let prefix = format!("user:{}:", user_id);
425 let mut tasks = Vec::new();
426
427 for item in self
428 .db
429 .prefix_iterator_cf(self.index_cf(), prefix.as_bytes())
430 {
431 let (key, _) = item.context("Failed to read index entry")?;
432 let key_str = String::from_utf8_lossy(&key);
433
434 if let Some(task_id_str) = key_str.strip_prefix(&prefix) {
436 if let Ok(uuid) = uuid::Uuid::parse_str(task_id_str) {
437 let task_id = ProspectiveTaskId(uuid);
438 if let Some(task) = self.get(user_id, &task_id)? {
439 if let Some(filter) = status_filter {
441 if task.status == filter {
442 tasks.push(task);
443 }
444 } else {
445 tasks.push(task);
446 }
447 }
448 }
449 }
450 }
451
452 tasks.sort_by(|a, b| b.created_at.cmp(&a.created_at));
454
455 Ok(tasks)
456 }
457
458 pub fn get_due_tasks(&self, user_id: &str) -> Result<Vec<ProspectiveTask>> {
465 let now = Utc::now();
466 let now_ts = now.timestamp();
467
468 let mut due_tasks = Vec::new();
469
470 for item in self.db.prefix_iterator_cf(self.index_cf(), b"due:") {
473 let (key, value) = item.context("Failed to read due index")?;
474 let key_str = String::from_utf8_lossy(&key);
475
476 let parts: Vec<&str> = key_str.splitn(3, ':').collect();
478 if parts.len() != 3 {
479 continue;
480 }
481
482 let task_ts: i64 = match parts[1].parse() {
483 Ok(ts) => ts,
484 Err(_) => continue,
485 };
486
487 if task_ts > now_ts {
490 break;
491 }
492
493 let stored_user_id = String::from_utf8_lossy(&value);
495 if stored_user_id != user_id {
496 continue;
497 }
498
499 if let Ok(uuid) = uuid::Uuid::parse_str(parts[2]) {
501 let task_id = ProspectiveTaskId(uuid);
502 if let Some(task) = self.get(user_id, &task_id)? {
503 if task.status == ProspectiveTaskStatus::Pending {
504 due_tasks.push(task);
505 }
506 }
507 }
508 }
509
510 due_tasks.sort_by(|a, b| {
512 let priority_cmp = b.priority.cmp(&a.priority);
513 if priority_cmp != std::cmp::Ordering::Equal {
514 return priority_cmp;
515 }
516 let a_due = a.trigger.due_at().unwrap_or(a.created_at);
517 let b_due = b.trigger.due_at().unwrap_or(b.created_at);
518 a_due.cmp(&b_due)
519 });
520
521 Ok(due_tasks)
522 }
523
524 pub fn get_all_due_tasks(&self) -> Result<Vec<(String, ProspectiveTask)>> {
530 let now_ts = Utc::now().timestamp();
531 let mut due_tasks = Vec::new();
532
533 for item in self.db.prefix_iterator_cf(self.index_cf(), b"due:") {
534 let (key, value) = item.context("Failed to read due index")?;
535 let key_str = String::from_utf8_lossy(&key);
536
537 let parts: Vec<&str> = key_str.splitn(3, ':').collect();
538 if parts.len() != 3 {
539 continue;
540 }
541
542 let task_ts: i64 = match parts[1].parse() {
543 Ok(ts) => ts,
544 Err(_) => continue,
545 };
546
547 if task_ts > now_ts {
548 break;
549 }
550
551 let user_id = String::from_utf8_lossy(&value).to_string();
552
553 if let Ok(uuid) = uuid::Uuid::parse_str(parts[2]) {
554 let task_id = ProspectiveTaskId(uuid);
555 if let Some(task) = self.get(&user_id, &task_id)? {
556 if task.status == ProspectiveTaskStatus::Pending {
557 due_tasks.push((user_id, task));
558 }
559 }
560 }
561 }
562
563 Ok(due_tasks)
564 }
565
566 pub fn check_context_triggers(
575 &self,
576 user_id: &str,
577 context: &str,
578 ) -> Result<Vec<ProspectiveTask>> {
579 let context_lower = context.to_lowercase();
580 let mut matches = Vec::new();
581 let mut seen_ids = std::collections::HashSet::new();
582
583 let pending_tasks = self.list_for_user(user_id, Some(ProspectiveTaskStatus::Pending))?;
585
586 for task in pending_tasks {
587 if seen_ids.contains(&task.id.0) {
588 continue;
589 }
590
591 if let ProspectiveTrigger::OnContext { ref keywords, .. } = task.trigger {
592 let matched = keywords
594 .iter()
595 .any(|kw| context_lower.contains(&kw.to_lowercase()));
596 if matched {
597 seen_ids.insert(task.id.0);
598 matches.push(task);
599 }
600 }
601 }
602
603 matches.sort_by(|a, b| b.priority.cmp(&a.priority));
605
606 Ok(matches)
607 }
608
609 pub fn check_context_triggers_semantic<F>(
624 &self,
625 user_id: &str,
626 context: &str,
627 context_embedding: &[f32],
628 embed_fn: F,
629 ) -> Result<Vec<(ProspectiveTask, f32)>>
630 where
631 F: Fn(&str) -> Option<Vec<f32>>,
632 {
633 let context_lower = context.to_lowercase();
634 let mut matches: Vec<(ProspectiveTask, f32)> = Vec::new();
635 let mut seen_ids = std::collections::HashSet::new();
636
637 let pending_tasks = self.list_for_user(user_id, Some(ProspectiveTaskStatus::Pending))?;
639
640 for task in pending_tasks {
641 if seen_ids.contains(&task.id.0) {
642 continue;
643 }
644
645 if let ProspectiveTrigger::OnContext {
646 ref keywords,
647 threshold,
648 } = task.trigger
649 {
650 let keyword_match = keywords
652 .iter()
653 .any(|kw| context_lower.contains(&kw.to_lowercase()));
654
655 if keyword_match {
656 seen_ids.insert(task.id.0);
657 matches.push((task, 1.0)); continue;
659 }
660
661 let task_emb = task
663 .embedding
664 .as_deref()
665 .map(|e| std::borrow::Cow::Borrowed(e))
666 .or_else(|| embed_fn(&task.content).map(std::borrow::Cow::Owned));
667
668 if let Some(task_embedding) = task_emb {
669 let similarity = cosine_similarity(context_embedding, &task_embedding);
670 if similarity >= threshold {
671 seen_ids.insert(task.id.0);
672 matches.push((task, similarity));
673 }
674 }
675 }
676 }
677
678 matches.sort_by(|a, b| {
680 let score_cmp = b.1.total_cmp(&a.1);
681 if score_cmp != std::cmp::Ordering::Equal {
682 return score_cmp;
683 }
684 b.0.priority.cmp(&a.0.priority)
685 });
686
687 Ok(matches)
688 }
689
690 pub fn mark_triggered(&self, user_id: &str, task_id: &ProspectiveTaskId) -> Result<bool> {
692 if let Some(mut task) = self.get(user_id, task_id)? {
693 if task.status == ProspectiveTaskStatus::Pending {
694 task.mark_triggered();
695 self.update(&task)?;
696 return Ok(true);
697 }
698 }
699 Ok(false)
700 }
701
702 pub fn mark_dismissed(&self, user_id: &str, task_id: &ProspectiveTaskId) -> Result<bool> {
704 if let Some(mut task) = self.get(user_id, task_id)? {
705 task.mark_dismissed();
706 self.update(&task)?;
707 return Ok(true);
708 }
709 Ok(false)
710 }
711
712 pub fn pending_count(&self, user_id: &str) -> Result<usize> {
714 let tasks = self.list_for_user(user_id, Some(ProspectiveTaskStatus::Pending))?;
715 Ok(tasks.len())
716 }
717
718 pub fn find_by_prefix(
723 &self,
724 user_id: &str,
725 id_prefix: &str,
726 ) -> Result<Option<ProspectiveTask>> {
727 let prefix_lower = id_prefix.to_lowercase();
728 let tasks = self.list_for_user(user_id, None)?;
729
730 let matches: Vec<_> = tasks
731 .into_iter()
732 .filter(|t| t.id.0.to_string().to_lowercase().starts_with(&prefix_lower))
733 .collect();
734
735 match matches.len() {
736 0 => Ok(None),
737 1 => Ok(Some(matches.into_iter().next().unwrap())),
738 _ => {
739 tracing::warn!(
741 user_id = %user_id,
742 prefix = %id_prefix,
743 matches = matches.len(),
744 "Multiple reminders match prefix, using first"
745 );
746 Ok(Some(matches.into_iter().next().unwrap()))
747 }
748 }
749 }
750
751 pub fn cleanup_old_tasks(&self, user_id: &str, older_than_days: i64) -> Result<usize> {
753 let cutoff = Utc::now() - chrono::Duration::days(older_than_days);
754 let mut deleted = 0;
755
756 let tasks = self.list_for_user(user_id, None)?;
757 for task in tasks {
758 if matches!(
760 task.status,
761 ProspectiveTaskStatus::Dismissed | ProspectiveTaskStatus::Expired
762 ) {
763 let check_time = task.dismissed_at.unwrap_or(task.created_at);
764 if check_time < cutoff {
765 self.delete(user_id, &task.id)?;
766 deleted += 1;
767 }
768 }
769 }
770
771 if deleted > 0 {
772 tracing::info!(user_id = %user_id, deleted = deleted, "Cleaned up old prospective tasks");
773 }
774
775 Ok(deleted)
776 }
777}
778
779#[cfg(test)]
780mod tests {
781 use super::*;
782 use tempfile::TempDir;
783
784 fn open_test_db(path: &Path) -> Arc<DB> {
785 let shared_path = path.join("shared");
786 std::fs::create_dir_all(&shared_path).unwrap();
787 let mut opts = Options::default();
788 opts.create_if_missing(true);
789 opts.create_missing_column_families(true);
790 opts.set_compression_type(rocksdb::DBCompressionType::Lz4);
791 let cfs = vec![
792 ColumnFamilyDescriptor::new("default", opts.clone()),
793 ColumnFamilyDescriptor::new(CF_PROSPECTIVE, opts.clone()),
794 ColumnFamilyDescriptor::new(CF_PROSPECTIVE_INDEX, opts.clone()),
795 ];
796 Arc::new(DB::open_cf_descriptors(&opts, &shared_path, cfs).unwrap())
797 }
798
799 fn setup_store() -> (TempDir, ProspectiveStore) {
800 let temp_dir = TempDir::new().unwrap();
801 let db = open_test_db(temp_dir.path());
802 let store = ProspectiveStore::new(db, temp_dir.path()).unwrap();
803 (temp_dir, store)
804 }
805
806 #[test]
807 fn test_store_and_get() {
808 let (_temp, store) = setup_store();
809
810 let task = ProspectiveTask::new(
811 "test-user".to_string(),
812 "Remember to push code".to_string(),
813 ProspectiveTrigger::AfterDuration {
814 seconds: 3600,
815 from: Utc::now(),
816 },
817 );
818
819 store.store(&task).unwrap();
820
821 let retrieved = store.get("test-user", &task.id).unwrap();
822 assert!(retrieved.is_some());
823
824 let retrieved = retrieved.unwrap();
825 assert_eq!(retrieved.content, "Remember to push code");
826 assert_eq!(retrieved.status, ProspectiveTaskStatus::Pending);
827 }
828
829 #[test]
830 fn test_list_for_user() {
831 let (_temp, store) = setup_store();
832
833 let task1 = ProspectiveTask::new(
835 "user-a".to_string(),
836 "Task 1".to_string(),
837 ProspectiveTrigger::AfterDuration {
838 seconds: 3600,
839 from: Utc::now(),
840 },
841 );
842
843 let task2 = ProspectiveTask::new(
844 "user-a".to_string(),
845 "Task 2".to_string(),
846 ProspectiveTrigger::AfterDuration {
847 seconds: 7200,
848 from: Utc::now(),
849 },
850 );
851
852 let task3 = ProspectiveTask::new(
853 "user-b".to_string(),
854 "Task 3".to_string(),
855 ProspectiveTrigger::AfterDuration {
856 seconds: 3600,
857 from: Utc::now(),
858 },
859 );
860
861 store.store(&task1).unwrap();
862 store.store(&task2).unwrap();
863 store.store(&task3).unwrap();
864
865 let user_a_tasks = store.list_for_user("user-a", None).unwrap();
866 assert_eq!(user_a_tasks.len(), 2);
867
868 let user_b_tasks = store.list_for_user("user-b", None).unwrap();
869 assert_eq!(user_b_tasks.len(), 1);
870 }
871
872 #[test]
873 fn test_due_tasks() {
874 let (_temp, store) = setup_store();
875
876 let past = Utc::now() - chrono::Duration::seconds(100);
878 let due_task = ProspectiveTask::new(
879 "test-user".to_string(),
880 "Due task".to_string(),
881 ProspectiveTrigger::AfterDuration {
882 seconds: 0,
883 from: past,
884 },
885 );
886
887 let future_task = ProspectiveTask::new(
889 "test-user".to_string(),
890 "Future task".to_string(),
891 ProspectiveTrigger::AfterDuration {
892 seconds: 999999,
893 from: Utc::now(),
894 },
895 );
896
897 store.store(&due_task).unwrap();
898 store.store(&future_task).unwrap();
899
900 let due = store.get_due_tasks("test-user").unwrap();
901 assert_eq!(due.len(), 1);
902 assert_eq!(due[0].content, "Due task");
903 }
904
905 #[test]
906 fn test_context_trigger() {
907 let (_temp, store) = setup_store();
908
909 let task = ProspectiveTask::new(
910 "test-user".to_string(),
911 "Check auth token".to_string(),
912 ProspectiveTrigger::OnContext {
913 keywords: vec![
914 "authentication".to_string(),
915 "token".to_string(),
916 "jwt".to_string(),
917 ],
918 threshold: 0.7,
919 },
920 );
921
922 store.store(&task).unwrap();
923
924 let matches = store
926 .check_context_triggers("test-user", "I need to fix the JWT token expiry")
927 .unwrap();
928 assert_eq!(matches.len(), 1);
929
930 let no_matches = store
932 .check_context_triggers("test-user", "Let's update the database schema")
933 .unwrap();
934 assert_eq!(no_matches.len(), 0);
935 }
936
937 #[test]
938 fn test_due_key_migration_and_ordering() {
939 let temp_dir = TempDir::new().unwrap();
940 let db = open_test_db(temp_dir.path());
941
942 let index_cf = db
943 .cf_handle(CF_PROSPECTIVE_INDEX)
944 .expect("prospective_index CF must exist");
945
946 let task_id_a = uuid::Uuid::new_v4();
948 let task_id_b = uuid::Uuid::new_v4();
949 db.put_cf(
952 index_cf,
953 format!("due:9:{}", task_id_a).as_bytes(),
954 b"user-1",
955 )
956 .unwrap();
957 db.put_cf(
958 index_cf,
959 format!("due:10:{}", task_id_b).as_bytes(),
960 b"user-1",
961 )
962 .unwrap();
963
964 let migrated = migrate_due_key_padding(&db, index_cf).unwrap();
966 assert_eq!(migrated, 2);
967
968 assert!(db
970 .get_cf(index_cf, format!("due:9:{}", task_id_a).as_bytes())
971 .unwrap()
972 .is_none());
973 assert!(db
974 .get_cf(index_cf, format!("due:10:{}", task_id_b).as_bytes())
975 .unwrap()
976 .is_none());
977
978 let key_a = format!("due:{:020}:{}", 9_i64, task_id_a);
980 let key_b = format!("due:{:020}:{}", 10_i64, task_id_b);
981 assert!(db.get_cf(index_cf, key_a.as_bytes()).unwrap().is_some());
982 assert!(db.get_cf(index_cf, key_b.as_bytes()).unwrap().is_some());
983
984 assert!(
986 key_a < key_b,
987 "Padded key for ts=9 should sort before ts=10"
988 );
989
990 let migrated_again = migrate_due_key_padding(&db, index_cf).unwrap();
992 assert_eq!(migrated_again, 0);
993 }
994
995 #[test]
996 fn test_mark_triggered_and_dismissed() {
997 let (_temp, store) = setup_store();
998
999 let task = ProspectiveTask::new(
1000 "test-user".to_string(),
1001 "Test task".to_string(),
1002 ProspectiveTrigger::AfterDuration {
1003 seconds: 0,
1004 from: Utc::now(),
1005 },
1006 );
1007
1008 store.store(&task).unwrap();
1009
1010 store.mark_triggered("test-user", &task.id).unwrap();
1012 let updated = store.get("test-user", &task.id).unwrap().unwrap();
1013 assert_eq!(updated.status, ProspectiveTaskStatus::Triggered);
1014 assert!(updated.triggered_at.is_some());
1015
1016 store.mark_dismissed("test-user", &task.id).unwrap();
1018 let dismissed = store.get("test-user", &task.id).unwrap().unwrap();
1019 assert_eq!(dismissed.status, ProspectiveTaskStatus::Dismissed);
1020 assert!(dismissed.dismissed_at.is_some());
1021 }
1022
1023 #[test]
1024 fn test_delete() {
1025 let (_temp, store) = setup_store();
1026
1027 let task = ProspectiveTask::new(
1028 "test-user".to_string(),
1029 "To delete".to_string(),
1030 ProspectiveTrigger::AfterDuration {
1031 seconds: 3600,
1032 from: Utc::now(),
1033 },
1034 );
1035
1036 store.store(&task).unwrap();
1037 assert!(store.get("test-user", &task.id).unwrap().is_some());
1038
1039 let deleted = store.delete("test-user", &task.id).unwrap();
1040 assert!(deleted);
1041 assert!(store.get("test-user", &task.id).unwrap().is_none());
1042 }
1043}