Skip to main content

scouter_evaluate/evaluate/
store.rs

1use chrono::{DateTime, Utc};
2use scouter_types::genai::AssertionResult;
3use std::sync::RwLock;
4use std::{
5    collections::{HashMap, HashSet},
6    sync::Arc,
7};
8use tracing::debug;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum TaskType {
12    Assertion,
13    LLMJudge,
14    TraceAssertion,
15}
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub struct TaskMetadata {
19    pub task_type: TaskType,
20    pub is_conditional: bool,
21}
22
23/// Registry that tracks task IDs and their types for store routing
24#[derive(Debug)]
25pub struct TaskRegistry {
26    /// Maps task_id -> TaskType
27    registry: HashMap<String, TaskMetadata>,
28    dependency_map: HashMap<String, Arc<Vec<String>>>,
29    skipped_tasks: RwLock<HashSet<String>>,
30}
31
32impl TaskRegistry {
33    pub fn new() -> Self {
34        debug!("Initializing TaskRegistry");
35        Self {
36            registry: HashMap::new(),
37            dependency_map: HashMap::new(),
38            skipped_tasks: RwLock::new(HashSet::new()),
39        }
40    }
41
42    pub fn mark_skipped(&self, task_id: String) {
43        self.skipped_tasks.write().unwrap().insert(task_id);
44    }
45
46    pub fn is_skipped(&self, task_id: &str) -> bool {
47        self.skipped_tasks.read().unwrap().contains(task_id)
48    }
49
50    /// Register a task with its type and conditional status
51    pub fn register(&mut self, task_id: String, task_type: TaskType, is_conditional: bool) {
52        self.registry.insert(
53            task_id,
54            TaskMetadata {
55                task_type,
56                is_conditional,
57            },
58        );
59    }
60
61    /// Get the type of a task by ID
62    pub fn get_type(&self, task_id: &str) -> Option<TaskType> {
63        self.registry.get(task_id).map(|meta| meta.task_type)
64    }
65
66    /// Check if a task is marked as conditional
67    pub fn is_conditional(&self, task_id: &str) -> bool {
68        self.registry
69            .get(task_id)
70            .map(|meta| meta.is_conditional)
71            .unwrap_or(false)
72    }
73
74    /// Check if a task is registered
75    pub fn contains(&self, task_id: &str) -> bool {
76        self.registry.contains_key(task_id)
77    }
78
79    /// Register dependencies for a task
80    /// # Arguments
81    /// * `task_id` - The ID of the task
82    /// * `dependencies` - A list of task IDs that this task depends on
83    pub fn register_dependencies(&mut self, task_id: String, dependencies: Vec<String>) {
84        self.dependency_map.insert(task_id, Arc::new(dependencies));
85    }
86
87    /// For a given task ID, get its dependencies
88    /// # Arguments
89    /// * `task_id` - The ID of the task
90    pub fn get_dependencies(&self, task_id: &str) -> Option<Arc<Vec<String>>> {
91        if !self.registry.contains_key(task_id) {
92            return None;
93        }
94
95        Some(self.dependency_map.get(task_id)?.clone())
96    }
97}
98
99impl Default for TaskRegistry {
100    fn default() -> Self {
101        Self::new()
102    }
103}
104
105// start time, end time, result
106type AssertionResultType = (DateTime<Utc>, DateTime<Utc>, AssertionResult);
107
108/// Store for assertion results
109#[derive(Debug)]
110pub struct AssertionResultStore {
111    /// Internal store mapping task IDs to results
112    store: HashMap<String, AssertionResultType>,
113}
114
115impl AssertionResultStore {
116    pub fn new() -> Self {
117        AssertionResultStore {
118            store: HashMap::new(),
119        }
120    }
121
122    pub fn store(
123        &mut self,
124        task_id: String,
125        start_time: DateTime<Utc>,
126        end_time: DateTime<Utc>,
127        result: AssertionResult,
128    ) {
129        self.store.insert(task_id, (start_time, end_time, result));
130    }
131
132    pub fn retrieve(&self, task_id: &str) -> Option<AssertionResultType> {
133        self.store.get(task_id).cloned()
134    }
135}
136
137impl Default for AssertionResultStore {
138    fn default() -> Self {
139        Self::new()
140    }
141}
142
143/// Store for raw LLM responses
144#[derive(Debug)]
145pub struct LLMResponseStore {
146    /// Internal store mapping task IDs to raw LLM responses
147    store: HashMap<String, serde_json::Value>,
148}
149
150impl LLMResponseStore {
151    pub fn new() -> Self {
152        LLMResponseStore {
153            store: HashMap::new(),
154        }
155    }
156
157    pub fn store(&mut self, task_id: String, response: serde_json::Value) {
158        self.store.insert(task_id, response);
159    }
160
161    pub fn retrieve(&self, task_id: &str) -> Option<serde_json::Value> {
162        self.store.get(task_id).cloned()
163    }
164}
165
166impl Default for LLMResponseStore {
167    fn default() -> Self {
168        Self::new()
169    }
170}
171
172// create a very simple test
173#[cfg(test)]
174mod tests {
175    use super::*;
176    #[test]
177    fn test_task_registry() {
178        let mut registry = TaskRegistry::new();
179        registry.register("task1".to_string(), TaskType::Assertion, false);
180        registry.register("task2".to_string(), TaskType::LLMJudge, false);
181        assert_eq!(registry.get_type("task1"), Some(TaskType::Assertion));
182        assert_eq!(registry.get_type("task2"), Some(TaskType::LLMJudge));
183        assert!(registry.contains("task1"));
184        assert!(!registry.contains("task3"));
185    }
186}