1use blake3;
11use serde::{Deserialize, Serialize};
12use zeph_config::PlanCacheConfig;
13use zeph_db::DbPool;
14#[allow(unused_imports)]
15use zeph_db::sql;
16use zeph_llm::provider::{LlmProvider, Message, Role};
17
18use super::dag;
19use super::error::OrchestrationError;
20use super::graph::TaskGraph;
21use super::planner::{PlannerResponse, convert_response_pub};
22use zeph_subagent::SubAgentDef;
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct TemplateTask {
27 pub title: String,
28 pub description: String,
29 #[serde(default, skip_serializing_if = "Option::is_none")]
30 pub agent_hint: Option<String>,
31 #[serde(default, skip_serializing_if = "Vec::is_empty")]
32 pub depends_on: Vec<String>,
33 #[serde(default, skip_serializing_if = "Option::is_none")]
34 pub failure_strategy: Option<String>,
35 pub task_id: String,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct PlanTemplate {
46 pub goal: String,
48 pub tasks: Vec<TemplateTask>,
50}
51
52impl PlanTemplate {
53 pub fn from_task_graph(graph: &TaskGraph) -> Result<Self, OrchestrationError> {
59 if graph.tasks.is_empty() {
60 return Err(OrchestrationError::PlanningFailed(
61 "cannot cache a plan with zero tasks".into(),
62 ));
63 }
64
65 let id_to_slug: Vec<String> = graph
67 .tasks
68 .iter()
69 .map(|n| slugify_title(&n.title, n.id.as_u32()))
70 .collect();
71
72 let tasks = graph
73 .tasks
74 .iter()
75 .enumerate()
76 .map(|(i, node)| TemplateTask {
77 title: node.title.clone(),
78 description: node.description.clone(),
79 agent_hint: node.agent_hint.clone(),
80 depends_on: node
81 .depends_on
82 .iter()
83 .map(|dep| id_to_slug[dep.index()].clone())
84 .collect(),
85 failure_strategy: node.failure_strategy.map(|fs| fs.to_string()),
86 task_id: id_to_slug[i].clone(),
87 })
88 .collect();
89
90 Ok(Self {
91 goal: normalize_goal(&graph.goal),
92 tasks,
93 })
94 }
95}
96
97#[must_use]
103pub fn normalize_goal(text: &str) -> String {
104 let trimmed = text.trim();
105 let mut result = String::with_capacity(trimmed.len());
106 let mut prev_space = false;
107 for ch in trimmed.chars() {
108 if ch.is_whitespace() {
109 if !prev_space && !result.is_empty() {
110 result.push(' ');
111 prev_space = true;
112 }
113 } else {
114 for lc in ch.to_lowercase() {
115 result.push(lc);
116 }
117 prev_space = false;
118 }
119 }
120 result
121}
122
123#[must_use]
125pub fn goal_hash(normalized: &str) -> String {
126 blake3::hash(normalized.as_bytes()).to_hex().to_string()
127}
128
129fn slugify_title(title: &str, idx: u32) -> String {
131 let slug: String = title
132 .chars()
133 .map(|c| {
134 if c.is_ascii_alphanumeric() {
135 c.to_ascii_lowercase()
136 } else {
137 '-'
138 }
139 })
140 .collect::<String>()
141 .split('-')
142 .filter(|s| !s.is_empty())
143 .collect::<Vec<_>>()
144 .join("-");
145
146 let capped = if slug.len() > 32 { &slug[..32] } else { &slug };
148 let capped = capped.trim_end_matches('-');
150 if capped.is_empty() {
151 format!("task-{idx}")
152 } else {
153 format!("{capped}-{idx}")
154 }
155}
156
157fn embedding_to_blob(embedding: &[f32]) -> Vec<u8> {
159 embedding.iter().flat_map(|f| f.to_le_bytes()).collect()
160}
161
162fn blob_to_embedding(blob: &[u8]) -> Option<Vec<f32>> {
167 if !blob.len().is_multiple_of(4) {
168 tracing::warn!(
169 len = blob.len(),
170 "plan cache: embedding blob length not a multiple of 4"
171 );
172 return None;
173 }
174 Some(
175 blob.chunks_exact(4)
176 .map(|chunk| f32::from_le_bytes(chunk.try_into().expect("chunk is exactly 4 bytes")))
177 .collect(),
178 )
179}
180
181fn unix_now() -> i64 {
182 #[allow(clippy::cast_possible_wrap)]
183 {
184 std::time::SystemTime::now()
185 .duration_since(std::time::UNIX_EPOCH)
186 .unwrap_or_default()
187 .as_secs() as i64
188 }
189}
190
191#[derive(Debug, thiserror::Error)]
193pub enum PlanCacheError {
194 #[error("database error: {0}")]
195 Database(#[from] zeph_db::SqlxError),
196 #[error("serialization error: {0}")]
197 Serialization(#[from] serde_json::Error),
198 #[error("plan template extraction failed: {0}")]
199 Extraction(String),
200}
201
202pub struct PlanCache {
208 pool: DbPool,
209 config: PlanCacheConfig,
210}
211
212impl PlanCache {
213 pub async fn new(
219 pool: DbPool,
220 config: PlanCacheConfig,
221 current_embedding_model: &str,
222 ) -> Result<Self, PlanCacheError> {
223 let cache = Self { pool, config };
224 cache
225 .invalidate_stale_embeddings(current_embedding_model)
226 .await?;
227 Ok(cache)
228 }
229
230 async fn invalidate_stale_embeddings(&self, current_model: &str) -> Result<(), PlanCacheError> {
236 let affected = zeph_db::query(sql!(
237 "UPDATE plan_cache SET embedding = NULL, embedding_model = NULL \
238 WHERE embedding IS NOT NULL AND embedding_model != ?"
239 ))
240 .bind(current_model)
241 .execute(&self.pool)
242 .await?
243 .rows_affected();
244
245 if affected > 0 {
246 tracing::info!(
247 rows = affected,
248 current_model,
249 "plan cache: invalidated stale embeddings for model change"
250 );
251 }
252 Ok(())
253 }
254
255 pub async fn find_similar(
267 &self,
268 goal_embedding: &[f32],
269 embedding_model: &str,
270 ) -> Result<Option<(PlanTemplate, f32)>, PlanCacheError> {
271 let rows: Vec<(String, String, Vec<u8>)> = zeph_db::query_as(sql!(
272 "SELECT id, template, embedding FROM plan_cache \
273 WHERE embedding IS NOT NULL AND embedding_model = ? \
274 ORDER BY last_accessed_at DESC LIMIT ?"
275 ))
276 .bind(embedding_model)
277 .bind(self.config.max_templates)
278 .fetch_all(&self.pool)
279 .await?;
280
281 let mut best_score = -1.0_f32;
282 let mut best_id: Option<String> = None;
283 let mut best_template_json: Option<String> = None;
284
285 for (id, template_json, blob) in rows {
286 if let Some(stored) = blob_to_embedding(&blob) {
287 let score = zeph_memory::cosine_similarity(goal_embedding, &stored);
288 if score > best_score {
289 best_score = score;
290 best_id = Some(id);
291 best_template_json = Some(template_json);
292 }
293 }
294 }
295
296 if best_score >= self.config.similarity_threshold
297 && let (Some(id), Some(json)) = (best_id, best_template_json)
298 {
299 let now = unix_now();
301 if let Err(e) = zeph_db::query(sql!(
302 "UPDATE plan_cache SET last_accessed_at = ?, adapted_count = adapted_count + 1 \
303 WHERE id = ?"
304 ))
305 .bind(now)
306 .bind(&id)
307 .execute(&self.pool)
308 .await
309 {
310 tracing::warn!(error = %e, "plan cache: failed to update last_accessed_at");
311 }
312 let template: PlanTemplate = serde_json::from_str(&json)?;
313 return Ok(Some((template, best_score)));
314 }
315
316 Ok(None)
317 }
318
319 pub async fn cache_plan(
329 &self,
330 graph: &TaskGraph,
331 goal_embedding: &[f32],
332 embedding_model: &str,
333 ) -> Result<(), PlanCacheError> {
334 let template = PlanTemplate::from_task_graph(graph)
335 .map_err(|e| PlanCacheError::Extraction(e.to_string()))?;
336
337 let normalized = normalize_goal(&graph.goal);
338 let hash = goal_hash(&normalized);
339 let template_json = serde_json::to_string(&template)?;
340 let task_count = i64::try_from(template.tasks.len()).unwrap_or(i64::MAX);
341 let now = unix_now();
342 let id = uuid::Uuid::new_v4().to_string();
343 let blob = embedding_to_blob(goal_embedding);
344
345 zeph_db::query(sql!(
346 "INSERT INTO plan_cache \
347 (id, goal_hash, goal_text, template, task_count, success_count, adapted_count, \
348 embedding, embedding_model, created_at, last_accessed_at) \
349 VALUES (?, ?, ?, ?, ?, 1, 0, ?, ?, ?, ?) \
350 ON CONFLICT(goal_hash) DO UPDATE SET \
351 success_count = success_count + 1, \
352 template = excluded.template, \
353 task_count = excluded.task_count, \
354 embedding = excluded.embedding, \
355 embedding_model = excluded.embedding_model, \
356 last_accessed_at = excluded.last_accessed_at"
357 ))
358 .bind(&id)
359 .bind(&hash)
360 .bind(&normalized)
361 .bind(&template_json)
362 .bind(task_count)
363 .bind(&blob)
364 .bind(embedding_model)
365 .bind(now)
366 .bind(now)
367 .execute(&self.pool)
368 .await?;
369
370 if let Err(e) = self.evict().await {
372 tracing::warn!(error = %e, "plan cache: eviction failed after cache_plan");
373 }
374
375 Ok(())
376 }
377
378 pub async fn evict(&self) -> Result<u32, PlanCacheError> {
389 let now = unix_now();
390 let ttl_secs = i64::from(self.config.ttl_days) * 86_400;
391 let cutoff = now.saturating_sub(ttl_secs);
392
393 let ttl_deleted = zeph_db::query(sql!("DELETE FROM plan_cache WHERE last_accessed_at < ?"))
394 .bind(cutoff)
395 .execute(&self.pool)
396 .await?
397 .rows_affected();
398
399 let count: i64 = zeph_db::query_scalar(sql!("SELECT COUNT(*) FROM plan_cache"))
401 .fetch_one(&self.pool)
402 .await?;
403
404 let max = i64::from(self.config.max_templates);
405 let lru_deleted = if count > max {
406 let excess = count - max;
407 zeph_db::query(sql!(
408 "DELETE FROM plan_cache WHERE id IN \
409 (SELECT id FROM plan_cache ORDER BY last_accessed_at ASC LIMIT ?)"
410 ))
411 .bind(excess)
412 .execute(&self.pool)
413 .await?
414 .rows_affected()
415 } else {
416 0
417 };
418
419 let total = ttl_deleted + lru_deleted;
420 if total > 0 {
421 tracing::debug!(ttl_deleted, lru_deleted, "plan cache: eviction complete");
422 }
423 Ok(u32::try_from(total).unwrap_or(u32::MAX))
424 }
425}
426
427#[allow(clippy::too_many_arguments)]
436pub async fn plan_with_cache<P>(
437 planner: &P,
438 plan_cache: Option<&PlanCache>,
439 provider: &impl LlmProvider,
440 embedding: Option<&[f32]>,
441 embedding_model: &str,
442 goal: &str,
443 available_agents: &[SubAgentDef],
444 max_tasks: u32,
445) -> Result<(TaskGraph, Option<(u64, u64)>), OrchestrationError>
446where
447 P: super::planner::Planner,
448{
449 if let (Some(cache), Some(emb)) = (plan_cache, embedding)
450 && cache.config.enabled
451 {
452 match cache.find_similar(emb, embedding_model).await {
453 Ok(Some((template, score))) => {
454 tracing::info!(
455 similarity = score,
456 tasks = template.tasks.len(),
457 "plan cache hit, adapting template"
458 );
459 match adapt_plan(provider, goal, &template, available_agents, max_tasks).await {
460 Ok(result) => return Ok(result),
461 Err(e) => {
462 tracing::warn!(
463 error = %e,
464 "plan cache: adaptation failed, falling back to full decomposition"
465 );
466 }
467 }
468 }
469 Ok(None) => {
470 tracing::debug!("plan cache miss");
471 }
472 Err(e) => {
473 tracing::warn!(error = %e, "plan cache: find_similar failed, using full decomposition");
474 }
475 }
476 }
477
478 planner.plan(goal, available_agents).await
479}
480
481async fn adapt_plan(
492 provider: &impl LlmProvider,
493 goal: &str,
494 template: &PlanTemplate,
495 available_agents: &[SubAgentDef],
496 max_tasks: u32,
497) -> Result<(TaskGraph, Option<(u64, u64)>), OrchestrationError> {
498 use zeph_subagent::ToolPolicy;
499
500 let agent_catalog = available_agents
501 .iter()
502 .map(|a| {
503 let tools = match &a.tools {
504 ToolPolicy::AllowList(list) => list.join(", "),
505 ToolPolicy::DenyList(excluded) => {
506 format!("all except: [{}]", excluded.join(", "))
507 }
508 ToolPolicy::InheritAll => "all".to_string(),
509 };
510 format!(
511 "- name: \"{}\", description: \"{}\", tools: [{}]",
512 a.name, a.description, tools
513 )
514 })
515 .collect::<Vec<_>>()
516 .join("\n");
517
518 let template_json = serde_json::to_string(&template.tasks)
519 .map_err(|e| OrchestrationError::PlanningFailed(e.to_string()))?;
520
521 let system = format!(
522 "You are a task planner. A cached plan template exists for a similar goal. \
523 Adapt it for the new goal by adjusting task descriptions and adding or removing \
524 tasks as needed. Keep the same JSON structure.\n\n\
525 Available agents:\n{agent_catalog}\n\n\
526 Rules:\n\
527 - Each task must have a unique task_id (short, descriptive, kebab-case: [a-z0-9-]).\n\
528 - Specify dependencies using task_id strings in depends_on.\n\
529 - Do not create more than {max_tasks} tasks.\n\
530 - failure_strategy is optional: \"abort\", \"retry\", \"skip\", \"ask\"."
531 );
532
533 let user = format!(
534 "New goal:\n{goal}\n\nCached template (for similar goal \"{}\"):\n{template_json}\n\n\
535 Adapt the template for the new goal. Return JSON: {{\"tasks\": [...]}}",
536 template.goal
537 );
538
539 let messages = vec![
540 Message::from_legacy(Role::System, system),
541 Message::from_legacy(Role::User, user),
542 ];
543
544 let response: PlannerResponse = provider
545 .chat_typed(&messages)
546 .await
547 .map_err(|e| OrchestrationError::PlanningFailed(e.to_string()))?;
548
549 let usage = provider.last_usage();
550
551 let graph = convert_response_pub(response, goal, available_agents, max_tasks)?;
552
553 dag::validate(&graph.tasks, max_tasks as usize)?;
554
555 Ok((graph, usage))
556}
557
558#[cfg(test)]
559mod tests {
560 use super::super::graph::{TaskId, TaskNode};
561 use super::*;
562 use zeph_memory::store::SqliteStore;
563
564 async fn test_pool() -> DbPool {
565 let store = SqliteStore::new(":memory:").await.unwrap();
566 store.pool().clone()
567 }
568
569 async fn test_cache(pool: DbPool) -> PlanCache {
570 PlanCache::new(pool, PlanCacheConfig::default(), "test-model")
571 .await
572 .unwrap()
573 }
574
575 fn make_graph(goal: &str, tasks: &[(&str, &str, &[u32])]) -> TaskGraph {
576 let mut graph = TaskGraph::new(goal);
577 for (i, (title, desc, deps)) in tasks.iter().enumerate() {
578 #[allow(clippy::cast_possible_truncation)]
579 let mut node = TaskNode::new(i as u32, *title, *desc);
580 node.depends_on = deps.iter().map(|&d| TaskId(d)).collect();
581 graph.tasks.push(node);
582 }
583 graph
584 }
585
586 #[test]
589 fn normalize_trims_and_lowercases() {
590 assert_eq!(normalize_goal(" Hello World "), "hello world");
591 }
592
593 #[test]
594 fn normalize_collapses_internal_whitespace() {
595 assert_eq!(normalize_goal("hello world"), "hello world");
596 }
597
598 #[test]
599 fn normalize_empty_string() {
600 assert_eq!(normalize_goal(""), "");
601 }
602
603 #[test]
604 fn normalize_whitespace_only() {
605 assert_eq!(normalize_goal(" "), "");
606 }
607
608 #[test]
611 fn goal_hash_is_deterministic() {
612 let h1 = goal_hash("deploy service");
613 let h2 = goal_hash("deploy service");
614 assert_eq!(h1, h2);
615 }
616
617 #[test]
618 fn goal_hash_differs_for_different_goals() {
619 assert_ne!(goal_hash("deploy service"), goal_hash("build artifact"));
620 }
621
622 #[test]
623 fn goal_hash_nonempty() {
624 assert!(!goal_hash("goal").is_empty());
625 }
626
627 #[test]
630 fn template_from_empty_graph_returns_error() {
631 let graph = TaskGraph::new("goal");
632 assert!(PlanTemplate::from_task_graph(&graph).is_err());
633 }
634
635 #[test]
636 fn template_strips_runtime_fields() {
637 use crate::graph::TaskStatus;
638 let mut graph = make_graph("goal", &[("Fetch data", "Download it", &[])]);
639 graph.tasks[0].status = TaskStatus::Completed;
640 graph.tasks[0].retry_count = 3;
641 graph.tasks[0].assigned_agent = Some("agent-x".to_string());
642 let template = PlanTemplate::from_task_graph(&graph).unwrap();
643 assert_eq!(template.tasks[0].title, "Fetch data");
645 assert_eq!(template.tasks[0].description, "Download it");
646 }
647
648 #[test]
649 fn template_preserves_dependencies() {
650 let graph = make_graph("goal", &[("Task A", "do A", &[]), ("Task B", "do B", &[0])]);
651 let template = PlanTemplate::from_task_graph(&graph).unwrap();
652 assert_eq!(template.tasks.len(), 2);
653 assert!(template.tasks[0].depends_on.is_empty());
654 assert_eq!(template.tasks[1].depends_on.len(), 1);
655 assert_eq!(template.tasks[1].depends_on[0], template.tasks[0].task_id);
656 }
657
658 #[test]
659 fn template_serde_roundtrip() {
660 let graph = make_graph("goal", &[("Step one", "do step one", &[])]);
661 let template = PlanTemplate::from_task_graph(&graph).unwrap();
662 let json = serde_json::to_string(&template).unwrap();
663 let restored: PlanTemplate = serde_json::from_str(&json).unwrap();
664 assert_eq!(template.tasks[0].title, restored.tasks[0].title);
665 assert_eq!(template.goal, restored.goal);
666 }
667
668 #[test]
671 fn embedding_blob_roundtrip() {
672 let embedding = vec![1.0_f32, 0.5, 0.25, -1.0];
673 let blob = embedding_to_blob(&embedding);
674 let restored = blob_to_embedding(&blob).unwrap();
675 assert_eq!(embedding, restored);
676 }
677
678 #[test]
679 fn blob_to_embedding_odd_length_returns_none() {
680 let bad_blob = vec![0u8; 5]; assert!(blob_to_embedding(&bad_blob).is_none());
682 }
683
684 #[tokio::test]
687 async fn cache_miss_on_empty_cache() {
688 let pool = test_pool().await;
689 let cache = test_cache(pool).await;
690 let result = cache
691 .find_similar(&[1.0, 0.0, 0.0], "test-model")
692 .await
693 .unwrap();
694 assert!(result.is_none());
695 }
696
697 #[tokio::test]
698 async fn cache_store_and_hit() {
699 let pool = test_pool().await;
700 let config = PlanCacheConfig {
701 similarity_threshold: 0.9,
702 ..PlanCacheConfig::default()
703 };
704 let cache = PlanCache::new(pool, config, "test-model").await.unwrap();
705
706 let graph = make_graph("deploy service", &[("Build", "build it", &[])]);
707 let embedding = vec![1.0_f32, 0.0, 0.0];
708 cache
709 .cache_plan(&graph, &embedding, "test-model")
710 .await
711 .unwrap();
712
713 let result = cache
715 .find_similar(&[1.0, 0.0, 0.0], "test-model")
716 .await
717 .unwrap();
718 assert!(result.is_some());
719 let (template, score) = result.unwrap();
720 assert!((score - 1.0).abs() < 1e-5);
721 assert_eq!(template.tasks.len(), 1);
722 }
723
724 #[tokio::test]
725 async fn cache_miss_on_dissimilar_goal() {
726 let pool = test_pool().await;
727 let config = PlanCacheConfig {
728 similarity_threshold: 0.9,
729 ..PlanCacheConfig::default()
730 };
731 let cache = PlanCache::new(pool, config, "test-model").await.unwrap();
732
733 let graph = make_graph("goal a", &[("Task", "do it", &[])]);
734 cache
735 .cache_plan(&graph, &[1.0_f32, 0.0, 0.0], "test-model")
736 .await
737 .unwrap();
738
739 let result = cache
741 .find_similar(&[0.0, 1.0, 0.0], "test-model")
742 .await
743 .unwrap();
744 assert!(result.is_none());
745 }
746
747 #[tokio::test]
748 async fn deduplication_increments_success_count() {
749 let pool = test_pool().await;
750 let cache = test_cache(pool.clone()).await;
751
752 let graph = make_graph("same goal", &[("Task", "do it", &[])]);
753 let emb = vec![1.0_f32, 0.0];
754
755 cache.cache_plan(&graph, &emb, "test-model").await.unwrap();
756 cache.cache_plan(&graph, &emb, "test-model").await.unwrap();
757
758 let count: i64 = zeph_db::query_scalar(sql!("SELECT COUNT(*) FROM plan_cache"))
760 .fetch_one(&pool)
761 .await
762 .unwrap();
763 assert_eq!(count, 1);
764
765 let success: i64 = zeph_db::query_scalar(sql!("SELECT success_count FROM plan_cache"))
766 .fetch_one(&pool)
767 .await
768 .unwrap();
769 assert_eq!(success, 2);
770 }
771
772 #[tokio::test]
773 async fn eviction_removes_ttl_expired_rows() {
774 let pool = test_pool().await;
775 let config = PlanCacheConfig {
777 ttl_days: 0,
778 ..PlanCacheConfig::default()
779 };
780 let cache = PlanCache::new(pool.clone(), config, "test-model")
781 .await
782 .unwrap();
783
784 let now = unix_now() - 1;
786 zeph_db::query(sql!(
787 "INSERT INTO plan_cache \
788 (id, goal_hash, goal_text, template, task_count, created_at, last_accessed_at) \
789 VALUES (?, ?, ?, ?, ?, ?, ?)"
790 ))
791 .bind("test-id")
792 .bind("hash-1")
793 .bind("goal")
794 .bind("{\"goal\":\"goal\",\"tasks\":[]}")
795 .bind(0_i64)
796 .bind(now)
797 .bind(now)
798 .execute(&pool)
799 .await
800 .unwrap();
801
802 let deleted = cache.evict().await.unwrap();
803 assert!(deleted >= 1);
804
805 let count: i64 = zeph_db::query_scalar(sql!("SELECT COUNT(*) FROM plan_cache"))
806 .fetch_one(&pool)
807 .await
808 .unwrap();
809 assert_eq!(count, 0);
810 }
811
812 #[tokio::test]
813 async fn eviction_lru_when_over_max() {
814 let pool = test_pool().await;
815 let config = PlanCacheConfig {
816 max_templates: 2,
817 ttl_days: 365,
818 ..PlanCacheConfig::default()
819 };
820 let cache = PlanCache::new(pool.clone(), config, "test-model")
821 .await
822 .unwrap();
823
824 let now = unix_now();
825 for i in 0..3_i64 {
827 zeph_db::query(sql!(
828 "INSERT INTO plan_cache \
829 (id, goal_hash, goal_text, template, task_count, created_at, last_accessed_at) \
830 VALUES (?, ?, ?, ?, ?, ?, ?)"
831 ))
832 .bind(format!("id-{i}"))
833 .bind(format!("hash-{i}"))
834 .bind(format!("goal-{i}"))
835 .bind("{\"goal\":\"g\",\"tasks\":[]}")
836 .bind(0_i64)
837 .bind(now)
838 .bind(now + i) .execute(&pool)
840 .await
841 .unwrap();
842 }
843
844 let deleted = cache.evict().await.unwrap();
845 assert_eq!(deleted, 1);
846
847 let count: i64 = zeph_db::query_scalar(sql!("SELECT COUNT(*) FROM plan_cache"))
849 .fetch_one(&pool)
850 .await
851 .unwrap();
852 assert_eq!(count, 2);
853 }
854
855 #[tokio::test]
856 async fn stale_embedding_invalidated_on_new() {
857 let pool = test_pool().await;
858 let now = unix_now();
859
860 let emb = embedding_to_blob(&[1.0_f32, 0.0]);
862 zeph_db::query(sql!(
863 "INSERT INTO plan_cache \
864 (id, goal_hash, goal_text, template, task_count, embedding, embedding_model, \
865 created_at, last_accessed_at) \
866 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"
867 ))
868 .bind("id-old")
869 .bind("hash-old")
870 .bind("goal old")
871 .bind("{\"goal\":\"g\",\"tasks\":[]}")
872 .bind(0_i64)
873 .bind(&emb)
874 .bind("old-model")
875 .bind(now)
876 .bind(now)
877 .execute(&pool)
878 .await
879 .unwrap();
880
881 let _cache = PlanCache::new(pool.clone(), PlanCacheConfig::default(), "new-model")
883 .await
884 .unwrap();
885
886 let model: Option<String> = zeph_db::query_scalar(sql!(
887 "SELECT embedding_model FROM plan_cache WHERE id = 'id-old'"
888 ))
889 .fetch_one(&pool)
890 .await
891 .unwrap();
892 assert!(model.is_none(), "stale embedding_model should be NULL");
893
894 let emb_col: Option<Vec<u8>> =
895 zeph_db::query_scalar(sql!("SELECT embedding FROM plan_cache WHERE id = 'id-old'"))
896 .fetch_one(&pool)
897 .await
898 .unwrap();
899 assert!(emb_col.is_none(), "stale embedding should be NULL");
900 }
901
902 #[tokio::test]
903 async fn disabled_cache_not_used_in_plan_with_cache() {
904 use crate::planner::LlmPlanner;
905 use zeph_config::OrchestrationConfig;
906 use zeph_llm::mock::MockProvider;
907
908 let pool = test_pool().await;
909 let config = PlanCacheConfig::default(); let cache = PlanCache::new(pool, config, "test-model").await.unwrap();
911
912 let graph_json = r#"{"tasks": [
913 {"task_id": "t1", "title": "Task", "description": "do it", "depends_on": []}
914 ]}"#
915 .to_string();
916
917 let provider = MockProvider::with_responses(vec![graph_json.clone()]);
918 let planner = LlmPlanner::new(
919 MockProvider::with_responses(vec![graph_json]),
920 &OrchestrationConfig::default(),
921 );
922
923 let (graph, _) = plan_with_cache(
924 &planner,
925 Some(&cache),
926 &provider,
927 Some(&[1.0_f32, 0.0]),
928 "test-model",
929 "do something",
930 &[],
931 20,
932 )
933 .await
934 .unwrap();
935
936 assert_eq!(graph.tasks.len(), 1);
937 }
938
939 #[tokio::test]
940 async fn plan_with_cache_with_none_embedding_skips_cache() {
941 use crate::planner::LlmPlanner;
942 use zeph_config::OrchestrationConfig;
943 use zeph_llm::mock::MockProvider;
944
945 let pool = test_pool().await;
946 let config = PlanCacheConfig {
947 enabled: true,
948 similarity_threshold: 0.5,
949 ..PlanCacheConfig::default()
950 };
951 let cache = PlanCache::new(pool, config, "test-model").await.unwrap();
952
953 let graph = make_graph("deploy service", &[("Build", "build it", &[])]);
955 cache
956 .cache_plan(&graph, &[1.0_f32, 0.0], "test-model")
957 .await
958 .unwrap();
959
960 let graph_json = r#"{"tasks": [
961 {"task_id": "fallback-task-0", "title": "Fallback", "description": "planner fallback", "depends_on": []}
962 ]}"#
963 .to_string();
964
965 let provider = MockProvider::with_responses(vec![graph_json.clone()]);
966 let planner = LlmPlanner::new(
967 MockProvider::with_responses(vec![graph_json]),
968 &OrchestrationConfig::default(),
969 );
970
971 let (result_graph, _) = plan_with_cache(
973 &planner,
974 Some(&cache),
975 &provider,
976 None, "test-model",
978 "deploy service",
979 &[],
980 20,
981 )
982 .await
983 .unwrap();
984
985 assert_eq!(result_graph.tasks[0].title, "Fallback");
986 }
987
988 #[tokio::test]
989 async fn adapt_plan_error_fallback_to_full_decomposition() {
990 use crate::planner::LlmPlanner;
991 use zeph_config::OrchestrationConfig;
992 use zeph_llm::mock::MockProvider;
993
994 let pool = test_pool().await;
995 let config = PlanCacheConfig {
996 enabled: true,
997 similarity_threshold: 0.5,
998 ..PlanCacheConfig::default()
999 };
1000 let cache = PlanCache::new(pool, config, "test-model").await.unwrap();
1001
1002 let graph = make_graph("deploy service", &[("Build", "build it", &[])]);
1004 cache
1005 .cache_plan(&graph, &[1.0_f32, 0.0], "test-model")
1006 .await
1007 .unwrap();
1008
1009 let bad_provider = MockProvider::with_responses(vec!["not valid json".to_string()]);
1011
1012 let fallback_json = r#"{"tasks": [
1014 {"task_id": "fallback-0", "title": "Fallback Task", "description": "via planner", "depends_on": []}
1015 ]}"#
1016 .to_string();
1017 let planner = LlmPlanner::new(
1018 MockProvider::with_responses(vec![fallback_json]),
1019 &OrchestrationConfig::default(),
1020 );
1021
1022 let (result_graph, _) = plan_with_cache(
1023 &planner,
1024 Some(&cache),
1025 &bad_provider, Some(&[1.0_f32, 0.0]),
1027 "test-model",
1028 "deploy service",
1029 &[],
1030 20,
1031 )
1032 .await
1033 .unwrap();
1034
1035 assert_eq!(result_graph.tasks[0].title, "Fallback Task");
1037 }
1038}