Skip to main content

shodh_memory/memory/
prospective.rs

1//! Prospective Memory - Future intentions and reminders (SHO-116)
2//!
3//! Implements the "remembering to remember" capability:
4//! - Time-based triggers (at specific time, after duration)
5//! - Context-based triggers (keyword match, semantic similarity)
6//!
7//! Architecture:
8//! - ProspectiveTask stored in "prospective" column family of shared RocksDB
9//! - Secondary indices in "prospective_index" column family
10//! - Memory with ExperienceType::Intention created for semantic integration
11//! - Uses Hebbian learning for decay (same as regular memories)
12
13use anyhow::{Context, Result};
14use chrono::Utc;
15use rocksdb::{ColumnFamily, ColumnFamilyDescriptor, Options, WriteBatch, DB};
16use std::path::Path;
17use std::sync::Arc;
18
19use super::types::{ProspectiveTask, ProspectiveTaskId, ProspectiveTaskStatus, ProspectiveTrigger};
20
21/// Column family for main task storage (key = `{user_id}:{task_id}`)
22const CF_PROSPECTIVE: &str = "prospective";
23/// Column family for secondary indices (due dates, status, keyword lookups)
24const CF_PROSPECTIVE_INDEX: &str = "prospective_index";
25
26/// Compute cosine similarity between two embedding vectors
27fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
28    if a.len() != b.len() || a.is_empty() {
29        return 0.0;
30    }
31
32    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
33    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
34    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
35
36    if norm_a == 0.0 || norm_b == 0.0 {
37        return 0.0;
38    }
39
40    dot / (norm_a * norm_b)
41}
42
43/// Migrate unpadded `due:{ts}:{id}` keys to zero-padded `due:{:020}:{id}` format.
44///
45/// Prior versions wrote bare timestamps (e.g. `due:1739404800:uuid`), which break
46/// lexicographic ordering (`"9" > "10"`). Zero-padding to 20 digits ensures
47/// lex order = chronological order, enabling early-termination scans.
48fn migrate_due_key_padding(db: &DB, index_cf: &ColumnFamily) -> Result<usize> {
49    let mut batch = WriteBatch::default();
50    let mut count = 0;
51
52    for item in db.prefix_iterator_cf(index_cf, b"due:") {
53        let (key, value) = item.context("Failed to read due index during migration")?;
54        let key_str = std::str::from_utf8(&key).context("Non-UTF8 key in prospective due index")?;
55
56        // Key format: due:{timestamp}:{task_id}
57        let parts: Vec<&str> = key_str.splitn(3, ':').collect();
58        if parts.len() != 3 {
59            continue;
60        }
61
62        // Already padded — nothing to do
63        if parts[1].len() >= 20 {
64            continue;
65        }
66
67        if let Ok(ts) = parts[1].parse::<i64>() {
68            let new_key = format!("due:{:020}:{}", ts, parts[2]);
69            batch.delete_cf(index_cf, &*key);
70            batch.put_cf(index_cf, new_key.as_bytes(), &*value);
71            count += 1;
72        }
73    }
74
75    if count > 0 {
76        db.write(batch)
77            .context("Failed to write migrated prospective due keys")?;
78        tracing::info!(count, "Migrated prospective due keys to zero-padded format");
79    }
80
81    Ok(count)
82}
83
84/// Storage and query engine for prospective memory (reminders)
85pub struct ProspectiveStore {
86    /// Shared RocksDB instance with "prospective" and "prospective_index" column families
87    db: Arc<DB>,
88}
89
90impl ProspectiveStore {
91    /// Return the column family descriptors needed by ProspectiveStore.
92    ///
93    /// Call this when opening the shared DB so the CFs are registered.
94    pub fn column_family_descriptors() -> Vec<ColumnFamilyDescriptor> {
95        let mut opts = Options::default();
96        opts.create_if_missing(true);
97        opts.set_compression_type(rocksdb::DBCompressionType::Lz4);
98
99        vec![
100            ColumnFamilyDescriptor::new(CF_PROSPECTIVE, opts.clone()),
101            ColumnFamilyDescriptor::new(CF_PROSPECTIVE_INDEX, opts),
102        ]
103    }
104
105    /// CF accessor for the main task storage column family
106    fn tasks_cf(&self) -> &ColumnFamily {
107        self.db
108            .cf_handle(CF_PROSPECTIVE)
109            .expect("prospective CF must exist")
110    }
111
112    /// CF accessor for the secondary index column family
113    fn index_cf(&self) -> &ColumnFamily {
114        self.db
115            .cf_handle(CF_PROSPECTIVE_INDEX)
116            .expect("prospective_index CF must exist")
117    }
118
119    /// Create a new prospective store backed by the given shared DB.
120    ///
121    /// The DB must have been opened with the column families returned by
122    /// [`column_family_descriptors()`].  On first run after the migration,
123    /// data from the old separate `tasks/` and `index/` sub-DBs is copied
124    /// into the corresponding CFs and the old directories are renamed.
125    pub fn new(db: Arc<DB>, storage_path: &Path) -> Result<Self> {
126        let prospective_path = storage_path.join("prospective");
127        std::fs::create_dir_all(&prospective_path)?;
128
129        Self::migrate_from_separate_dbs(&prospective_path, &db)?;
130
131        let index_cf = db
132            .cf_handle(CF_PROSPECTIVE_INDEX)
133            .expect("prospective_index CF must exist");
134        migrate_due_key_padding(&db, index_cf)?;
135
136        tracing::info!("Prospective memory store initialized");
137        Ok(Self { db })
138    }
139
140    /// One-time migration: copy data from legacy separate RocksDB instances
141    /// (`prospective/tasks/` and `prospective/index/`) into the shared DB's
142    /// column families, then rename the old directories so we don't re-migrate.
143    fn migrate_from_separate_dbs(prospective_path: &Path, db: &DB) -> Result<()> {
144        let old_dirs: &[(&str, &str)] =
145            &[("tasks", CF_PROSPECTIVE), ("index", CF_PROSPECTIVE_INDEX)];
146
147        for (old_name, cf_name) in old_dirs {
148            let old_dir = prospective_path.join(old_name);
149            if !old_dir.is_dir() {
150                continue;
151            }
152
153            let cf = db
154                .cf_handle(cf_name)
155                .unwrap_or_else(|| panic!("{cf_name} CF must exist"));
156            let old_opts = Options::default();
157            match DB::open_for_read_only(&old_opts, &old_dir, false) {
158                Ok(old_db) => {
159                    let mut batch = WriteBatch::default();
160                    let mut count = 0usize;
161                    for item in old_db.iterator(rocksdb::IteratorMode::Start) {
162                        if let Ok((key, value)) = item {
163                            batch.put_cf(cf, &key, &value);
164                            count += 1;
165                            if count % 10_000 == 0 {
166                                db.write(std::mem::take(&mut batch))?;
167                            }
168                        }
169                    }
170                    if !batch.is_empty() {
171                        db.write(batch)?;
172                    }
173                    drop(old_db);
174                    tracing::info!(
175                        "  prospective/{old_name}: migrated {count} entries to {cf_name} CF"
176                    );
177
178                    let backup = prospective_path.join(format!("{old_name}.pre_cf_migration"));
179                    if backup.exists() {
180                        let _ = std::fs::remove_dir_all(&backup);
181                    }
182                    if let Err(e) = std::fs::rename(&old_dir, &backup) {
183                        tracing::warn!("Could not rename old {old_name} dir: {e}");
184                    }
185                }
186                Err(e) => {
187                    tracing::warn!("Could not open old {old_name} DB for migration: {e}");
188                }
189            }
190        }
191        Ok(())
192    }
193
194    /// Flush all column families to disk (critical for graceful shutdown)
195    pub fn flush(&self) -> Result<()> {
196        use rocksdb::FlushOptions;
197        let mut flush_opts = FlushOptions::default();
198        flush_opts.set_wait(true);
199        for cf_name in &[CF_PROSPECTIVE, CF_PROSPECTIVE_INDEX] {
200            if let Some(cf) = self.db.cf_handle(cf_name) {
201                self.db
202                    .flush_cf_opt(cf, &flush_opts)
203                    .map_err(|e| anyhow::anyhow!("Failed to flush {cf_name}: {e}"))?;
204            }
205        }
206        Ok(())
207    }
208
209    /// Get references to all RocksDB databases for backup
210    pub fn databases(&self) -> Vec<(&str, &Arc<DB>)> {
211        vec![("prospective_shared", &self.db)]
212    }
213
214    /// Store a new prospective task
215    pub fn store(&self, task: &ProspectiveTask) -> Result<()> {
216        let key = format!("{}:{}", task.user_id, task.id);
217        // Use JSON instead of bincode - handles tagged enums properly and is human-readable
218        let value =
219            serde_json::to_vec(task).context("Failed to serialize prospective task to JSON")?;
220
221        self.db
222            .put_cf(self.tasks_cf(), key.as_bytes(), &value)
223            .context("Failed to store prospective task")?;
224
225        // Update indices
226        self.update_indices(task)?;
227
228        tracing::debug!(
229            task_id = %task.id,
230            user_id = %task.user_id,
231            trigger = ?task.trigger,
232            "Stored prospective task"
233        );
234
235        Ok(())
236    }
237
238    /// Update secondary indices for efficient queries
239    fn update_indices(&self, task: &ProspectiveTask) -> Result<()> {
240        let mut batch = WriteBatch::default();
241
242        // Index by user (for listing user's reminders)
243        let user_key = format!("user:{}:{}", task.user_id, task.id);
244        batch.put_cf(self.index_cf(), user_key.as_bytes(), b"1");
245
246        // Index by status
247        let status_key = format!("status:{:?}:{}:{}", task.status, task.user_id, task.id);
248        batch.put_cf(self.index_cf(), status_key.as_bytes(), b"1");
249
250        // Index by due time (for time-based trigger queries)
251        // Zero-padded to 20 digits for correct lexicographic ordering
252        if let Some(due_at) = task.trigger.due_at() {
253            let due_key = format!("due:{:020}:{}", due_at.timestamp(), task.id);
254            batch.put_cf(self.index_cf(), due_key.as_bytes(), task.user_id.as_bytes());
255        }
256
257        // Index context triggers by keywords
258        if let ProspectiveTrigger::OnContext { ref keywords, .. } = task.trigger {
259            for keyword in keywords {
260                let kw_key = format!(
261                    "context:{}:{}:{}",
262                    keyword.to_lowercase(),
263                    task.user_id,
264                    task.id
265                );
266                batch.put_cf(self.index_cf(), kw_key.as_bytes(), b"1");
267            }
268        }
269
270        self.db
271            .write(batch)
272            .context("Failed to update prospective indices")?;
273
274        Ok(())
275    }
276
277    /// Get a task by ID
278    pub fn get(
279        &self,
280        user_id: &str,
281        task_id: &ProspectiveTaskId,
282    ) -> Result<Option<ProspectiveTask>> {
283        let key = format!("{}:{}", user_id, task_id);
284
285        match self.db.get_cf(self.tasks_cf(), key.as_bytes())? {
286            Some(value) => {
287                let task: ProspectiveTask = serde_json::from_slice(&value)
288                    .context("Failed to deserialize prospective task from JSON")?;
289                Ok(Some(task))
290            }
291            None => Ok(None),
292        }
293    }
294
295    /// Update a task (e.g., mark as triggered/dismissed)
296    ///
297    /// Atomic: reads old task, builds a single WriteBatch containing old-index
298    /// deletes + new task write + new index writes, then commits once.
299    pub fn update(&self, task: &ProspectiveTask) -> Result<()> {
300        let mut batch = WriteBatch::default();
301
302        // 1. Remove old indices (need to read old task to know what to delete)
303        if let Some(old_task) = self.get(&task.user_id, &task.id)? {
304            let user_key = format!("user:{}:{}", task.user_id, task.id);
305            batch.delete_cf(self.index_cf(), user_key.as_bytes());
306
307            let status_key = format!("status:{:?}:{}:{}", old_task.status, task.user_id, task.id);
308            batch.delete_cf(self.index_cf(), status_key.as_bytes());
309
310            if let Some(due_at) = old_task.trigger.due_at() {
311                let due_key = format!("due:{:020}:{}", due_at.timestamp(), task.id);
312                batch.delete_cf(self.index_cf(), due_key.as_bytes());
313            }
314
315            if let ProspectiveTrigger::OnContext { ref keywords, .. } = old_task.trigger {
316                for keyword in keywords {
317                    let kw_key = format!(
318                        "context:{}:{}:{}",
319                        keyword.to_lowercase(),
320                        task.user_id,
321                        task.id
322                    );
323                    batch.delete_cf(self.index_cf(), kw_key.as_bytes());
324                }
325            }
326        }
327
328        // 2. Write new task data
329        let key = format!("{}:{}", task.user_id, task.id);
330        let value =
331            serde_json::to_vec(task).context("Failed to serialize prospective task to JSON")?;
332        batch.put_cf(self.tasks_cf(), key.as_bytes(), &value);
333
334        // 3. Write new indices
335        let user_key = format!("user:{}:{}", task.user_id, task.id);
336        batch.put_cf(self.index_cf(), user_key.as_bytes(), b"1");
337
338        let status_key = format!("status:{:?}:{}:{}", task.status, task.user_id, task.id);
339        batch.put_cf(self.index_cf(), status_key.as_bytes(), b"1");
340
341        // Only index due time for Pending tasks — triggered/dismissed tasks
342        // don't need to appear in due scans (fixes M13: stale due index bloat)
343        if task.status == ProspectiveTaskStatus::Pending {
344            if let Some(due_at) = task.trigger.due_at() {
345                let due_key = format!("due:{:020}:{}", due_at.timestamp(), task.id);
346                batch.put_cf(self.index_cf(), due_key.as_bytes(), task.user_id.as_bytes());
347            }
348        }
349
350        if let ProspectiveTrigger::OnContext { ref keywords, .. } = task.trigger {
351            for keyword in keywords {
352                let kw_key = format!(
353                    "context:{}:{}:{}",
354                    keyword.to_lowercase(),
355                    task.user_id,
356                    task.id
357                );
358                batch.put_cf(self.index_cf(), kw_key.as_bytes(), b"1");
359            }
360        }
361
362        // 4. Single atomic commit
363        self.db
364            .write(batch)
365            .context("Failed to atomically update prospective task")?;
366
367        tracing::debug!(
368            task_id = %task.id,
369            user_id = %task.user_id,
370            status = ?task.status,
371            "Updated prospective task (atomic)"
372        );
373
374        Ok(())
375    }
376
377    /// Delete a task (atomic: removes task + all indices in single WriteBatch)
378    pub fn delete(&self, user_id: &str, task_id: &ProspectiveTaskId) -> Result<bool> {
379        let task = match self.get(user_id, task_id)? {
380            Some(t) => t,
381            None => return Ok(false),
382        };
383
384        let mut batch = WriteBatch::default();
385
386        // Delete task data
387        let key = format!("{}:{}", user_id, task_id);
388        batch.delete_cf(self.tasks_cf(), key.as_bytes());
389
390        // Delete all indices
391        let user_key = format!("user:{}:{}", user_id, task_id);
392        batch.delete_cf(self.index_cf(), user_key.as_bytes());
393
394        let status_key = format!("status:{:?}:{}:{}", task.status, user_id, task_id);
395        batch.delete_cf(self.index_cf(), status_key.as_bytes());
396
397        if let Some(due_at) = task.trigger.due_at() {
398            let due_key = format!("due:{:020}:{}", due_at.timestamp(), task_id);
399            batch.delete_cf(self.index_cf(), due_key.as_bytes());
400        }
401
402        if let ProspectiveTrigger::OnContext { ref keywords, .. } = task.trigger {
403            for keyword in keywords {
404                let kw_key = format!("context:{}:{}:{}", keyword.to_lowercase(), user_id, task_id);
405                batch.delete_cf(self.index_cf(), kw_key.as_bytes());
406            }
407        }
408
409        self.db
410            .write(batch)
411            .context("Failed to atomically delete prospective task")?;
412
413        tracing::debug!(task_id = %task_id, user_id = %user_id, "Deleted prospective task");
414
415        Ok(true)
416    }
417
418    /// List all tasks for a user, optionally filtered by status
419    pub fn list_for_user(
420        &self,
421        user_id: &str,
422        status_filter: Option<ProspectiveTaskStatus>,
423    ) -> Result<Vec<ProspectiveTask>> {
424        let prefix = format!("user:{}:", user_id);
425        let mut tasks = Vec::new();
426
427        for item in self
428            .db
429            .prefix_iterator_cf(self.index_cf(), prefix.as_bytes())
430        {
431            let (key, _) = item.context("Failed to read index entry")?;
432            let key_str = String::from_utf8_lossy(&key);
433
434            // Extract task_id from key: user:{user_id}:{task_id}
435            if let Some(task_id_str) = key_str.strip_prefix(&prefix) {
436                if let Ok(uuid) = uuid::Uuid::parse_str(task_id_str) {
437                    let task_id = ProspectiveTaskId(uuid);
438                    if let Some(task) = self.get(user_id, &task_id)? {
439                        // Apply status filter if specified
440                        if let Some(filter) = status_filter {
441                            if task.status == filter {
442                                tasks.push(task);
443                            }
444                        } else {
445                            tasks.push(task);
446                        }
447                    }
448                }
449            }
450        }
451
452        // Sort by created_at descending (newest first)
453        tasks.sort_by(|a, b| b.created_at.cmp(&a.created_at));
454
455        Ok(tasks)
456    }
457
458    /// Get all due time-based tasks for a user
459    ///
460    /// Returns tasks where:
461    /// - Trigger is time-based (AtTime or AfterDuration)
462    /// - Trigger time <= now
463    /// - Status is Pending
464    pub fn get_due_tasks(&self, user_id: &str) -> Result<Vec<ProspectiveTask>> {
465        let now = Utc::now();
466        let now_ts = now.timestamp();
467
468        let mut due_tasks = Vec::new();
469
470        // Scan due index for tasks with due_time <= now
471        // Key format: due:{timestamp}:{task_id}
472        for item in self.db.prefix_iterator_cf(self.index_cf(), b"due:") {
473            let (key, value) = item.context("Failed to read due index")?;
474            let key_str = String::from_utf8_lossy(&key);
475
476            // Parse key: due:{timestamp}:{task_id}
477            let parts: Vec<&str> = key_str.splitn(3, ':').collect();
478            if parts.len() != 3 {
479                continue;
480            }
481
482            let task_ts: i64 = match parts[1].parse() {
483                Ok(ts) => ts,
484                Err(_) => continue,
485            };
486
487            // With zero-padded keys, lexicographic order = chronological order.
488            // All remaining keys are also in the future — stop scanning.
489            if task_ts > now_ts {
490                break;
491            }
492
493            // Check user matches
494            let stored_user_id = String::from_utf8_lossy(&value);
495            if stored_user_id != user_id {
496                continue;
497            }
498
499            // Get task and check status
500            if let Ok(uuid) = uuid::Uuid::parse_str(parts[2]) {
501                let task_id = ProspectiveTaskId(uuid);
502                if let Some(task) = self.get(user_id, &task_id)? {
503                    if task.status == ProspectiveTaskStatus::Pending {
504                        due_tasks.push(task);
505                    }
506                }
507            }
508        }
509
510        // Sort by priority (higher first) then by due time (earliest first)
511        due_tasks.sort_by(|a, b| {
512            let priority_cmp = b.priority.cmp(&a.priority);
513            if priority_cmp != std::cmp::Ordering::Equal {
514                return priority_cmp;
515            }
516            let a_due = a.trigger.due_at().unwrap_or(a.created_at);
517            let b_due = b.trigger.due_at().unwrap_or(b.created_at);
518            a_due.cmp(&b_due)
519        });
520
521        Ok(due_tasks)
522    }
523
524    /// Scan ALL users for due reminders (used by active reminder scheduler).
525    ///
526    /// Returns `(user_id, task)` pairs for all pending tasks whose due time has passed.
527    /// Leverages the zero-padded `due:{timestamp}:{task_id}` index for efficient scanning:
528    /// lexicographic order = chronological, so we stop at the first future timestamp.
529    pub fn get_all_due_tasks(&self) -> Result<Vec<(String, ProspectiveTask)>> {
530        let now_ts = Utc::now().timestamp();
531        let mut due_tasks = Vec::new();
532
533        for item in self.db.prefix_iterator_cf(self.index_cf(), b"due:") {
534            let (key, value) = item.context("Failed to read due index")?;
535            let key_str = String::from_utf8_lossy(&key);
536
537            let parts: Vec<&str> = key_str.splitn(3, ':').collect();
538            if parts.len() != 3 {
539                continue;
540            }
541
542            let task_ts: i64 = match parts[1].parse() {
543                Ok(ts) => ts,
544                Err(_) => continue,
545            };
546
547            if task_ts > now_ts {
548                break;
549            }
550
551            let user_id = String::from_utf8_lossy(&value).to_string();
552
553            if let Ok(uuid) = uuid::Uuid::parse_str(parts[2]) {
554                let task_id = ProspectiveTaskId(uuid);
555                if let Some(task) = self.get(&user_id, &task_id)? {
556                    if task.status == ProspectiveTaskStatus::Pending {
557                        due_tasks.push((user_id, task));
558                    }
559                }
560            }
561        }
562
563        Ok(due_tasks)
564    }
565
566    /// Check for context-triggered reminders based on text content (keyword match only)
567    ///
568    /// Returns tasks where:
569    /// - Trigger is OnContext
570    /// - Any keyword matches the context text
571    /// - Status is Pending
572    ///
573    /// For semantic matching, use `check_context_triggers_semantic` instead.
574    pub fn check_context_triggers(
575        &self,
576        user_id: &str,
577        context: &str,
578    ) -> Result<Vec<ProspectiveTask>> {
579        let context_lower = context.to_lowercase();
580        let mut matches = Vec::new();
581        let mut seen_ids = std::collections::HashSet::new();
582
583        // Get all pending context-based tasks for user
584        let pending_tasks = self.list_for_user(user_id, Some(ProspectiveTaskStatus::Pending))?;
585
586        for task in pending_tasks {
587            if seen_ids.contains(&task.id.0) {
588                continue;
589            }
590
591            if let ProspectiveTrigger::OnContext { ref keywords, .. } = task.trigger {
592                // Check if any keyword matches
593                let matched = keywords
594                    .iter()
595                    .any(|kw| context_lower.contains(&kw.to_lowercase()));
596                if matched {
597                    seen_ids.insert(task.id.0);
598                    matches.push(task);
599                }
600            }
601        }
602
603        // Sort by priority
604        matches.sort_by(|a, b| b.priority.cmp(&a.priority));
605
606        Ok(matches)
607    }
608
609    /// Check for context-triggered reminders using both keyword AND semantic matching
610    ///
611    /// Returns tasks that either:
612    /// - Have keyword matches in the context (score = 1.0), OR
613    /// - Have semantic similarity above their threshold
614    ///
615    /// # Arguments
616    /// * `user_id` - User to check reminders for
617    /// * `context` - Current context text (for keyword matching)
618    /// * `context_embedding` - Precomputed embedding of the context
619    /// * `embed_fn` - Closure to compute embedding for task content
620    ///
621    /// # Returns
622    /// Vector of (task, score) tuples sorted by score (highest first)
623    pub fn check_context_triggers_semantic<F>(
624        &self,
625        user_id: &str,
626        context: &str,
627        context_embedding: &[f32],
628        embed_fn: F,
629    ) -> Result<Vec<(ProspectiveTask, f32)>>
630    where
631        F: Fn(&str) -> Option<Vec<f32>>,
632    {
633        let context_lower = context.to_lowercase();
634        let mut matches: Vec<(ProspectiveTask, f32)> = Vec::new();
635        let mut seen_ids = std::collections::HashSet::new();
636
637        // Get all pending context-based tasks for user
638        let pending_tasks = self.list_for_user(user_id, Some(ProspectiveTaskStatus::Pending))?;
639
640        for task in pending_tasks {
641            if seen_ids.contains(&task.id.0) {
642                continue;
643            }
644
645            if let ProspectiveTrigger::OnContext {
646                ref keywords,
647                threshold,
648            } = task.trigger
649            {
650                // 1. Check keyword matches first (fast path)
651                let keyword_match = keywords
652                    .iter()
653                    .any(|kw| context_lower.contains(&kw.to_lowercase()));
654
655                if keyword_match {
656                    seen_ids.insert(task.id.0);
657                    matches.push((task, 1.0)); // Perfect score for keyword match
658                    continue;
659                }
660
661                // 2. Try semantic matching (prefer cached embedding, fallback to embed_fn)
662                let task_emb = task
663                    .embedding
664                    .as_deref()
665                    .map(|e| std::borrow::Cow::Borrowed(e))
666                    .or_else(|| embed_fn(&task.content).map(std::borrow::Cow::Owned));
667
668                if let Some(task_embedding) = task_emb {
669                    let similarity = cosine_similarity(context_embedding, &task_embedding);
670                    if similarity >= threshold {
671                        seen_ids.insert(task.id.0);
672                        matches.push((task, similarity));
673                    }
674                }
675            }
676        }
677
678        // Sort by score (highest first), then by priority
679        matches.sort_by(|a, b| {
680            let score_cmp = b.1.total_cmp(&a.1);
681            if score_cmp != std::cmp::Ordering::Equal {
682                return score_cmp;
683            }
684            b.0.priority.cmp(&a.0.priority)
685        });
686
687        Ok(matches)
688    }
689
690    /// Mark a task as triggered
691    pub fn mark_triggered(&self, user_id: &str, task_id: &ProspectiveTaskId) -> Result<bool> {
692        if let Some(mut task) = self.get(user_id, task_id)? {
693            if task.status == ProspectiveTaskStatus::Pending {
694                task.mark_triggered();
695                self.update(&task)?;
696                return Ok(true);
697            }
698        }
699        Ok(false)
700    }
701
702    /// Mark a task as dismissed
703    pub fn mark_dismissed(&self, user_id: &str, task_id: &ProspectiveTaskId) -> Result<bool> {
704        if let Some(mut task) = self.get(user_id, task_id)? {
705            task.mark_dismissed();
706            self.update(&task)?;
707            return Ok(true);
708        }
709        Ok(false)
710    }
711
712    /// Get count of pending tasks for a user
713    pub fn pending_count(&self, user_id: &str) -> Result<usize> {
714        let tasks = self.list_for_user(user_id, Some(ProspectiveTaskStatus::Pending))?;
715        Ok(tasks.len())
716    }
717
718    /// Find a task by ID prefix (for short ID lookups)
719    ///
720    /// Allows users to dismiss reminders using short IDs like "d8cdc580"
721    /// instead of full UUIDs like "d8cdc580-bf96-403a-85c5-57098c7b1786"
722    pub fn find_by_prefix(
723        &self,
724        user_id: &str,
725        id_prefix: &str,
726    ) -> Result<Option<ProspectiveTask>> {
727        let prefix_lower = id_prefix.to_lowercase();
728        let tasks = self.list_for_user(user_id, None)?;
729
730        let matches: Vec<_> = tasks
731            .into_iter()
732            .filter(|t| t.id.0.to_string().to_lowercase().starts_with(&prefix_lower))
733            .collect();
734
735        match matches.len() {
736            0 => Ok(None),
737            1 => Ok(Some(matches.into_iter().next().unwrap())),
738            _ => {
739                // Multiple matches - return the first one but log warning
740                tracing::warn!(
741                    user_id = %user_id,
742                    prefix = %id_prefix,
743                    matches = matches.len(),
744                    "Multiple reminders match prefix, using first"
745                );
746                Ok(Some(matches.into_iter().next().unwrap()))
747            }
748        }
749    }
750
751    /// Cleanup expired/dismissed tasks older than given days
752    pub fn cleanup_old_tasks(&self, user_id: &str, older_than_days: i64) -> Result<usize> {
753        let cutoff = Utc::now() - chrono::Duration::days(older_than_days);
754        let mut deleted = 0;
755
756        let tasks = self.list_for_user(user_id, None)?;
757        for task in tasks {
758            // Only cleanup dismissed or expired tasks
759            if matches!(
760                task.status,
761                ProspectiveTaskStatus::Dismissed | ProspectiveTaskStatus::Expired
762            ) {
763                let check_time = task.dismissed_at.unwrap_or(task.created_at);
764                if check_time < cutoff {
765                    self.delete(user_id, &task.id)?;
766                    deleted += 1;
767                }
768            }
769        }
770
771        if deleted > 0 {
772            tracing::info!(user_id = %user_id, deleted = deleted, "Cleaned up old prospective tasks");
773        }
774
775        Ok(deleted)
776    }
777}
778
779#[cfg(test)]
780mod tests {
781    use super::*;
782    use tempfile::TempDir;
783
784    fn open_test_db(path: &Path) -> Arc<DB> {
785        let shared_path = path.join("shared");
786        std::fs::create_dir_all(&shared_path).unwrap();
787        let mut opts = Options::default();
788        opts.create_if_missing(true);
789        opts.create_missing_column_families(true);
790        opts.set_compression_type(rocksdb::DBCompressionType::Lz4);
791        let cfs = vec![
792            ColumnFamilyDescriptor::new("default", opts.clone()),
793            ColumnFamilyDescriptor::new(CF_PROSPECTIVE, opts.clone()),
794            ColumnFamilyDescriptor::new(CF_PROSPECTIVE_INDEX, opts.clone()),
795        ];
796        Arc::new(DB::open_cf_descriptors(&opts, &shared_path, cfs).unwrap())
797    }
798
799    fn setup_store() -> (TempDir, ProspectiveStore) {
800        let temp_dir = TempDir::new().unwrap();
801        let db = open_test_db(temp_dir.path());
802        let store = ProspectiveStore::new(db, temp_dir.path()).unwrap();
803        (temp_dir, store)
804    }
805
806    #[test]
807    fn test_store_and_get() {
808        let (_temp, store) = setup_store();
809
810        let task = ProspectiveTask::new(
811            "test-user".to_string(),
812            "Remember to push code".to_string(),
813            ProspectiveTrigger::AfterDuration {
814                seconds: 3600,
815                from: Utc::now(),
816            },
817        );
818
819        store.store(&task).unwrap();
820
821        let retrieved = store.get("test-user", &task.id).unwrap();
822        assert!(retrieved.is_some());
823
824        let retrieved = retrieved.unwrap();
825        assert_eq!(retrieved.content, "Remember to push code");
826        assert_eq!(retrieved.status, ProspectiveTaskStatus::Pending);
827    }
828
829    #[test]
830    fn test_list_for_user() {
831        let (_temp, store) = setup_store();
832
833        // Create tasks for two users
834        let task1 = ProspectiveTask::new(
835            "user-a".to_string(),
836            "Task 1".to_string(),
837            ProspectiveTrigger::AfterDuration {
838                seconds: 3600,
839                from: Utc::now(),
840            },
841        );
842
843        let task2 = ProspectiveTask::new(
844            "user-a".to_string(),
845            "Task 2".to_string(),
846            ProspectiveTrigger::AfterDuration {
847                seconds: 7200,
848                from: Utc::now(),
849            },
850        );
851
852        let task3 = ProspectiveTask::new(
853            "user-b".to_string(),
854            "Task 3".to_string(),
855            ProspectiveTrigger::AfterDuration {
856                seconds: 3600,
857                from: Utc::now(),
858            },
859        );
860
861        store.store(&task1).unwrap();
862        store.store(&task2).unwrap();
863        store.store(&task3).unwrap();
864
865        let user_a_tasks = store.list_for_user("user-a", None).unwrap();
866        assert_eq!(user_a_tasks.len(), 2);
867
868        let user_b_tasks = store.list_for_user("user-b", None).unwrap();
869        assert_eq!(user_b_tasks.len(), 1);
870    }
871
872    #[test]
873    fn test_due_tasks() {
874        let (_temp, store) = setup_store();
875
876        // Task that's already due (0 seconds from past)
877        let past = Utc::now() - chrono::Duration::seconds(100);
878        let due_task = ProspectiveTask::new(
879            "test-user".to_string(),
880            "Due task".to_string(),
881            ProspectiveTrigger::AfterDuration {
882                seconds: 0,
883                from: past,
884            },
885        );
886
887        // Task that's not due yet
888        let future_task = ProspectiveTask::new(
889            "test-user".to_string(),
890            "Future task".to_string(),
891            ProspectiveTrigger::AfterDuration {
892                seconds: 999999,
893                from: Utc::now(),
894            },
895        );
896
897        store.store(&due_task).unwrap();
898        store.store(&future_task).unwrap();
899
900        let due = store.get_due_tasks("test-user").unwrap();
901        assert_eq!(due.len(), 1);
902        assert_eq!(due[0].content, "Due task");
903    }
904
905    #[test]
906    fn test_context_trigger() {
907        let (_temp, store) = setup_store();
908
909        let task = ProspectiveTask::new(
910            "test-user".to_string(),
911            "Check auth token".to_string(),
912            ProspectiveTrigger::OnContext {
913                keywords: vec![
914                    "authentication".to_string(),
915                    "token".to_string(),
916                    "jwt".to_string(),
917                ],
918                threshold: 0.7,
919            },
920        );
921
922        store.store(&task).unwrap();
923
924        // Should match on keyword
925        let matches = store
926            .check_context_triggers("test-user", "I need to fix the JWT token expiry")
927            .unwrap();
928        assert_eq!(matches.len(), 1);
929
930        // Should not match
931        let no_matches = store
932            .check_context_triggers("test-user", "Let's update the database schema")
933            .unwrap();
934        assert_eq!(no_matches.len(), 0);
935    }
936
937    #[test]
938    fn test_due_key_migration_and_ordering() {
939        let temp_dir = TempDir::new().unwrap();
940        let db = open_test_db(temp_dir.path());
941
942        let index_cf = db
943            .cf_handle(CF_PROSPECTIVE_INDEX)
944            .expect("prospective_index CF must exist");
945
946        // Write unpadded keys simulating old format
947        let task_id_a = uuid::Uuid::new_v4();
948        let task_id_b = uuid::Uuid::new_v4();
949        // ts_a = 9 (1 digit), ts_b = 10 (2 digits)
950        // Without padding: "due:9:..." > "due:10:..." lexicographically (wrong)
951        db.put_cf(
952            index_cf,
953            format!("due:9:{}", task_id_a).as_bytes(),
954            b"user-1",
955        )
956        .unwrap();
957        db.put_cf(
958            index_cf,
959            format!("due:10:{}", task_id_b).as_bytes(),
960            b"user-1",
961        )
962        .unwrap();
963
964        // Run migration
965        let migrated = migrate_due_key_padding(&db, index_cf).unwrap();
966        assert_eq!(migrated, 2);
967
968        // Verify old keys are gone
969        assert!(db
970            .get_cf(index_cf, format!("due:9:{}", task_id_a).as_bytes())
971            .unwrap()
972            .is_none());
973        assert!(db
974            .get_cf(index_cf, format!("due:10:{}", task_id_b).as_bytes())
975            .unwrap()
976            .is_none());
977
978        // Verify new padded keys exist
979        let key_a = format!("due:{:020}:{}", 9_i64, task_id_a);
980        let key_b = format!("due:{:020}:{}", 10_i64, task_id_b);
981        assert!(db.get_cf(index_cf, key_a.as_bytes()).unwrap().is_some());
982        assert!(db.get_cf(index_cf, key_b.as_bytes()).unwrap().is_some());
983
984        // Verify lexicographic order is now correct: 9 < 10
985        assert!(
986            key_a < key_b,
987            "Padded key for ts=9 should sort before ts=10"
988        );
989
990        // Re-running migration should be a no-op
991        let migrated_again = migrate_due_key_padding(&db, index_cf).unwrap();
992        assert_eq!(migrated_again, 0);
993    }
994
995    #[test]
996    fn test_mark_triggered_and_dismissed() {
997        let (_temp, store) = setup_store();
998
999        let task = ProspectiveTask::new(
1000            "test-user".to_string(),
1001            "Test task".to_string(),
1002            ProspectiveTrigger::AfterDuration {
1003                seconds: 0,
1004                from: Utc::now(),
1005            },
1006        );
1007
1008        store.store(&task).unwrap();
1009
1010        // Mark as triggered
1011        store.mark_triggered("test-user", &task.id).unwrap();
1012        let updated = store.get("test-user", &task.id).unwrap().unwrap();
1013        assert_eq!(updated.status, ProspectiveTaskStatus::Triggered);
1014        assert!(updated.triggered_at.is_some());
1015
1016        // Mark as dismissed
1017        store.mark_dismissed("test-user", &task.id).unwrap();
1018        let dismissed = store.get("test-user", &task.id).unwrap().unwrap();
1019        assert_eq!(dismissed.status, ProspectiveTaskStatus::Dismissed);
1020        assert!(dismissed.dismissed_at.is_some());
1021    }
1022
1023    #[test]
1024    fn test_delete() {
1025        let (_temp, store) = setup_store();
1026
1027        let task = ProspectiveTask::new(
1028            "test-user".to_string(),
1029            "To delete".to_string(),
1030            ProspectiveTrigger::AfterDuration {
1031                seconds: 3600,
1032                from: Utc::now(),
1033            },
1034        );
1035
1036        store.store(&task).unwrap();
1037        assert!(store.get("test-user", &task.id).unwrap().is_some());
1038
1039        let deleted = store.delete("test-user", &task.id).unwrap();
1040        assert!(deleted);
1041        assert!(store.get("test-user", &task.id).unwrap().is_none());
1042    }
1043}