1use blake3;
11use serde::{Deserialize, Serialize};
12use sqlx::SqlitePool;
13use zeph_config::PlanCacheConfig;
14use zeph_llm::provider::{LlmProvider, Message, Role};
15
16use super::dag;
17use super::error::OrchestrationError;
18use super::graph::TaskGraph;
19use super::planner::{PlannerResponse, convert_response_pub};
20use zeph_subagent::SubAgentDef;
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct TemplateTask {
25 pub title: String,
26 pub description: String,
27 #[serde(default, skip_serializing_if = "Option::is_none")]
28 pub agent_hint: Option<String>,
29 #[serde(default, skip_serializing_if = "Vec::is_empty")]
30 pub depends_on: Vec<String>,
31 #[serde(default, skip_serializing_if = "Option::is_none")]
32 pub failure_strategy: Option<String>,
33 pub task_id: String,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct PlanTemplate {
44 pub goal: String,
46 pub tasks: Vec<TemplateTask>,
48}
49
50impl PlanTemplate {
51 pub fn from_task_graph(graph: &TaskGraph) -> Result<Self, OrchestrationError> {
57 if graph.tasks.is_empty() {
58 return Err(OrchestrationError::PlanningFailed(
59 "cannot cache a plan with zero tasks".into(),
60 ));
61 }
62
63 let id_to_slug: Vec<String> = graph
65 .tasks
66 .iter()
67 .map(|n| slugify_title(&n.title, n.id.as_u32()))
68 .collect();
69
70 let tasks = graph
71 .tasks
72 .iter()
73 .enumerate()
74 .map(|(i, node)| TemplateTask {
75 title: node.title.clone(),
76 description: node.description.clone(),
77 agent_hint: node.agent_hint.clone(),
78 depends_on: node
79 .depends_on
80 .iter()
81 .map(|dep| id_to_slug[dep.index()].clone())
82 .collect(),
83 failure_strategy: node.failure_strategy.map(|fs| fs.to_string()),
84 task_id: id_to_slug[i].clone(),
85 })
86 .collect();
87
88 Ok(Self {
89 goal: normalize_goal(&graph.goal),
90 tasks,
91 })
92 }
93}
94
95#[must_use]
101pub fn normalize_goal(text: &str) -> String {
102 let trimmed = text.trim();
103 let mut result = String::with_capacity(trimmed.len());
104 let mut prev_space = false;
105 for ch in trimmed.chars() {
106 if ch.is_whitespace() {
107 if !prev_space && !result.is_empty() {
108 result.push(' ');
109 prev_space = true;
110 }
111 } else {
112 for lc in ch.to_lowercase() {
113 result.push(lc);
114 }
115 prev_space = false;
116 }
117 }
118 result
119}
120
121#[must_use]
123pub fn goal_hash(normalized: &str) -> String {
124 blake3::hash(normalized.as_bytes()).to_hex().to_string()
125}
126
127fn slugify_title(title: &str, idx: u32) -> String {
129 let slug: String = title
130 .chars()
131 .map(|c| {
132 if c.is_ascii_alphanumeric() {
133 c.to_ascii_lowercase()
134 } else {
135 '-'
136 }
137 })
138 .collect::<String>()
139 .split('-')
140 .filter(|s| !s.is_empty())
141 .collect::<Vec<_>>()
142 .join("-");
143
144 let capped = if slug.len() > 32 { &slug[..32] } else { &slug };
146 let capped = capped.trim_end_matches('-');
148 if capped.is_empty() {
149 format!("task-{idx}")
150 } else {
151 format!("{capped}-{idx}")
152 }
153}
154
155fn embedding_to_blob(embedding: &[f32]) -> Vec<u8> {
157 embedding.iter().flat_map(|f| f.to_le_bytes()).collect()
158}
159
160fn blob_to_embedding(blob: &[u8]) -> Option<Vec<f32>> {
165 if !blob.len().is_multiple_of(4) {
166 tracing::warn!(
167 len = blob.len(),
168 "plan cache: embedding blob length not a multiple of 4"
169 );
170 return None;
171 }
172 Some(
173 blob.chunks_exact(4)
174 .map(|chunk| f32::from_le_bytes(chunk.try_into().expect("chunk is exactly 4 bytes")))
175 .collect(),
176 )
177}
178
179fn unix_now() -> i64 {
180 #[allow(clippy::cast_possible_wrap)]
181 {
182 std::time::SystemTime::now()
183 .duration_since(std::time::UNIX_EPOCH)
184 .unwrap_or_default()
185 .as_secs() as i64
186 }
187}
188
189#[derive(Debug, thiserror::Error)]
191pub enum PlanCacheError {
192 #[error("database error: {0}")]
193 Database(#[from] sqlx::Error),
194 #[error("serialization error: {0}")]
195 Serialization(#[from] serde_json::Error),
196 #[error("plan template extraction failed: {0}")]
197 Extraction(String),
198}
199
200pub struct PlanCache {
206 pool: SqlitePool,
207 config: PlanCacheConfig,
208}
209
210impl PlanCache {
211 pub async fn new(
217 pool: SqlitePool,
218 config: PlanCacheConfig,
219 current_embedding_model: &str,
220 ) -> Result<Self, PlanCacheError> {
221 let cache = Self { pool, config };
222 cache
223 .invalidate_stale_embeddings(current_embedding_model)
224 .await?;
225 Ok(cache)
226 }
227
228 async fn invalidate_stale_embeddings(&self, current_model: &str) -> Result<(), PlanCacheError> {
234 let affected = sqlx::query(
235 "UPDATE plan_cache SET embedding = NULL, embedding_model = NULL \
236 WHERE embedding IS NOT NULL AND embedding_model != ?",
237 )
238 .bind(current_model)
239 .execute(&self.pool)
240 .await?
241 .rows_affected();
242
243 if affected > 0 {
244 tracing::info!(
245 rows = affected,
246 current_model,
247 "plan cache: invalidated stale embeddings for model change"
248 );
249 }
250 Ok(())
251 }
252
253 pub async fn find_similar(
265 &self,
266 goal_embedding: &[f32],
267 embedding_model: &str,
268 ) -> Result<Option<(PlanTemplate, f32)>, PlanCacheError> {
269 let rows: Vec<(String, String, Vec<u8>)> = sqlx::query_as(
270 "SELECT id, template, embedding FROM plan_cache \
271 WHERE embedding IS NOT NULL AND embedding_model = ? \
272 ORDER BY last_accessed_at DESC LIMIT ?",
273 )
274 .bind(embedding_model)
275 .bind(self.config.max_templates)
276 .fetch_all(&self.pool)
277 .await?;
278
279 let mut best_score = -1.0_f32;
280 let mut best_id: Option<String> = None;
281 let mut best_template_json: Option<String> = None;
282
283 for (id, template_json, blob) in rows {
284 if let Some(stored) = blob_to_embedding(&blob) {
285 let score = zeph_memory::cosine_similarity(goal_embedding, &stored);
286 if score > best_score {
287 best_score = score;
288 best_id = Some(id);
289 best_template_json = Some(template_json);
290 }
291 }
292 }
293
294 if best_score >= self.config.similarity_threshold
295 && let (Some(id), Some(json)) = (best_id, best_template_json)
296 {
297 let now = unix_now();
299 if let Err(e) = sqlx::query(
300 "UPDATE plan_cache SET last_accessed_at = ?, adapted_count = adapted_count + 1 \
301 WHERE id = ?",
302 )
303 .bind(now)
304 .bind(&id)
305 .execute(&self.pool)
306 .await
307 {
308 tracing::warn!(error = %e, "plan cache: failed to update last_accessed_at");
309 }
310 let template: PlanTemplate = serde_json::from_str(&json)?;
311 return Ok(Some((template, best_score)));
312 }
313
314 Ok(None)
315 }
316
317 pub async fn cache_plan(
327 &self,
328 graph: &TaskGraph,
329 goal_embedding: &[f32],
330 embedding_model: &str,
331 ) -> Result<(), PlanCacheError> {
332 let template = PlanTemplate::from_task_graph(graph)
333 .map_err(|e| PlanCacheError::Extraction(e.to_string()))?;
334
335 let normalized = normalize_goal(&graph.goal);
336 let hash = goal_hash(&normalized);
337 let template_json = serde_json::to_string(&template)?;
338 let task_count = i64::try_from(template.tasks.len()).unwrap_or(i64::MAX);
339 let now = unix_now();
340 let id = uuid::Uuid::new_v4().to_string();
341 let blob = embedding_to_blob(goal_embedding);
342
343 sqlx::query(
344 "INSERT INTO plan_cache \
345 (id, goal_hash, goal_text, template, task_count, success_count, adapted_count, \
346 embedding, embedding_model, created_at, last_accessed_at) \
347 VALUES (?, ?, ?, ?, ?, 1, 0, ?, ?, ?, ?) \
348 ON CONFLICT(goal_hash) DO UPDATE SET \
349 success_count = success_count + 1, \
350 template = excluded.template, \
351 task_count = excluded.task_count, \
352 embedding = excluded.embedding, \
353 embedding_model = excluded.embedding_model, \
354 last_accessed_at = excluded.last_accessed_at",
355 )
356 .bind(&id)
357 .bind(&hash)
358 .bind(&normalized)
359 .bind(&template_json)
360 .bind(task_count)
361 .bind(&blob)
362 .bind(embedding_model)
363 .bind(now)
364 .bind(now)
365 .execute(&self.pool)
366 .await?;
367
368 if let Err(e) = self.evict().await {
370 tracing::warn!(error = %e, "plan cache: eviction failed after cache_plan");
371 }
372
373 Ok(())
374 }
375
376 pub async fn evict(&self) -> Result<u32, PlanCacheError> {
387 let now = unix_now();
388 let ttl_secs = i64::from(self.config.ttl_days) * 86_400;
389 let cutoff = now.saturating_sub(ttl_secs);
390
391 let ttl_deleted = sqlx::query("DELETE FROM plan_cache WHERE last_accessed_at < ?")
392 .bind(cutoff)
393 .execute(&self.pool)
394 .await?
395 .rows_affected();
396
397 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM plan_cache")
399 .fetch_one(&self.pool)
400 .await?;
401
402 let max = i64::from(self.config.max_templates);
403 let lru_deleted = if count > max {
404 let excess = count - max;
405 sqlx::query(
406 "DELETE FROM plan_cache WHERE id IN \
407 (SELECT id FROM plan_cache ORDER BY last_accessed_at ASC LIMIT ?)",
408 )
409 .bind(excess)
410 .execute(&self.pool)
411 .await?
412 .rows_affected()
413 } else {
414 0
415 };
416
417 let total = ttl_deleted + lru_deleted;
418 if total > 0 {
419 tracing::debug!(ttl_deleted, lru_deleted, "plan cache: eviction complete");
420 }
421 Ok(u32::try_from(total).unwrap_or(u32::MAX))
422 }
423}
424
425#[allow(clippy::too_many_arguments)]
434pub async fn plan_with_cache<P>(
435 planner: &P,
436 plan_cache: Option<&PlanCache>,
437 provider: &impl LlmProvider,
438 embedding: Option<&[f32]>,
439 embedding_model: &str,
440 goal: &str,
441 available_agents: &[SubAgentDef],
442 max_tasks: u32,
443) -> Result<(TaskGraph, Option<(u64, u64)>), OrchestrationError>
444where
445 P: super::planner::Planner,
446{
447 if let (Some(cache), Some(emb)) = (plan_cache, embedding)
448 && cache.config.enabled
449 {
450 match cache.find_similar(emb, embedding_model).await {
451 Ok(Some((template, score))) => {
452 tracing::info!(
453 similarity = score,
454 tasks = template.tasks.len(),
455 "plan cache hit, adapting template"
456 );
457 match adapt_plan(provider, goal, &template, available_agents, max_tasks).await {
458 Ok(result) => return Ok(result),
459 Err(e) => {
460 tracing::warn!(
461 error = %e,
462 "plan cache: adaptation failed, falling back to full decomposition"
463 );
464 }
465 }
466 }
467 Ok(None) => {
468 tracing::debug!("plan cache miss");
469 }
470 Err(e) => {
471 tracing::warn!(error = %e, "plan cache: find_similar failed, using full decomposition");
472 }
473 }
474 }
475
476 planner.plan(goal, available_agents).await
477}
478
479async fn adapt_plan(
490 provider: &impl LlmProvider,
491 goal: &str,
492 template: &PlanTemplate,
493 available_agents: &[SubAgentDef],
494 max_tasks: u32,
495) -> Result<(TaskGraph, Option<(u64, u64)>), OrchestrationError> {
496 use zeph_subagent::ToolPolicy;
497
498 let agent_catalog = available_agents
499 .iter()
500 .map(|a| {
501 let tools = match &a.tools {
502 ToolPolicy::AllowList(list) => list.join(", "),
503 ToolPolicy::DenyList(excluded) => {
504 format!("all except: [{}]", excluded.join(", "))
505 }
506 ToolPolicy::InheritAll => "all".to_string(),
507 };
508 format!(
509 "- name: \"{}\", description: \"{}\", tools: [{}]",
510 a.name, a.description, tools
511 )
512 })
513 .collect::<Vec<_>>()
514 .join("\n");
515
516 let template_json = serde_json::to_string(&template.tasks)
517 .map_err(|e| OrchestrationError::PlanningFailed(e.to_string()))?;
518
519 let system = format!(
520 "You are a task planner. A cached plan template exists for a similar goal. \
521 Adapt it for the new goal by adjusting task descriptions and adding or removing \
522 tasks as needed. Keep the same JSON structure.\n\n\
523 Available agents:\n{agent_catalog}\n\n\
524 Rules:\n\
525 - Each task must have a unique task_id (short, descriptive, kebab-case: [a-z0-9-]).\n\
526 - Specify dependencies using task_id strings in depends_on.\n\
527 - Do not create more than {max_tasks} tasks.\n\
528 - failure_strategy is optional: \"abort\", \"retry\", \"skip\", \"ask\"."
529 );
530
531 let user = format!(
532 "New goal:\n{goal}\n\nCached template (for similar goal \"{}\"):\n{template_json}\n\n\
533 Adapt the template for the new goal. Return JSON: {{\"tasks\": [...]}}",
534 template.goal
535 );
536
537 let messages = vec![
538 Message::from_legacy(Role::System, system),
539 Message::from_legacy(Role::User, user),
540 ];
541
542 let response: PlannerResponse = provider
543 .chat_typed(&messages)
544 .await
545 .map_err(|e| OrchestrationError::PlanningFailed(e.to_string()))?;
546
547 let usage = provider.last_usage();
548
549 let graph = convert_response_pub(response, goal, available_agents, max_tasks)?;
550
551 dag::validate(&graph.tasks, max_tasks as usize)?;
552
553 Ok((graph, usage))
554}
555
556#[cfg(test)]
557mod tests {
558 use super::super::graph::{TaskId, TaskNode};
559 use super::*;
560 use zeph_memory::sqlite::SqliteStore;
561
562 async fn test_pool() -> SqlitePool {
563 let store = SqliteStore::new(":memory:").await.unwrap();
564 store.pool().clone()
565 }
566
567 async fn test_cache(pool: SqlitePool) -> PlanCache {
568 PlanCache::new(pool, PlanCacheConfig::default(), "test-model")
569 .await
570 .unwrap()
571 }
572
573 fn make_graph(goal: &str, tasks: &[(&str, &str, &[u32])]) -> TaskGraph {
574 let mut graph = TaskGraph::new(goal);
575 for (i, (title, desc, deps)) in tasks.iter().enumerate() {
576 let mut node = TaskNode::new(i as u32, *title, *desc);
577 node.depends_on = deps.iter().map(|&d| TaskId(d)).collect();
578 graph.tasks.push(node);
579 }
580 graph
581 }
582
583 #[test]
586 fn normalize_trims_and_lowercases() {
587 assert_eq!(normalize_goal(" Hello World "), "hello world");
588 }
589
590 #[test]
591 fn normalize_collapses_internal_whitespace() {
592 assert_eq!(normalize_goal("hello world"), "hello world");
593 }
594
595 #[test]
596 fn normalize_empty_string() {
597 assert_eq!(normalize_goal(""), "");
598 }
599
600 #[test]
601 fn normalize_whitespace_only() {
602 assert_eq!(normalize_goal(" "), "");
603 }
604
605 #[test]
608 fn goal_hash_is_deterministic() {
609 let h1 = goal_hash("deploy service");
610 let h2 = goal_hash("deploy service");
611 assert_eq!(h1, h2);
612 }
613
614 #[test]
615 fn goal_hash_differs_for_different_goals() {
616 assert_ne!(goal_hash("deploy service"), goal_hash("build artifact"));
617 }
618
619 #[test]
620 fn goal_hash_nonempty() {
621 assert!(!goal_hash("goal").is_empty());
622 }
623
624 #[test]
627 fn template_from_empty_graph_returns_error() {
628 let graph = TaskGraph::new("goal");
629 assert!(PlanTemplate::from_task_graph(&graph).is_err());
630 }
631
632 #[test]
633 fn template_strips_runtime_fields() {
634 use crate::graph::TaskStatus;
635 let mut graph = make_graph("goal", &[("Fetch data", "Download it", &[])]);
636 graph.tasks[0].status = TaskStatus::Completed;
637 graph.tasks[0].retry_count = 3;
638 graph.tasks[0].assigned_agent = Some("agent-x".to_string());
639 let template = PlanTemplate::from_task_graph(&graph).unwrap();
640 assert_eq!(template.tasks[0].title, "Fetch data");
642 assert_eq!(template.tasks[0].description, "Download it");
643 }
644
645 #[test]
646 fn template_preserves_dependencies() {
647 let graph = make_graph("goal", &[("Task A", "do A", &[]), ("Task B", "do B", &[0])]);
648 let template = PlanTemplate::from_task_graph(&graph).unwrap();
649 assert_eq!(template.tasks.len(), 2);
650 assert!(template.tasks[0].depends_on.is_empty());
651 assert_eq!(template.tasks[1].depends_on.len(), 1);
652 assert_eq!(template.tasks[1].depends_on[0], template.tasks[0].task_id);
653 }
654
655 #[test]
656 fn template_serde_roundtrip() {
657 let graph = make_graph("goal", &[("Step one", "do step one", &[])]);
658 let template = PlanTemplate::from_task_graph(&graph).unwrap();
659 let json = serde_json::to_string(&template).unwrap();
660 let restored: PlanTemplate = serde_json::from_str(&json).unwrap();
661 assert_eq!(template.tasks[0].title, restored.tasks[0].title);
662 assert_eq!(template.goal, restored.goal);
663 }
664
665 #[test]
668 fn embedding_blob_roundtrip() {
669 let embedding = vec![1.0_f32, 0.5, 0.25, -1.0];
670 let blob = embedding_to_blob(&embedding);
671 let restored = blob_to_embedding(&blob).unwrap();
672 assert_eq!(embedding, restored);
673 }
674
675 #[test]
676 fn blob_to_embedding_odd_length_returns_none() {
677 let bad_blob = vec![0u8; 5]; assert!(blob_to_embedding(&bad_blob).is_none());
679 }
680
681 #[tokio::test]
684 async fn cache_miss_on_empty_cache() {
685 let pool = test_pool().await;
686 let cache = test_cache(pool).await;
687 let result = cache
688 .find_similar(&[1.0, 0.0, 0.0], "test-model")
689 .await
690 .unwrap();
691 assert!(result.is_none());
692 }
693
694 #[tokio::test]
695 async fn cache_store_and_hit() {
696 let pool = test_pool().await;
697 let mut config = PlanCacheConfig::default();
698 config.similarity_threshold = 0.9;
699 let cache = PlanCache::new(pool, config, "test-model").await.unwrap();
700
701 let graph = make_graph("deploy service", &[("Build", "build it", &[])]);
702 let embedding = vec![1.0_f32, 0.0, 0.0];
703 cache
704 .cache_plan(&graph, &embedding, "test-model")
705 .await
706 .unwrap();
707
708 let result = cache
710 .find_similar(&[1.0, 0.0, 0.0], "test-model")
711 .await
712 .unwrap();
713 assert!(result.is_some());
714 let (template, score) = result.unwrap();
715 assert!((score - 1.0).abs() < 1e-5);
716 assert_eq!(template.tasks.len(), 1);
717 }
718
719 #[tokio::test]
720 async fn cache_miss_on_dissimilar_goal() {
721 let pool = test_pool().await;
722 let mut config = PlanCacheConfig::default();
723 config.similarity_threshold = 0.9;
724 let cache = PlanCache::new(pool, config, "test-model").await.unwrap();
725
726 let graph = make_graph("goal a", &[("Task", "do it", &[])]);
727 cache
728 .cache_plan(&graph, &[1.0_f32, 0.0, 0.0], "test-model")
729 .await
730 .unwrap();
731
732 let result = cache
734 .find_similar(&[0.0, 1.0, 0.0], "test-model")
735 .await
736 .unwrap();
737 assert!(result.is_none());
738 }
739
740 #[tokio::test]
741 async fn deduplication_increments_success_count() {
742 let pool = test_pool().await;
743 let cache = test_cache(pool.clone()).await;
744
745 let graph = make_graph("same goal", &[("Task", "do it", &[])]);
746 let emb = vec![1.0_f32, 0.0];
747
748 cache.cache_plan(&graph, &emb, "test-model").await.unwrap();
749 cache.cache_plan(&graph, &emb, "test-model").await.unwrap();
750
751 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM plan_cache")
753 .fetch_one(&pool)
754 .await
755 .unwrap();
756 assert_eq!(count, 1);
757
758 let success: i64 = sqlx::query_scalar("SELECT success_count FROM plan_cache")
759 .fetch_one(&pool)
760 .await
761 .unwrap();
762 assert_eq!(success, 2);
763 }
764
765 #[tokio::test]
766 async fn eviction_removes_ttl_expired_rows() {
767 let pool = test_pool().await;
768 let mut config = PlanCacheConfig::default();
769 config.ttl_days = 0;
771 let cache = PlanCache::new(pool.clone(), config, "test-model")
772 .await
773 .unwrap();
774
775 let now = unix_now() - 1;
777 sqlx::query(
778 "INSERT INTO plan_cache \
779 (id, goal_hash, goal_text, template, task_count, created_at, last_accessed_at) \
780 VALUES (?, ?, ?, ?, ?, ?, ?)",
781 )
782 .bind("test-id")
783 .bind("hash-1")
784 .bind("goal")
785 .bind("{\"goal\":\"goal\",\"tasks\":[]}")
786 .bind(0_i64)
787 .bind(now)
788 .bind(now)
789 .execute(&pool)
790 .await
791 .unwrap();
792
793 let deleted = cache.evict().await.unwrap();
794 assert!(deleted >= 1);
795
796 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM plan_cache")
797 .fetch_one(&pool)
798 .await
799 .unwrap();
800 assert_eq!(count, 0);
801 }
802
803 #[tokio::test]
804 async fn eviction_lru_when_over_max() {
805 let pool = test_pool().await;
806 let mut config = PlanCacheConfig::default();
807 config.max_templates = 2;
808 config.ttl_days = 365;
809 let cache = PlanCache::new(pool.clone(), config, "test-model")
810 .await
811 .unwrap();
812
813 let now = unix_now();
814 for i in 0..3_i64 {
816 sqlx::query(
817 "INSERT INTO plan_cache \
818 (id, goal_hash, goal_text, template, task_count, created_at, last_accessed_at) \
819 VALUES (?, ?, ?, ?, ?, ?, ?)",
820 )
821 .bind(format!("id-{i}"))
822 .bind(format!("hash-{i}"))
823 .bind(format!("goal-{i}"))
824 .bind("{\"goal\":\"g\",\"tasks\":[]}")
825 .bind(0_i64)
826 .bind(now)
827 .bind(now + i) .execute(&pool)
829 .await
830 .unwrap();
831 }
832
833 let deleted = cache.evict().await.unwrap();
834 assert_eq!(deleted, 1);
835
836 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM plan_cache")
838 .fetch_one(&pool)
839 .await
840 .unwrap();
841 assert_eq!(count, 2);
842 }
843
844 #[tokio::test]
845 async fn stale_embedding_invalidated_on_new() {
846 let pool = test_pool().await;
847 let now = unix_now();
848
849 let emb = embedding_to_blob(&[1.0_f32, 0.0]);
851 sqlx::query(
852 "INSERT INTO plan_cache \
853 (id, goal_hash, goal_text, template, task_count, embedding, embedding_model, \
854 created_at, last_accessed_at) \
855 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
856 )
857 .bind("id-old")
858 .bind("hash-old")
859 .bind("goal old")
860 .bind("{\"goal\":\"g\",\"tasks\":[]}")
861 .bind(0_i64)
862 .bind(&emb)
863 .bind("old-model")
864 .bind(now)
865 .bind(now)
866 .execute(&pool)
867 .await
868 .unwrap();
869
870 let _cache = PlanCache::new(pool.clone(), PlanCacheConfig::default(), "new-model")
872 .await
873 .unwrap();
874
875 let model: Option<String> =
876 sqlx::query_scalar("SELECT embedding_model FROM plan_cache WHERE id = 'id-old'")
877 .fetch_one(&pool)
878 .await
879 .unwrap();
880 assert!(model.is_none(), "stale embedding_model should be NULL");
881
882 let emb_col: Option<Vec<u8>> =
883 sqlx::query_scalar("SELECT embedding FROM plan_cache WHERE id = 'id-old'")
884 .fetch_one(&pool)
885 .await
886 .unwrap();
887 assert!(emb_col.is_none(), "stale embedding should be NULL");
888 }
889
890 #[tokio::test]
891 async fn disabled_cache_not_used_in_plan_with_cache() {
892 use zeph_llm::mock::MockProvider;
893
894 let pool = test_pool().await;
895 let config = PlanCacheConfig::default(); let cache = PlanCache::new(pool, config, "test-model").await.unwrap();
897
898 let graph_json = r#"{"tasks": [
899 {"task_id": "t1", "title": "Task", "description": "do it", "depends_on": []}
900 ]}"#
901 .to_string();
902
903 let provider = MockProvider::with_responses(vec![graph_json.clone()]);
904 use crate::planner::LlmPlanner;
905 use zeph_config::OrchestrationConfig;
906 let planner = LlmPlanner::new(
907 MockProvider::with_responses(vec![graph_json]),
908 &OrchestrationConfig::default(),
909 );
910
911 let (graph, _) = plan_with_cache(
912 &planner,
913 Some(&cache),
914 &provider,
915 Some(&[1.0_f32, 0.0]),
916 "test-model",
917 "do something",
918 &[],
919 20,
920 )
921 .await
922 .unwrap();
923
924 assert_eq!(graph.tasks.len(), 1);
925 }
926
927 #[tokio::test]
928 async fn plan_with_cache_with_none_embedding_skips_cache() {
929 use crate::planner::LlmPlanner;
930 use zeph_config::OrchestrationConfig;
931 use zeph_llm::mock::MockProvider;
932
933 let pool = test_pool().await;
934 let mut config = PlanCacheConfig::default();
935 config.enabled = true;
936 config.similarity_threshold = 0.5;
937 let cache = PlanCache::new(pool, config, "test-model").await.unwrap();
938
939 let graph = make_graph("deploy service", &[("Build", "build it", &[])]);
941 cache
942 .cache_plan(&graph, &[1.0_f32, 0.0], "test-model")
943 .await
944 .unwrap();
945
946 let graph_json = r#"{"tasks": [
947 {"task_id": "fallback-task-0", "title": "Fallback", "description": "planner fallback", "depends_on": []}
948 ]}"#
949 .to_string();
950
951 let provider = MockProvider::with_responses(vec![graph_json.clone()]);
952 let planner = LlmPlanner::new(
953 MockProvider::with_responses(vec![graph_json]),
954 &OrchestrationConfig::default(),
955 );
956
957 let (result_graph, _) = plan_with_cache(
959 &planner,
960 Some(&cache),
961 &provider,
962 None, "test-model",
964 "deploy service",
965 &[],
966 20,
967 )
968 .await
969 .unwrap();
970
971 assert_eq!(result_graph.tasks[0].title, "Fallback");
972 }
973
974 #[tokio::test]
975 async fn adapt_plan_error_fallback_to_full_decomposition() {
976 use crate::planner::LlmPlanner;
977 use zeph_config::OrchestrationConfig;
978 use zeph_llm::mock::MockProvider;
979
980 let pool = test_pool().await;
981 let mut config = PlanCacheConfig::default();
982 config.enabled = true;
983 config.similarity_threshold = 0.5;
984 let cache = PlanCache::new(pool, config, "test-model").await.unwrap();
985
986 let graph = make_graph("deploy service", &[("Build", "build it", &[])]);
988 cache
989 .cache_plan(&graph, &[1.0_f32, 0.0], "test-model")
990 .await
991 .unwrap();
992
993 let bad_provider = MockProvider::with_responses(vec!["not valid json".to_string()]);
995
996 let fallback_json = r#"{"tasks": [
998 {"task_id": "fallback-0", "title": "Fallback Task", "description": "via planner", "depends_on": []}
999 ]}"#
1000 .to_string();
1001 let planner = LlmPlanner::new(
1002 MockProvider::with_responses(vec![fallback_json]),
1003 &OrchestrationConfig::default(),
1004 );
1005
1006 let (result_graph, _) = plan_with_cache(
1007 &planner,
1008 Some(&cache),
1009 &bad_provider, Some(&[1.0_f32, 0.0]),
1011 "test-model",
1012 "deploy service",
1013 &[],
1014 20,
1015 )
1016 .await
1017 .unwrap();
1018
1019 assert_eq!(result_graph.tasks[0].title, "Fallback Task");
1021 }
1022}