1use blake3;
11use serde::{Deserialize, Serialize};
12use zeph_config::PlanCacheConfig;
13use zeph_db::DbPool;
14#[allow(unused_imports)]
15use zeph_db::sql;
16use zeph_llm::provider::{LlmProvider, Message, Role};
17
18use super::dag;
19use super::error::OrchestrationError;
20use super::graph::TaskGraph;
21use super::planner::{PlannerResponse, convert_response_pub};
22use zeph_subagent::SubAgentDef;
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct TemplateTask {
30 pub title: String,
32 pub description: String,
34 #[serde(default, skip_serializing_if = "Option::is_none")]
36 pub agent_hint: Option<String>,
37 #[serde(default, skip_serializing_if = "Vec::is_empty")]
39 pub depends_on: Vec<String>,
40 #[serde(default, skip_serializing_if = "Option::is_none")]
44 pub failure_strategy: Option<String>,
45 pub task_id: String,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct PlanTemplate {
68 pub goal: String,
70 pub tasks: Vec<TemplateTask>,
72}
73
74impl PlanTemplate {
75 pub fn from_task_graph(graph: &TaskGraph) -> Result<Self, OrchestrationError> {
81 if graph.tasks.is_empty() {
82 return Err(OrchestrationError::PlanningFailed(
83 "cannot cache a plan with zero tasks".into(),
84 ));
85 }
86
87 let id_to_slug: Vec<String> = graph
89 .tasks
90 .iter()
91 .map(|n| slugify_title(&n.title, n.id.as_u32()))
92 .collect();
93
94 let tasks = graph
95 .tasks
96 .iter()
97 .enumerate()
98 .map(|(i, node)| TemplateTask {
99 title: node.title.clone(),
100 description: node.description.clone(),
101 agent_hint: node.agent_hint.clone(),
102 depends_on: node
103 .depends_on
104 .iter()
105 .map(|dep| id_to_slug[dep.index()].clone())
106 .collect(),
107 failure_strategy: node.failure_strategy.map(|fs| fs.to_string()),
108 task_id: id_to_slug[i].clone(),
109 })
110 .collect();
111
112 Ok(Self {
113 goal: normalize_goal(&graph.goal),
114 tasks,
115 })
116 }
117}
118
119#[must_use]
125pub fn normalize_goal(text: &str) -> String {
126 let trimmed = text.trim();
127 let mut result = String::with_capacity(trimmed.len());
128 let mut prev_space = false;
129 for ch in trimmed.chars() {
130 if ch.is_whitespace() {
131 if !prev_space && !result.is_empty() {
132 result.push(' ');
133 prev_space = true;
134 }
135 } else {
136 for lc in ch.to_lowercase() {
137 result.push(lc);
138 }
139 prev_space = false;
140 }
141 }
142 result
143}
144
145#[must_use]
147pub fn goal_hash(normalized: &str) -> String {
148 blake3::hash(normalized.as_bytes()).to_hex().to_string()
149}
150
151fn slugify_title(title: &str, idx: u32) -> String {
153 let slug: String = title
154 .chars()
155 .map(|c| {
156 if c.is_ascii_alphanumeric() {
157 c.to_ascii_lowercase()
158 } else {
159 '-'
160 }
161 })
162 .collect::<String>()
163 .split('-')
164 .filter(|s| !s.is_empty())
165 .collect::<Vec<_>>()
166 .join("-");
167
168 let capped = if slug.len() > 32 { &slug[..32] } else { &slug };
170 let capped = capped.trim_end_matches('-');
172 if capped.is_empty() {
173 format!("task-{idx}")
174 } else {
175 format!("{capped}-{idx}")
176 }
177}
178
179fn embedding_to_blob(embedding: &[f32]) -> Vec<u8> {
181 embedding.iter().flat_map(|f| f.to_le_bytes()).collect()
182}
183
184fn blob_to_embedding(blob: &[u8]) -> Option<Vec<f32>> {
189 if !blob.len().is_multiple_of(4) {
190 tracing::warn!(
191 len = blob.len(),
192 "plan cache: embedding blob length not a multiple of 4"
193 );
194 return None;
195 }
196 Some(
197 blob.chunks_exact(4)
198 .map(|chunk| f32::from_le_bytes(chunk.try_into().expect("chunk is exactly 4 bytes")))
199 .collect(),
200 )
201}
202
203fn unix_now() -> i64 {
204 #[allow(clippy::cast_possible_wrap)]
205 {
206 std::time::SystemTime::now()
207 .duration_since(std::time::UNIX_EPOCH)
208 .unwrap_or_default()
209 .as_secs() as i64
210 }
211}
212
213#[derive(Debug, thiserror::Error)]
218pub enum PlanCacheError {
219 #[error("database error: {0}")]
221 Database(#[from] zeph_db::SqlxError),
222 #[error("serialization error: {0}")]
224 Serialization(#[from] serde_json::Error),
225 #[error("plan template extraction failed: {0}")]
227 Extraction(String),
228}
229
230pub struct PlanCache {
236 pool: DbPool,
237 config: PlanCacheConfig,
238}
239
240impl PlanCache {
241 pub async fn new(
247 pool: DbPool,
248 config: PlanCacheConfig,
249 current_embedding_model: &str,
250 ) -> Result<Self, PlanCacheError> {
251 let cache = Self { pool, config };
252 cache
253 .invalidate_stale_embeddings(current_embedding_model)
254 .await?;
255 Ok(cache)
256 }
257
258 async fn invalidate_stale_embeddings(&self, current_model: &str) -> Result<(), PlanCacheError> {
264 let affected = zeph_db::query(sql!(
265 "UPDATE plan_cache SET embedding = NULL, embedding_model = NULL \
266 WHERE embedding IS NOT NULL AND embedding_model != ?"
267 ))
268 .bind(current_model)
269 .execute(&self.pool)
270 .await?
271 .rows_affected();
272
273 if affected > 0 {
274 tracing::info!(
275 rows = affected,
276 current_model,
277 "plan cache: invalidated stale embeddings for model change"
278 );
279 }
280 Ok(())
281 }
282
283 pub async fn find_similar(
295 &self,
296 goal_embedding: &[f32],
297 embedding_model: &str,
298 ) -> Result<Option<(PlanTemplate, f32)>, PlanCacheError> {
299 let rows: Vec<(String, String, Vec<u8>)> = zeph_db::query_as(sql!(
300 "SELECT id, template, embedding FROM plan_cache \
301 WHERE embedding IS NOT NULL AND embedding_model = ? \
302 ORDER BY last_accessed_at DESC LIMIT ?"
303 ))
304 .bind(embedding_model)
305 .bind(self.config.max_templates)
306 .fetch_all(&self.pool)
307 .await?;
308
309 let mut best_score = -1.0_f32;
310 let mut best_id: Option<String> = None;
311 let mut best_template_json: Option<String> = None;
312
313 for (id, template_json, blob) in rows {
314 if let Some(stored) = blob_to_embedding(&blob) {
315 let score = zeph_common::math::cosine_similarity(goal_embedding, &stored);
316 if score > best_score {
317 best_score = score;
318 best_id = Some(id);
319 best_template_json = Some(template_json);
320 }
321 }
322 }
323
324 if best_score >= self.config.similarity_threshold
325 && let (Some(id), Some(json)) = (best_id, best_template_json)
326 {
327 let now = unix_now();
329 if let Err(e) = zeph_db::query(sql!(
330 "UPDATE plan_cache SET last_accessed_at = ?, adapted_count = adapted_count + 1 \
331 WHERE id = ?"
332 ))
333 .bind(now)
334 .bind(&id)
335 .execute(&self.pool)
336 .await
337 {
338 tracing::warn!(error = %e, "plan cache: failed to update last_accessed_at");
339 }
340 let template: PlanTemplate = serde_json::from_str(&json)?;
341 return Ok(Some((template, best_score)));
342 }
343
344 Ok(None)
345 }
346
347 pub async fn cache_plan(
357 &self,
358 graph: &TaskGraph,
359 goal_embedding: &[f32],
360 embedding_model: &str,
361 ) -> Result<(), PlanCacheError> {
362 let template = PlanTemplate::from_task_graph(graph)
363 .map_err(|e| PlanCacheError::Extraction(e.to_string()))?;
364
365 let normalized = normalize_goal(&graph.goal);
366 let hash = goal_hash(&normalized);
367 let template_json = serde_json::to_string(&template)?;
368 let task_count = i64::try_from(template.tasks.len()).unwrap_or(i64::MAX);
369 let now = unix_now();
370 let id = uuid::Uuid::new_v4().to_string();
371 let blob = embedding_to_blob(goal_embedding);
372
373 zeph_db::query(sql!(
374 "INSERT INTO plan_cache \
375 (id, goal_hash, goal_text, template, task_count, success_count, adapted_count, \
376 embedding, embedding_model, created_at, last_accessed_at) \
377 VALUES (?, ?, ?, ?, ?, 1, 0, ?, ?, ?, ?) \
378 ON CONFLICT(goal_hash) DO UPDATE SET \
379 success_count = success_count + 1, \
380 template = excluded.template, \
381 task_count = excluded.task_count, \
382 embedding = excluded.embedding, \
383 embedding_model = excluded.embedding_model, \
384 last_accessed_at = excluded.last_accessed_at"
385 ))
386 .bind(&id)
387 .bind(&hash)
388 .bind(&normalized)
389 .bind(&template_json)
390 .bind(task_count)
391 .bind(&blob)
392 .bind(embedding_model)
393 .bind(now)
394 .bind(now)
395 .execute(&self.pool)
396 .await?;
397
398 if let Err(e) = self.evict().await {
400 tracing::warn!(error = %e, "plan cache: eviction failed after cache_plan");
401 }
402
403 Ok(())
404 }
405
406 pub async fn evict(&self) -> Result<u32, PlanCacheError> {
417 let now = unix_now();
418 let ttl_secs = i64::from(self.config.ttl_days) * 86_400;
419 let cutoff = now.saturating_sub(ttl_secs);
420
421 let ttl_deleted = zeph_db::query(sql!("DELETE FROM plan_cache WHERE last_accessed_at < ?"))
422 .bind(cutoff)
423 .execute(&self.pool)
424 .await?
425 .rows_affected();
426
427 let count: i64 = zeph_db::query_scalar(sql!("SELECT COUNT(*) FROM plan_cache"))
429 .fetch_one(&self.pool)
430 .await?;
431
432 let max = i64::from(self.config.max_templates);
433 let lru_deleted = if count > max {
434 let excess = count - max;
435 zeph_db::query(sql!(
436 "DELETE FROM plan_cache WHERE id IN \
437 (SELECT id FROM plan_cache ORDER BY last_accessed_at ASC LIMIT ?)"
438 ))
439 .bind(excess)
440 .execute(&self.pool)
441 .await?
442 .rows_affected()
443 } else {
444 0
445 };
446
447 let total = ttl_deleted + lru_deleted;
448 if total > 0 {
449 tracing::debug!(ttl_deleted, lru_deleted, "plan cache: eviction complete");
450 }
451 Ok(u32::try_from(total).unwrap_or(u32::MAX))
452 }
453}
454
455#[allow(clippy::too_many_arguments)] pub async fn plan_with_cache<P>(
465 planner: &P,
466 plan_cache: Option<&PlanCache>,
467 provider: &impl LlmProvider,
468 embedding: Option<&[f32]>,
469 embedding_model: &str,
470 goal: &str,
471 available_agents: &[SubAgentDef],
472 max_tasks: u32,
473) -> Result<(TaskGraph, Option<(u64, u64)>), OrchestrationError>
474where
475 P: super::planner::Planner,
476{
477 if let (Some(cache), Some(emb)) = (plan_cache, embedding)
478 && cache.config.enabled
479 {
480 match cache.find_similar(emb, embedding_model).await {
481 Ok(Some((template, score))) => {
482 tracing::info!(
483 similarity = score,
484 tasks = template.tasks.len(),
485 "plan cache hit, adapting template"
486 );
487 match adapt_plan(provider, goal, &template, available_agents, max_tasks).await {
488 Ok(result) => return Ok(result),
489 Err(e) => {
490 tracing::warn!(
491 error = %e,
492 "plan cache: adaptation failed, falling back to full decomposition"
493 );
494 }
495 }
496 }
497 Ok(None) => {
498 tracing::debug!("plan cache miss");
499 }
500 Err(e) => {
501 tracing::warn!(error = %e, "plan cache: find_similar failed, using full decomposition");
502 }
503 }
504 }
505
506 planner.plan(goal, available_agents).await
507}
508
509async fn adapt_plan(
520 provider: &impl LlmProvider,
521 goal: &str,
522 template: &PlanTemplate,
523 available_agents: &[SubAgentDef],
524 max_tasks: u32,
525) -> Result<(TaskGraph, Option<(u64, u64)>), OrchestrationError> {
526 use zeph_subagent::ToolPolicy;
527
528 let agent_catalog = available_agents
529 .iter()
530 .map(|a| {
531 let tools = match &a.tools {
532 ToolPolicy::AllowList(list) => list.join(", "),
533 ToolPolicy::DenyList(excluded) => {
534 format!("all except: [{}]", excluded.join(", "))
535 }
536 ToolPolicy::InheritAll => "all".to_string(),
537 };
538 format!(
539 "- name: \"{}\", description: \"{}\", tools: [{}]",
540 a.name, a.description, tools
541 )
542 })
543 .collect::<Vec<_>>()
544 .join("\n");
545
546 let template_json = serde_json::to_string(&template.tasks)
547 .map_err(|e| OrchestrationError::PlanningFailed(e.to_string()))?;
548
549 let system = format!(
550 "You are a task planner. A cached plan template exists for a similar goal. \
551 Adapt it for the new goal by adjusting task descriptions and adding or removing \
552 tasks as needed. Keep the same JSON structure.\n\n\
553 Available agents:\n{agent_catalog}\n\n\
554 Rules:\n\
555 - Each task must have a unique task_id (short, descriptive, kebab-case: [a-z0-9-]).\n\
556 - Specify dependencies using task_id strings in depends_on.\n\
557 - Do not create more than {max_tasks} tasks.\n\
558 - failure_strategy is optional: \"abort\", \"retry\", \"skip\", \"ask\"."
559 );
560
561 let user = format!(
562 "New goal:\n{goal}\n\nCached template (for similar goal \"{}\"):\n{template_json}\n\n\
563 Adapt the template for the new goal. Return JSON: {{\"tasks\": [...]}}",
564 template.goal
565 );
566
567 let messages = vec![
568 Message::from_legacy(Role::System, system),
569 Message::from_legacy(Role::User, user),
570 ];
571
572 let response: PlannerResponse = provider
573 .chat_typed(&messages)
574 .await
575 .map_err(|e| OrchestrationError::PlanningFailed(e.to_string()))?;
576
577 let usage = provider.last_usage();
578
579 let graph = convert_response_pub(response, goal, available_agents, max_tasks)?;
580
581 dag::validate(&graph.tasks, max_tasks as usize)?;
582
583 Ok((graph, usage))
584}
585
586#[cfg(test)]
587mod tests {
588 use super::super::graph::{TaskId, TaskNode};
589 use super::*;
590 use zeph_memory::store::SqliteStore;
591
592 async fn test_pool() -> DbPool {
593 let store = SqliteStore::new(":memory:").await.unwrap();
594 store.pool().clone()
595 }
596
597 async fn test_cache(pool: DbPool) -> PlanCache {
598 PlanCache::new(pool, PlanCacheConfig::default(), "test-model")
599 .await
600 .unwrap()
601 }
602
603 fn make_graph(goal: &str, tasks: &[(&str, &str, &[u32])]) -> TaskGraph {
604 let mut graph = TaskGraph::new(goal);
605 for (i, (title, desc, deps)) in tasks.iter().enumerate() {
606 #[allow(clippy::cast_possible_truncation)]
607 let mut node = TaskNode::new(i as u32, *title, *desc);
608 node.depends_on = deps.iter().map(|&d| TaskId(d)).collect();
609 graph.tasks.push(node);
610 }
611 graph
612 }
613
614 #[test]
617 fn normalize_trims_and_lowercases() {
618 assert_eq!(normalize_goal(" Hello World "), "hello world");
619 }
620
621 #[test]
622 fn normalize_collapses_internal_whitespace() {
623 assert_eq!(normalize_goal("hello world"), "hello world");
624 }
625
626 #[test]
627 fn normalize_empty_string() {
628 assert_eq!(normalize_goal(""), "");
629 }
630
631 #[test]
632 fn normalize_whitespace_only() {
633 assert_eq!(normalize_goal(" "), "");
634 }
635
636 #[test]
639 fn goal_hash_is_deterministic() {
640 let h1 = goal_hash("deploy service");
641 let h2 = goal_hash("deploy service");
642 assert_eq!(h1, h2);
643 }
644
645 #[test]
646 fn goal_hash_differs_for_different_goals() {
647 assert_ne!(goal_hash("deploy service"), goal_hash("build artifact"));
648 }
649
650 #[test]
651 fn goal_hash_nonempty() {
652 assert!(!goal_hash("goal").is_empty());
653 }
654
655 #[test]
658 fn template_from_empty_graph_returns_error() {
659 let graph = TaskGraph::new("goal");
660 assert!(PlanTemplate::from_task_graph(&graph).is_err());
661 }
662
663 #[test]
664 fn template_strips_runtime_fields() {
665 use crate::graph::TaskStatus;
666 let mut graph = make_graph("goal", &[("Fetch data", "Download it", &[])]);
667 graph.tasks[0].status = TaskStatus::Completed;
668 graph.tasks[0].retry_count = 3;
669 graph.tasks[0].assigned_agent = Some("agent-x".to_string());
670 let template = PlanTemplate::from_task_graph(&graph).unwrap();
671 assert_eq!(template.tasks[0].title, "Fetch data");
673 assert_eq!(template.tasks[0].description, "Download it");
674 }
675
676 #[test]
677 fn template_preserves_dependencies() {
678 let graph = make_graph("goal", &[("Task A", "do A", &[]), ("Task B", "do B", &[0])]);
679 let template = PlanTemplate::from_task_graph(&graph).unwrap();
680 assert_eq!(template.tasks.len(), 2);
681 assert!(template.tasks[0].depends_on.is_empty());
682 assert_eq!(template.tasks[1].depends_on.len(), 1);
683 assert_eq!(template.tasks[1].depends_on[0], template.tasks[0].task_id);
684 }
685
686 #[test]
687 fn template_serde_roundtrip() {
688 let graph = make_graph("goal", &[("Step one", "do step one", &[])]);
689 let template = PlanTemplate::from_task_graph(&graph).unwrap();
690 let json = serde_json::to_string(&template).unwrap();
691 let restored: PlanTemplate = serde_json::from_str(&json).unwrap();
692 assert_eq!(template.tasks[0].title, restored.tasks[0].title);
693 assert_eq!(template.goal, restored.goal);
694 }
695
696 #[test]
699 fn embedding_blob_roundtrip() {
700 let embedding = vec![1.0_f32, 0.5, 0.25, -1.0];
701 let blob = embedding_to_blob(&embedding);
702 let restored = blob_to_embedding(&blob).unwrap();
703 assert_eq!(embedding, restored);
704 }
705
706 #[test]
707 fn blob_to_embedding_odd_length_returns_none() {
708 let bad_blob = vec![0u8; 5]; assert!(blob_to_embedding(&bad_blob).is_none());
710 }
711
712 #[tokio::test]
715 async fn cache_miss_on_empty_cache() {
716 let pool = test_pool().await;
717 let cache = test_cache(pool).await;
718 let result = cache
719 .find_similar(&[1.0, 0.0, 0.0], "test-model")
720 .await
721 .unwrap();
722 assert!(result.is_none());
723 }
724
725 #[tokio::test]
726 async fn cache_store_and_hit() {
727 let pool = test_pool().await;
728 let config = PlanCacheConfig {
729 similarity_threshold: 0.9,
730 ..PlanCacheConfig::default()
731 };
732 let cache = PlanCache::new(pool, config, "test-model").await.unwrap();
733
734 let graph = make_graph("deploy service", &[("Build", "build it", &[])]);
735 let embedding = vec![1.0_f32, 0.0, 0.0];
736 cache
737 .cache_plan(&graph, &embedding, "test-model")
738 .await
739 .unwrap();
740
741 let result = cache
743 .find_similar(&[1.0, 0.0, 0.0], "test-model")
744 .await
745 .unwrap();
746 assert!(result.is_some());
747 let (template, score) = result.unwrap();
748 assert!((score - 1.0).abs() < 1e-5);
749 assert_eq!(template.tasks.len(), 1);
750 }
751
752 #[tokio::test]
753 async fn cache_miss_on_dissimilar_goal() {
754 let pool = test_pool().await;
755 let config = PlanCacheConfig {
756 similarity_threshold: 0.9,
757 ..PlanCacheConfig::default()
758 };
759 let cache = PlanCache::new(pool, config, "test-model").await.unwrap();
760
761 let graph = make_graph("goal a", &[("Task", "do it", &[])]);
762 cache
763 .cache_plan(&graph, &[1.0_f32, 0.0, 0.0], "test-model")
764 .await
765 .unwrap();
766
767 let result = cache
769 .find_similar(&[0.0, 1.0, 0.0], "test-model")
770 .await
771 .unwrap();
772 assert!(result.is_none());
773 }
774
775 #[tokio::test]
776 async fn deduplication_increments_success_count() {
777 let pool = test_pool().await;
778 let cache = test_cache(pool.clone()).await;
779
780 let graph = make_graph("same goal", &[("Task", "do it", &[])]);
781 let emb = vec![1.0_f32, 0.0];
782
783 cache.cache_plan(&graph, &emb, "test-model").await.unwrap();
784 cache.cache_plan(&graph, &emb, "test-model").await.unwrap();
785
786 let count: i64 = zeph_db::query_scalar(sql!("SELECT COUNT(*) FROM plan_cache"))
788 .fetch_one(&pool)
789 .await
790 .unwrap();
791 assert_eq!(count, 1);
792
793 let success: i64 = zeph_db::query_scalar(sql!("SELECT success_count FROM plan_cache"))
794 .fetch_one(&pool)
795 .await
796 .unwrap();
797 assert_eq!(success, 2);
798 }
799
800 #[tokio::test]
801 async fn eviction_removes_ttl_expired_rows() {
802 let pool = test_pool().await;
803 let config = PlanCacheConfig {
805 ttl_days: 0,
806 ..PlanCacheConfig::default()
807 };
808 let cache = PlanCache::new(pool.clone(), config, "test-model")
809 .await
810 .unwrap();
811
812 let now = unix_now() - 1;
814 zeph_db::query(sql!(
815 "INSERT INTO plan_cache \
816 (id, goal_hash, goal_text, template, task_count, created_at, last_accessed_at) \
817 VALUES (?, ?, ?, ?, ?, ?, ?)"
818 ))
819 .bind("test-id")
820 .bind("hash-1")
821 .bind("goal")
822 .bind("{\"goal\":\"goal\",\"tasks\":[]}")
823 .bind(0_i64)
824 .bind(now)
825 .bind(now)
826 .execute(&pool)
827 .await
828 .unwrap();
829
830 let deleted = cache.evict().await.unwrap();
831 assert!(deleted >= 1);
832
833 let count: i64 = zeph_db::query_scalar(sql!("SELECT COUNT(*) FROM plan_cache"))
834 .fetch_one(&pool)
835 .await
836 .unwrap();
837 assert_eq!(count, 0);
838 }
839
840 #[tokio::test]
841 async fn eviction_lru_when_over_max() {
842 let pool = test_pool().await;
843 let config = PlanCacheConfig {
844 max_templates: 2,
845 ttl_days: 365,
846 ..PlanCacheConfig::default()
847 };
848 let cache = PlanCache::new(pool.clone(), config, "test-model")
849 .await
850 .unwrap();
851
852 let now = unix_now();
853 for i in 0..3_i64 {
855 zeph_db::query(sql!(
856 "INSERT INTO plan_cache \
857 (id, goal_hash, goal_text, template, task_count, created_at, last_accessed_at) \
858 VALUES (?, ?, ?, ?, ?, ?, ?)"
859 ))
860 .bind(format!("id-{i}"))
861 .bind(format!("hash-{i}"))
862 .bind(format!("goal-{i}"))
863 .bind("{\"goal\":\"g\",\"tasks\":[]}")
864 .bind(0_i64)
865 .bind(now)
866 .bind(now + i) .execute(&pool)
868 .await
869 .unwrap();
870 }
871
872 let deleted = cache.evict().await.unwrap();
873 assert_eq!(deleted, 1);
874
875 let count: i64 = zeph_db::query_scalar(sql!("SELECT COUNT(*) FROM plan_cache"))
877 .fetch_one(&pool)
878 .await
879 .unwrap();
880 assert_eq!(count, 2);
881 }
882
883 #[tokio::test]
884 async fn stale_embedding_invalidated_on_new() {
885 let pool = test_pool().await;
886 let now = unix_now();
887
888 let emb = embedding_to_blob(&[1.0_f32, 0.0]);
890 zeph_db::query(sql!(
891 "INSERT INTO plan_cache \
892 (id, goal_hash, goal_text, template, task_count, embedding, embedding_model, \
893 created_at, last_accessed_at) \
894 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"
895 ))
896 .bind("id-old")
897 .bind("hash-old")
898 .bind("goal old")
899 .bind("{\"goal\":\"g\",\"tasks\":[]}")
900 .bind(0_i64)
901 .bind(&emb)
902 .bind("old-model")
903 .bind(now)
904 .bind(now)
905 .execute(&pool)
906 .await
907 .unwrap();
908
909 let _cache = PlanCache::new(pool.clone(), PlanCacheConfig::default(), "new-model")
911 .await
912 .unwrap();
913
914 let model: Option<String> = zeph_db::query_scalar(sql!(
915 "SELECT embedding_model FROM plan_cache WHERE id = 'id-old'"
916 ))
917 .fetch_one(&pool)
918 .await
919 .unwrap();
920 assert!(model.is_none(), "stale embedding_model should be NULL");
921
922 let emb_col: Option<Vec<u8>> =
923 zeph_db::query_scalar(sql!("SELECT embedding FROM plan_cache WHERE id = 'id-old'"))
924 .fetch_one(&pool)
925 .await
926 .unwrap();
927 assert!(emb_col.is_none(), "stale embedding should be NULL");
928 }
929
930 #[tokio::test]
931 async fn disabled_cache_not_used_in_plan_with_cache() {
932 use crate::planner::LlmPlanner;
933 use zeph_config::OrchestrationConfig;
934 use zeph_llm::mock::MockProvider;
935
936 let pool = test_pool().await;
937 let config = PlanCacheConfig::default(); let cache = PlanCache::new(pool, config, "test-model").await.unwrap();
939
940 let graph_json = r#"{"tasks": [
941 {"task_id": "t1", "title": "Task", "description": "do it", "depends_on": []}
942 ]}"#
943 .to_string();
944
945 let provider = MockProvider::with_responses(vec![graph_json.clone()]);
946 let planner = LlmPlanner::new(
947 MockProvider::with_responses(vec![graph_json]),
948 &OrchestrationConfig::default(),
949 );
950
951 let (graph, _) = plan_with_cache(
952 &planner,
953 Some(&cache),
954 &provider,
955 Some(&[1.0_f32, 0.0]),
956 "test-model",
957 "do something",
958 &[],
959 20,
960 )
961 .await
962 .unwrap();
963
964 assert_eq!(graph.tasks.len(), 1);
965 }
966
967 #[tokio::test]
968 async fn plan_with_cache_with_none_embedding_skips_cache() {
969 use crate::planner::LlmPlanner;
970 use zeph_config::OrchestrationConfig;
971 use zeph_llm::mock::MockProvider;
972
973 let pool = test_pool().await;
974 let config = PlanCacheConfig {
975 enabled: true,
976 similarity_threshold: 0.5,
977 ..PlanCacheConfig::default()
978 };
979 let cache = PlanCache::new(pool, config, "test-model").await.unwrap();
980
981 let graph = make_graph("deploy service", &[("Build", "build it", &[])]);
983 cache
984 .cache_plan(&graph, &[1.0_f32, 0.0], "test-model")
985 .await
986 .unwrap();
987
988 let graph_json = r#"{"tasks": [
989 {"task_id": "fallback-task-0", "title": "Fallback", "description": "planner fallback", "depends_on": []}
990 ]}"#
991 .to_string();
992
993 let provider = MockProvider::with_responses(vec![graph_json.clone()]);
994 let planner = LlmPlanner::new(
995 MockProvider::with_responses(vec![graph_json]),
996 &OrchestrationConfig::default(),
997 );
998
999 let (result_graph, _) = plan_with_cache(
1001 &planner,
1002 Some(&cache),
1003 &provider,
1004 None, "test-model",
1006 "deploy service",
1007 &[],
1008 20,
1009 )
1010 .await
1011 .unwrap();
1012
1013 assert_eq!(result_graph.tasks[0].title, "Fallback");
1014 }
1015
1016 #[tokio::test]
1017 async fn adapt_plan_error_fallback_to_full_decomposition() {
1018 use crate::planner::LlmPlanner;
1019 use zeph_config::OrchestrationConfig;
1020 use zeph_llm::mock::MockProvider;
1021
1022 let pool = test_pool().await;
1023 let config = PlanCacheConfig {
1024 enabled: true,
1025 similarity_threshold: 0.5,
1026 ..PlanCacheConfig::default()
1027 };
1028 let cache = PlanCache::new(pool, config, "test-model").await.unwrap();
1029
1030 let graph = make_graph("deploy service", &[("Build", "build it", &[])]);
1032 cache
1033 .cache_plan(&graph, &[1.0_f32, 0.0], "test-model")
1034 .await
1035 .unwrap();
1036
1037 let bad_provider = MockProvider::with_responses(vec!["not valid json".to_string()]);
1039
1040 let fallback_json = r#"{"tasks": [
1042 {"task_id": "fallback-0", "title": "Fallback Task", "description": "via planner", "depends_on": []}
1043 ]}"#
1044 .to_string();
1045 let planner = LlmPlanner::new(
1046 MockProvider::with_responses(vec![fallback_json]),
1047 &OrchestrationConfig::default(),
1048 );
1049
1050 let (result_graph, _) = plan_with_cache(
1051 &planner,
1052 Some(&cache),
1053 &bad_provider, Some(&[1.0_f32, 0.0]),
1055 "test-model",
1056 "deploy service",
1057 &[],
1058 20,
1059 )
1060 .await
1061 .unwrap();
1062
1063 assert_eq!(result_graph.tasks[0].title, "Fallback Task");
1065 }
1066}