Skip to main content

zeph_orchestration/
plan_cache.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Plan template caching for the LLM planner.
5//!
6//! Caches completed `TaskGraph` plans as reusable `PlanTemplate` skeletons.
7//! On subsequent semantically similar goals, retrieves the closest template
8//! and uses a lightweight LLM adaptation call instead of full decomposition.
9
10use blake3;
11use serde::{Deserialize, Serialize};
12use sqlx::SqlitePool;
13use zeph_config::PlanCacheConfig;
14use zeph_llm::provider::{LlmProvider, Message, Role};
15
16use super::dag;
17use super::error::OrchestrationError;
18use super::graph::TaskGraph;
19use super::planner::{PlannerResponse, convert_response_pub};
20use zeph_subagent::SubAgentDef;
21
22/// Structural skeleton of a single task, stripped of all runtime state.
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct TemplateTask {
25    pub title: String,
26    pub description: String,
27    #[serde(default, skip_serializing_if = "Option::is_none")]
28    pub agent_hint: Option<String>,
29    #[serde(default, skip_serializing_if = "Vec::is_empty")]
30    pub depends_on: Vec<String>,
31    #[serde(default, skip_serializing_if = "Option::is_none")]
32    pub failure_strategy: Option<String>,
33    /// Stable kebab-case `task_id` assigned during template extraction.
34    pub task_id: String,
35}
36
37/// Reusable plan skeleton extracted from a successfully completed `TaskGraph`.
38///
39/// Contains only the structural information (task titles, descriptions,
40/// dependencies, agent hints) — all runtime state (status, results,
41/// `retry_count`, `assigned_agent`, timestamps) is stripped.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct PlanTemplate {
44    /// Normalized goal text used for exact-match fallback.
45    pub goal: String,
46    /// Structural task skeleton.
47    pub tasks: Vec<TemplateTask>,
48}
49
50impl PlanTemplate {
51    /// Extract a `PlanTemplate` from a completed `TaskGraph`.
52    ///
53    /// # Errors
54    ///
55    /// Returns `OrchestrationError::PlanningFailed` if the graph has no tasks.
56    pub fn from_task_graph(graph: &TaskGraph) -> Result<Self, OrchestrationError> {
57        if graph.tasks.is_empty() {
58            return Err(OrchestrationError::PlanningFailed(
59                "cannot cache a plan with zero tasks".into(),
60            ));
61        }
62
63        // Build task_id strings indexed by position for depends_on reconstruction.
64        let id_to_slug: Vec<String> = graph
65            .tasks
66            .iter()
67            .map(|n| slugify_title(&n.title, n.id.as_u32()))
68            .collect();
69
70        let tasks = graph
71            .tasks
72            .iter()
73            .enumerate()
74            .map(|(i, node)| TemplateTask {
75                title: node.title.clone(),
76                description: node.description.clone(),
77                agent_hint: node.agent_hint.clone(),
78                depends_on: node
79                    .depends_on
80                    .iter()
81                    .map(|dep| id_to_slug[dep.index()].clone())
82                    .collect(),
83                failure_strategy: node.failure_strategy.map(|fs| fs.to_string()),
84                task_id: id_to_slug[i].clone(),
85            })
86            .collect();
87
88        Ok(Self {
89            goal: normalize_goal(&graph.goal),
90            tasks,
91        })
92    }
93}
94
95/// Normalize goal text: trim + collapse internal whitespace + lowercase.
96///
97/// Used consistently for hash computation and embedding input so that
98/// trivially different goal strings (capitalization, extra spaces) map
99/// to the same cache entry.
100#[must_use]
101pub fn normalize_goal(text: &str) -> String {
102    let trimmed = text.trim();
103    let mut result = String::with_capacity(trimmed.len());
104    let mut prev_space = false;
105    for ch in trimmed.chars() {
106        if ch.is_whitespace() {
107            if !prev_space && !result.is_empty() {
108                result.push(' ');
109                prev_space = true;
110            }
111        } else {
112            for lc in ch.to_lowercase() {
113                result.push(lc);
114            }
115            prev_space = false;
116        }
117    }
118    result
119}
120
121/// Compute a BLAKE3 hex hash of a normalized goal string.
122#[must_use]
123pub fn goal_hash(normalized: &str) -> String {
124    blake3::hash(normalized.as_bytes()).to_hex().to_string()
125}
126
127/// Convert a task title + index into a stable kebab-case `task_id` for template use.
128fn slugify_title(title: &str, idx: u32) -> String {
129    let slug: String = title
130        .chars()
131        .map(|c| {
132            if c.is_ascii_alphanumeric() {
133                c.to_ascii_lowercase()
134            } else {
135                '-'
136            }
137        })
138        .collect::<String>()
139        .split('-')
140        .filter(|s| !s.is_empty())
141        .collect::<Vec<_>>()
142        .join("-");
143
144    // Cap at 32 chars, then append index to ensure uniqueness.
145    let capped = if slug.len() > 32 { &slug[..32] } else { &slug };
146    // Trim trailing dashes after cap.
147    let capped = capped.trim_end_matches('-');
148    if capped.is_empty() {
149        format!("task-{idx}")
150    } else {
151        format!("{capped}-{idx}")
152    }
153}
154
155/// Serialize an `f32` slice to a `Vec<u8>` BLOB using explicit little-endian encoding.
156fn embedding_to_blob(embedding: &[f32]) -> Vec<u8> {
157    embedding.iter().flat_map(|f| f.to_le_bytes()).collect()
158}
159
160/// Deserialize an `f32` slice from a BLOB using chunk-based little-endian decoding.
161///
162/// Returns `None` and logs a warning if the BLOB length is not a multiple of 4.
163/// Does not require aligned memory — safe for `Vec<u8>` returned by `SQLite`.
164fn blob_to_embedding(blob: &[u8]) -> Option<Vec<f32>> {
165    if !blob.len().is_multiple_of(4) {
166        tracing::warn!(
167            len = blob.len(),
168            "plan cache: embedding blob length not a multiple of 4"
169        );
170        return None;
171    }
172    Some(
173        blob.chunks_exact(4)
174            .map(|chunk| f32::from_le_bytes(chunk.try_into().expect("chunk is exactly 4 bytes")))
175            .collect(),
176    )
177}
178
179fn unix_now() -> i64 {
180    #[allow(clippy::cast_possible_wrap)]
181    {
182        std::time::SystemTime::now()
183            .duration_since(std::time::UNIX_EPOCH)
184            .unwrap_or_default()
185            .as_secs() as i64
186    }
187}
188
189/// Error type for plan cache operations.
190#[derive(Debug, thiserror::Error)]
191pub enum PlanCacheError {
192    #[error("database error: {0}")]
193    Database(#[from] sqlx::Error),
194    #[error("serialization error: {0}")]
195    Serialization(#[from] serde_json::Error),
196    #[error("plan template extraction failed: {0}")]
197    Extraction(String),
198}
199
200/// Plan template cache backed by `SQLite` with in-process cosine similarity search.
201///
202/// Stores embeddings as BLOB columns and computes cosine similarity in-process
203/// (same pattern as `ResponseCache`). Graceful degradation: all failures are
204/// logged as WARN and never block the planning critical path.
205pub struct PlanCache {
206    pool: SqlitePool,
207    config: PlanCacheConfig,
208}
209
210impl PlanCache {
211    /// Create a new `PlanCache` and invalidate stale embeddings for the given model.
212    ///
213    /// # Errors
214    ///
215    /// Returns `PlanCacheError` if the stale embedding invalidation query fails.
216    pub async fn new(
217        pool: SqlitePool,
218        config: PlanCacheConfig,
219        current_embedding_model: &str,
220    ) -> Result<Self, PlanCacheError> {
221        let cache = Self { pool, config };
222        cache
223            .invalidate_stale_embeddings(current_embedding_model)
224            .await?;
225        Ok(cache)
226    }
227
228    /// NULL-ify embeddings stored under a different model to prevent cross-model false hits.
229    ///
230    /// # Errors
231    ///
232    /// Returns `PlanCacheError::Database` on query failure.
233    async fn invalidate_stale_embeddings(&self, current_model: &str) -> Result<(), PlanCacheError> {
234        let affected = sqlx::query(
235            "UPDATE plan_cache SET embedding = NULL, embedding_model = NULL \
236             WHERE embedding IS NOT NULL AND embedding_model != ?",
237        )
238        .bind(current_model)
239        .execute(&self.pool)
240        .await?
241        .rows_affected();
242
243        if affected > 0 {
244            tracing::info!(
245                rows = affected,
246                current_model,
247                "plan cache: invalidated stale embeddings for model change"
248            );
249        }
250        Ok(())
251    }
252
253    /// Find the most similar cached plan template for the given goal embedding.
254    ///
255    /// Fetches all rows with matching `embedding_model`, computes cosine similarity
256    /// in-process, and returns the best match if it meets `similarity_threshold`.
257    ///
258    /// Also updates `last_accessed_at` on a hit.
259    ///
260    /// # Errors
261    ///
262    /// Returns `PlanCacheError::Database` on query failure or
263    /// `PlanCacheError::Serialization` on template JSON deserialization failure.
264    pub async fn find_similar(
265        &self,
266        goal_embedding: &[f32],
267        embedding_model: &str,
268    ) -> Result<Option<(PlanTemplate, f32)>, PlanCacheError> {
269        let rows: Vec<(String, String, Vec<u8>)> = sqlx::query_as(
270            "SELECT id, template, embedding FROM plan_cache \
271             WHERE embedding IS NOT NULL AND embedding_model = ? \
272             ORDER BY last_accessed_at DESC LIMIT ?",
273        )
274        .bind(embedding_model)
275        .bind(self.config.max_templates)
276        .fetch_all(&self.pool)
277        .await?;
278
279        let mut best_score = -1.0_f32;
280        let mut best_id: Option<String> = None;
281        let mut best_template_json: Option<String> = None;
282
283        for (id, template_json, blob) in rows {
284            if let Some(stored) = blob_to_embedding(&blob) {
285                let score = zeph_memory::cosine_similarity(goal_embedding, &stored);
286                if score > best_score {
287                    best_score = score;
288                    best_id = Some(id);
289                    best_template_json = Some(template_json);
290                }
291            }
292        }
293
294        if best_score >= self.config.similarity_threshold
295            && let (Some(id), Some(json)) = (best_id, best_template_json)
296        {
297            // Update last_accessed_at on hit.
298            let now = unix_now();
299            if let Err(e) = sqlx::query(
300                "UPDATE plan_cache SET last_accessed_at = ?, adapted_count = adapted_count + 1 \
301                 WHERE id = ?",
302            )
303            .bind(now)
304            .bind(&id)
305            .execute(&self.pool)
306            .await
307            {
308                tracing::warn!(error = %e, "plan cache: failed to update last_accessed_at");
309            }
310            let template: PlanTemplate = serde_json::from_str(&json)?;
311            return Ok(Some((template, best_score)));
312        }
313
314        Ok(None)
315    }
316
317    /// Store a completed plan as a reusable template.
318    ///
319    /// Extracts a `PlanTemplate` from the `TaskGraph`, serializes it to JSON,
320    /// and upserts into `SQLite` using `INSERT OR REPLACE ON CONFLICT(goal_hash)`.
321    /// Deduplication is enforced by the `UNIQUE` constraint on `goal_hash`.
322    ///
323    /// # Errors
324    ///
325    /// Returns `PlanCacheError` on extraction, serialization, or database failure.
326    pub async fn cache_plan(
327        &self,
328        graph: &TaskGraph,
329        goal_embedding: &[f32],
330        embedding_model: &str,
331    ) -> Result<(), PlanCacheError> {
332        let template = PlanTemplate::from_task_graph(graph)
333            .map_err(|e| PlanCacheError::Extraction(e.to_string()))?;
334
335        let normalized = normalize_goal(&graph.goal);
336        let hash = goal_hash(&normalized);
337        let template_json = serde_json::to_string(&template)?;
338        let task_count = i64::try_from(template.tasks.len()).unwrap_or(i64::MAX);
339        let now = unix_now();
340        let id = uuid::Uuid::new_v4().to_string();
341        let blob = embedding_to_blob(goal_embedding);
342
343        sqlx::query(
344            "INSERT INTO plan_cache \
345             (id, goal_hash, goal_text, template, task_count, success_count, adapted_count, \
346              embedding, embedding_model, created_at, last_accessed_at) \
347             VALUES (?, ?, ?, ?, ?, 1, 0, ?, ?, ?, ?) \
348             ON CONFLICT(goal_hash) DO UPDATE SET \
349               success_count = success_count + 1, \
350               template = excluded.template, \
351               task_count = excluded.task_count, \
352               embedding = excluded.embedding, \
353               embedding_model = excluded.embedding_model, \
354               last_accessed_at = excluded.last_accessed_at",
355        )
356        .bind(&id)
357        .bind(&hash)
358        .bind(&normalized)
359        .bind(&template_json)
360        .bind(task_count)
361        .bind(&blob)
362        .bind(embedding_model)
363        .bind(now)
364        .bind(now)
365        .execute(&self.pool)
366        .await?;
367
368        // Evict after inserting to keep within size bounds.
369        if let Err(e) = self.evict().await {
370            tracing::warn!(error = %e, "plan cache: eviction failed after cache_plan");
371        }
372
373        Ok(())
374    }
375
376    /// Run TTL + size-cap LRU eviction.
377    ///
378    /// Phase 1: Delete rows where `last_accessed_at < now - ttl_days * 86400`.
379    /// Phase 2: If count exceeds `max_templates`, delete the least-recently-accessed rows.
380    ///
381    /// Returns the total number of rows deleted.
382    ///
383    /// # Errors
384    ///
385    /// Returns `PlanCacheError::Database` on query failure.
386    pub async fn evict(&self) -> Result<u32, PlanCacheError> {
387        let now = unix_now();
388        let ttl_secs = i64::from(self.config.ttl_days) * 86_400;
389        let cutoff = now.saturating_sub(ttl_secs);
390
391        let ttl_deleted = sqlx::query("DELETE FROM plan_cache WHERE last_accessed_at < ?")
392            .bind(cutoff)
393            .execute(&self.pool)
394            .await?
395            .rows_affected();
396
397        // Count remaining rows.
398        let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM plan_cache")
399            .fetch_one(&self.pool)
400            .await?;
401
402        let max = i64::from(self.config.max_templates);
403        let lru_deleted = if count > max {
404            let excess = count - max;
405            sqlx::query(
406                "DELETE FROM plan_cache WHERE id IN \
407                 (SELECT id FROM plan_cache ORDER BY last_accessed_at ASC LIMIT ?)",
408            )
409            .bind(excess)
410            .execute(&self.pool)
411            .await?
412            .rows_affected()
413        } else {
414            0
415        };
416
417        let total = ttl_deleted + lru_deleted;
418        if total > 0 {
419            tracing::debug!(ttl_deleted, lru_deleted, "plan cache: eviction complete");
420        }
421        Ok(u32::try_from(total).unwrap_or(u32::MAX))
422    }
423}
424
425/// Wrapper that checks the plan cache before calling the planner.
426///
427/// On a cache hit, calls `adapt_plan` with the cached template and the given
428/// `LlmProvider`. Falls back to full `planner.plan()` on any failure.
429///
430/// # Errors
431///
432/// Returns `OrchestrationError` from the planner on full-decomposition fallback.
433#[allow(clippy::too_many_arguments)]
434pub async fn plan_with_cache<P>(
435    planner: &P,
436    plan_cache: Option<&PlanCache>,
437    provider: &impl LlmProvider,
438    embedding: Option<&[f32]>,
439    embedding_model: &str,
440    goal: &str,
441    available_agents: &[SubAgentDef],
442    max_tasks: u32,
443) -> Result<(TaskGraph, Option<(u64, u64)>), OrchestrationError>
444where
445    P: super::planner::Planner,
446{
447    if let (Some(cache), Some(emb)) = (plan_cache, embedding)
448        && cache.config.enabled
449    {
450        match cache.find_similar(emb, embedding_model).await {
451            Ok(Some((template, score))) => {
452                tracing::info!(
453                    similarity = score,
454                    tasks = template.tasks.len(),
455                    "plan cache hit, adapting template"
456                );
457                match adapt_plan(provider, goal, &template, available_agents, max_tasks).await {
458                    Ok(result) => return Ok(result),
459                    Err(e) => {
460                        tracing::warn!(
461                            error = %e,
462                            "plan cache: adaptation failed, falling back to full decomposition"
463                        );
464                    }
465                }
466            }
467            Ok(None) => {
468                tracing::debug!("plan cache miss");
469            }
470            Err(e) => {
471                tracing::warn!(error = %e, "plan cache: find_similar failed, using full decomposition");
472            }
473        }
474    }
475
476    planner.plan(goal, available_agents).await
477}
478
479/// Build an adaptation prompt and call the LLM to produce a `TaskGraph` adapted
480/// from a cached template for the new goal.
481///
482/// Uses `LlmProvider::chat_typed` with the same `PlannerResponse` schema as the
483/// full planner, so the existing `convert_response + dag::validate` pipeline applies.
484///
485/// # Errors
486///
487/// Returns `OrchestrationError::PlanningFailed` if the LLM call fails or the
488/// adapted graph fails DAG validation.
489async fn adapt_plan(
490    provider: &impl LlmProvider,
491    goal: &str,
492    template: &PlanTemplate,
493    available_agents: &[SubAgentDef],
494    max_tasks: u32,
495) -> Result<(TaskGraph, Option<(u64, u64)>), OrchestrationError> {
496    use zeph_subagent::ToolPolicy;
497
498    let agent_catalog = available_agents
499        .iter()
500        .map(|a| {
501            let tools = match &a.tools {
502                ToolPolicy::AllowList(list) => list.join(", "),
503                ToolPolicy::DenyList(excluded) => {
504                    format!("all except: [{}]", excluded.join(", "))
505                }
506                ToolPolicy::InheritAll => "all".to_string(),
507            };
508            format!(
509                "- name: \"{}\", description: \"{}\", tools: [{}]",
510                a.name, a.description, tools
511            )
512        })
513        .collect::<Vec<_>>()
514        .join("\n");
515
516    let template_json = serde_json::to_string(&template.tasks)
517        .map_err(|e| OrchestrationError::PlanningFailed(e.to_string()))?;
518
519    let system = format!(
520        "You are a task planner. A cached plan template exists for a similar goal. \
521         Adapt it for the new goal by adjusting task descriptions and adding or removing \
522         tasks as needed. Keep the same JSON structure.\n\n\
523         Available agents:\n{agent_catalog}\n\n\
524         Rules:\n\
525         - Each task must have a unique task_id (short, descriptive, kebab-case: [a-z0-9-]).\n\
526         - Specify dependencies using task_id strings in depends_on.\n\
527         - Do not create more than {max_tasks} tasks.\n\
528         - failure_strategy is optional: \"abort\", \"retry\", \"skip\", \"ask\"."
529    );
530
531    let user = format!(
532        "New goal:\n{goal}\n\nCached template (for similar goal \"{}\"):\n{template_json}\n\n\
533         Adapt the template for the new goal. Return JSON: {{\"tasks\": [...]}}",
534        template.goal
535    );
536
537    let messages = vec![
538        Message::from_legacy(Role::System, system),
539        Message::from_legacy(Role::User, user),
540    ];
541
542    let response: PlannerResponse = provider
543        .chat_typed(&messages)
544        .await
545        .map_err(|e| OrchestrationError::PlanningFailed(e.to_string()))?;
546
547    let usage = provider.last_usage();
548
549    let graph = convert_response_pub(response, goal, available_agents, max_tasks)?;
550
551    dag::validate(&graph.tasks, max_tasks as usize)?;
552
553    Ok((graph, usage))
554}
555
556#[cfg(test)]
557mod tests {
558    use super::super::graph::{TaskId, TaskNode};
559    use super::*;
560    use zeph_memory::sqlite::SqliteStore;
561
562    async fn test_pool() -> SqlitePool {
563        let store = SqliteStore::new(":memory:").await.unwrap();
564        store.pool().clone()
565    }
566
567    async fn test_cache(pool: SqlitePool) -> PlanCache {
568        PlanCache::new(pool, PlanCacheConfig::default(), "test-model")
569            .await
570            .unwrap()
571    }
572
573    fn make_graph(goal: &str, tasks: &[(&str, &str, &[u32])]) -> TaskGraph {
574        let mut graph = TaskGraph::new(goal);
575        for (i, (title, desc, deps)) in tasks.iter().enumerate() {
576            let mut node = TaskNode::new(i as u32, *title, *desc);
577            node.depends_on = deps.iter().map(|&d| TaskId(d)).collect();
578            graph.tasks.push(node);
579        }
580        graph
581    }
582
583    // --- normalize_goal tests ---
584
585    #[test]
586    fn normalize_trims_and_lowercases() {
587        assert_eq!(normalize_goal("  Hello World  "), "hello world");
588    }
589
590    #[test]
591    fn normalize_collapses_internal_whitespace() {
592        assert_eq!(normalize_goal("hello   world"), "hello world");
593    }
594
595    #[test]
596    fn normalize_empty_string() {
597        assert_eq!(normalize_goal(""), "");
598    }
599
600    #[test]
601    fn normalize_whitespace_only() {
602        assert_eq!(normalize_goal("   "), "");
603    }
604
605    // --- goal_hash tests ---
606
607    #[test]
608    fn goal_hash_is_deterministic() {
609        let h1 = goal_hash("deploy service");
610        let h2 = goal_hash("deploy service");
611        assert_eq!(h1, h2);
612    }
613
614    #[test]
615    fn goal_hash_differs_for_different_goals() {
616        assert_ne!(goal_hash("deploy service"), goal_hash("build artifact"));
617    }
618
619    #[test]
620    fn goal_hash_nonempty() {
621        assert!(!goal_hash("goal").is_empty());
622    }
623
624    // --- PlanTemplate extraction tests ---
625
626    #[test]
627    fn template_from_empty_graph_returns_error() {
628        let graph = TaskGraph::new("goal");
629        assert!(PlanTemplate::from_task_graph(&graph).is_err());
630    }
631
632    #[test]
633    fn template_strips_runtime_fields() {
634        use crate::graph::TaskStatus;
635        let mut graph = make_graph("goal", &[("Fetch data", "Download it", &[])]);
636        graph.tasks[0].status = TaskStatus::Completed;
637        graph.tasks[0].retry_count = 3;
638        graph.tasks[0].assigned_agent = Some("agent-x".to_string());
639        let template = PlanTemplate::from_task_graph(&graph).unwrap();
640        // Template only has structural data — no TaskStatus, retry_count, etc.
641        assert_eq!(template.tasks[0].title, "Fetch data");
642        assert_eq!(template.tasks[0].description, "Download it");
643    }
644
645    #[test]
646    fn template_preserves_dependencies() {
647        let graph = make_graph("goal", &[("Task A", "do A", &[]), ("Task B", "do B", &[0])]);
648        let template = PlanTemplate::from_task_graph(&graph).unwrap();
649        assert_eq!(template.tasks.len(), 2);
650        assert!(template.tasks[0].depends_on.is_empty());
651        assert_eq!(template.tasks[1].depends_on.len(), 1);
652        assert_eq!(template.tasks[1].depends_on[0], template.tasks[0].task_id);
653    }
654
655    #[test]
656    fn template_serde_roundtrip() {
657        let graph = make_graph("goal", &[("Step one", "do step one", &[])]);
658        let template = PlanTemplate::from_task_graph(&graph).unwrap();
659        let json = serde_json::to_string(&template).unwrap();
660        let restored: PlanTemplate = serde_json::from_str(&json).unwrap();
661        assert_eq!(template.tasks[0].title, restored.tasks[0].title);
662        assert_eq!(template.goal, restored.goal);
663    }
664
665    // --- BLOB serialization tests ---
666
667    #[test]
668    fn embedding_blob_roundtrip() {
669        let embedding = vec![1.0_f32, 0.5, 0.25, -1.0];
670        let blob = embedding_to_blob(&embedding);
671        let restored = blob_to_embedding(&blob).unwrap();
672        assert_eq!(embedding, restored);
673    }
674
675    #[test]
676    fn blob_to_embedding_odd_length_returns_none() {
677        let bad_blob = vec![0u8; 5]; // not a multiple of 4
678        assert!(blob_to_embedding(&bad_blob).is_none());
679    }
680
681    // --- PlanCache integration tests ---
682
683    #[tokio::test]
684    async fn cache_miss_on_empty_cache() {
685        let pool = test_pool().await;
686        let cache = test_cache(pool).await;
687        let result = cache
688            .find_similar(&[1.0, 0.0, 0.0], "test-model")
689            .await
690            .unwrap();
691        assert!(result.is_none());
692    }
693
694    #[tokio::test]
695    async fn cache_store_and_hit() {
696        let pool = test_pool().await;
697        let mut config = PlanCacheConfig::default();
698        config.similarity_threshold = 0.9;
699        let cache = PlanCache::new(pool, config, "test-model").await.unwrap();
700
701        let graph = make_graph("deploy service", &[("Build", "build it", &[])]);
702        let embedding = vec![1.0_f32, 0.0, 0.0];
703        cache
704            .cache_plan(&graph, &embedding, "test-model")
705            .await
706            .unwrap();
707
708        // Same embedding should hit.
709        let result = cache
710            .find_similar(&[1.0, 0.0, 0.0], "test-model")
711            .await
712            .unwrap();
713        assert!(result.is_some());
714        let (template, score) = result.unwrap();
715        assert!((score - 1.0).abs() < 1e-5);
716        assert_eq!(template.tasks.len(), 1);
717    }
718
719    #[tokio::test]
720    async fn cache_miss_on_dissimilar_goal() {
721        let pool = test_pool().await;
722        let mut config = PlanCacheConfig::default();
723        config.similarity_threshold = 0.9;
724        let cache = PlanCache::new(pool, config, "test-model").await.unwrap();
725
726        let graph = make_graph("goal a", &[("Task", "do it", &[])]);
727        cache
728            .cache_plan(&graph, &[1.0_f32, 0.0, 0.0], "test-model")
729            .await
730            .unwrap();
731
732        // Orthogonal vector — should not hit at threshold 0.9.
733        let result = cache
734            .find_similar(&[0.0, 1.0, 0.0], "test-model")
735            .await
736            .unwrap();
737        assert!(result.is_none());
738    }
739
740    #[tokio::test]
741    async fn deduplication_increments_success_count() {
742        let pool = test_pool().await;
743        let cache = test_cache(pool.clone()).await;
744
745        let graph = make_graph("same goal", &[("Task", "do it", &[])]);
746        let emb = vec![1.0_f32, 0.0];
747
748        cache.cache_plan(&graph, &emb, "test-model").await.unwrap();
749        cache.cache_plan(&graph, &emb, "test-model").await.unwrap();
750
751        // Only one row due to UNIQUE goal_hash.
752        let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM plan_cache")
753            .fetch_one(&pool)
754            .await
755            .unwrap();
756        assert_eq!(count, 1);
757
758        let success: i64 = sqlx::query_scalar("SELECT success_count FROM plan_cache")
759            .fetch_one(&pool)
760            .await
761            .unwrap();
762        assert_eq!(success, 2);
763    }
764
765    #[tokio::test]
766    async fn eviction_removes_ttl_expired_rows() {
767        let pool = test_pool().await;
768        let mut config = PlanCacheConfig::default();
769        // TTL of 0 days means everything is immediately expired.
770        config.ttl_days = 0;
771        let cache = PlanCache::new(pool.clone(), config, "test-model")
772            .await
773            .unwrap();
774
775        // Insert a row by bypassing the API to set last_accessed_at in the past.
776        let now = unix_now() - 1;
777        sqlx::query(
778            "INSERT INTO plan_cache \
779             (id, goal_hash, goal_text, template, task_count, created_at, last_accessed_at) \
780             VALUES (?, ?, ?, ?, ?, ?, ?)",
781        )
782        .bind("test-id")
783        .bind("hash-1")
784        .bind("goal")
785        .bind("{\"goal\":\"goal\",\"tasks\":[]}")
786        .bind(0_i64)
787        .bind(now)
788        .bind(now)
789        .execute(&pool)
790        .await
791        .unwrap();
792
793        let deleted = cache.evict().await.unwrap();
794        assert!(deleted >= 1);
795
796        let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM plan_cache")
797            .fetch_one(&pool)
798            .await
799            .unwrap();
800        assert_eq!(count, 0);
801    }
802
803    #[tokio::test]
804    async fn eviction_lru_when_over_max() {
805        let pool = test_pool().await;
806        let mut config = PlanCacheConfig::default();
807        config.max_templates = 2;
808        config.ttl_days = 365;
809        let cache = PlanCache::new(pool.clone(), config, "test-model")
810            .await
811            .unwrap();
812
813        let now = unix_now();
814        // Insert 3 rows with different last_accessed_at, all recent enough to survive TTL.
815        for i in 0..3_i64 {
816            sqlx::query(
817                "INSERT INTO plan_cache \
818                 (id, goal_hash, goal_text, template, task_count, created_at, last_accessed_at) \
819                 VALUES (?, ?, ?, ?, ?, ?, ?)",
820            )
821            .bind(format!("id-{i}"))
822            .bind(format!("hash-{i}"))
823            .bind(format!("goal-{i}"))
824            .bind("{\"goal\":\"g\",\"tasks\":[]}")
825            .bind(0_i64)
826            .bind(now)
827            .bind(now + i) // i=0 is least recently accessed, i=2 most recent
828            .execute(&pool)
829            .await
830            .unwrap();
831        }
832
833        let deleted = cache.evict().await.unwrap();
834        assert_eq!(deleted, 1);
835
836        // The row with smallest last_accessed_at (id-0) should be gone.
837        let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM plan_cache")
838            .fetch_one(&pool)
839            .await
840            .unwrap();
841        assert_eq!(count, 2);
842    }
843
844    #[tokio::test]
845    async fn stale_embedding_invalidated_on_new() {
846        let pool = test_pool().await;
847        let now = unix_now();
848
849        // Insert a row with "old-model" embedding.
850        let emb = embedding_to_blob(&[1.0_f32, 0.0]);
851        sqlx::query(
852            "INSERT INTO plan_cache \
853             (id, goal_hash, goal_text, template, task_count, embedding, embedding_model, \
854              created_at, last_accessed_at) \
855             VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
856        )
857        .bind("id-old")
858        .bind("hash-old")
859        .bind("goal old")
860        .bind("{\"goal\":\"g\",\"tasks\":[]}")
861        .bind(0_i64)
862        .bind(&emb)
863        .bind("old-model")
864        .bind(now)
865        .bind(now)
866        .execute(&pool)
867        .await
868        .unwrap();
869
870        // Constructing cache with "new-model" should invalidate the old embedding.
871        let _cache = PlanCache::new(pool.clone(), PlanCacheConfig::default(), "new-model")
872            .await
873            .unwrap();
874
875        let model: Option<String> =
876            sqlx::query_scalar("SELECT embedding_model FROM plan_cache WHERE id = 'id-old'")
877                .fetch_one(&pool)
878                .await
879                .unwrap();
880        assert!(model.is_none(), "stale embedding_model should be NULL");
881
882        let emb_col: Option<Vec<u8>> =
883            sqlx::query_scalar("SELECT embedding FROM plan_cache WHERE id = 'id-old'")
884                .fetch_one(&pool)
885                .await
886                .unwrap();
887        assert!(emb_col.is_none(), "stale embedding should be NULL");
888    }
889
890    #[tokio::test]
891    async fn disabled_cache_not_used_in_plan_with_cache() {
892        use zeph_llm::mock::MockProvider;
893
894        let pool = test_pool().await;
895        let config = PlanCacheConfig::default(); // enabled = false
896        let cache = PlanCache::new(pool, config, "test-model").await.unwrap();
897
898        let graph_json = r#"{"tasks": [
899            {"task_id": "t1", "title": "Task", "description": "do it", "depends_on": []}
900        ]}"#
901        .to_string();
902
903        let provider = MockProvider::with_responses(vec![graph_json.clone()]);
904        use crate::planner::LlmPlanner;
905        use zeph_config::OrchestrationConfig;
906        let planner = LlmPlanner::new(
907            MockProvider::with_responses(vec![graph_json]),
908            &OrchestrationConfig::default(),
909        );
910
911        let (graph, _) = plan_with_cache(
912            &planner,
913            Some(&cache),
914            &provider,
915            Some(&[1.0_f32, 0.0]),
916            "test-model",
917            "do something",
918            &[],
919            20,
920        )
921        .await
922        .unwrap();
923
924        assert_eq!(graph.tasks.len(), 1);
925    }
926
927    #[tokio::test]
928    async fn plan_with_cache_with_none_embedding_skips_cache() {
929        use crate::planner::LlmPlanner;
930        use zeph_config::OrchestrationConfig;
931        use zeph_llm::mock::MockProvider;
932
933        let pool = test_pool().await;
934        let mut config = PlanCacheConfig::default();
935        config.enabled = true;
936        config.similarity_threshold = 0.5;
937        let cache = PlanCache::new(pool, config, "test-model").await.unwrap();
938
939        // Pre-populate cache with a similar goal.
940        let graph = make_graph("deploy service", &[("Build", "build it", &[])]);
941        cache
942            .cache_plan(&graph, &[1.0_f32, 0.0], "test-model")
943            .await
944            .unwrap();
945
946        let graph_json = r#"{"tasks": [
947            {"task_id": "fallback-task-0", "title": "Fallback", "description": "planner fallback", "depends_on": []}
948        ]}"#
949        .to_string();
950
951        let provider = MockProvider::with_responses(vec![graph_json.clone()]);
952        let planner = LlmPlanner::new(
953            MockProvider::with_responses(vec![graph_json]),
954            &OrchestrationConfig::default(),
955        );
956
957        // embedding = None → must skip cache and call planner.
958        let (result_graph, _) = plan_with_cache(
959            &planner,
960            Some(&cache),
961            &provider,
962            None, // no embedding provided
963            "test-model",
964            "deploy service",
965            &[],
966            20,
967        )
968        .await
969        .unwrap();
970
971        assert_eq!(result_graph.tasks[0].title, "Fallback");
972    }
973
974    #[tokio::test]
975    async fn adapt_plan_error_fallback_to_full_decomposition() {
976        use crate::planner::LlmPlanner;
977        use zeph_config::OrchestrationConfig;
978        use zeph_llm::mock::MockProvider;
979
980        let pool = test_pool().await;
981        let mut config = PlanCacheConfig::default();
982        config.enabled = true;
983        config.similarity_threshold = 0.5;
984        let cache = PlanCache::new(pool, config, "test-model").await.unwrap();
985
986        // Pre-populate cache with matching embedding.
987        let graph = make_graph("deploy service", &[("Build", "build it", &[])]);
988        cache
989            .cache_plan(&graph, &[1.0_f32, 0.0], "test-model")
990            .await
991            .unwrap();
992
993        // Provider for adapt_plan returns invalid JSON — adaptation fails.
994        let bad_provider = MockProvider::with_responses(vec!["not valid json".to_string()]);
995
996        // Planner (fallback path) returns a valid response.
997        let fallback_json = r#"{"tasks": [
998            {"task_id": "fallback-0", "title": "Fallback Task", "description": "via planner", "depends_on": []}
999        ]}"#
1000        .to_string();
1001        let planner = LlmPlanner::new(
1002            MockProvider::with_responses(vec![fallback_json]),
1003            &OrchestrationConfig::default(),
1004        );
1005
1006        let (result_graph, _) = plan_with_cache(
1007            &planner,
1008            Some(&cache),
1009            &bad_provider, // adapt_plan will fail with this provider
1010            Some(&[1.0_f32, 0.0]),
1011            "test-model",
1012            "deploy service",
1013            &[],
1014            20,
1015        )
1016        .await
1017        .unwrap();
1018
1019        // Must return planner fallback result, not error.
1020        assert_eq!(result_graph.tasks[0].title, "Fallback Task");
1021    }
1022}