Skip to main content

scouter_evaluate/evaluate/
store.rs

1use chrono::{DateTime, Utc};
2use scouter_types::genai::AssertionResult;
3use std::sync::RwLock;
4use std::{
5    collections::{HashMap, HashSet},
6    sync::Arc,
7};
8use tracing::debug;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum TaskType {
12    Assertion,
13    LLMJudge,
14}
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub struct TaskMetadata {
18    pub task_type: TaskType,
19    pub is_conditional: bool,
20}
21
22/// Registry that tracks task IDs and their types for store routing
23#[derive(Debug)]
24pub struct TaskRegistry {
25    /// Maps task_id -> TaskType
26    registry: HashMap<String, TaskMetadata>,
27    dependency_map: HashMap<String, Arc<Vec<String>>>,
28    skipped_tasks: RwLock<HashSet<String>>,
29}
30
31impl TaskRegistry {
32    pub fn new() -> Self {
33        debug!("Initializing TaskRegistry");
34        Self {
35            registry: HashMap::new(),
36            dependency_map: HashMap::new(),
37            skipped_tasks: RwLock::new(HashSet::new()),
38        }
39    }
40
41    pub fn mark_skipped(&self, task_id: String) {
42        self.skipped_tasks.write().unwrap().insert(task_id);
43    }
44
45    pub fn is_skipped(&self, task_id: &str) -> bool {
46        self.skipped_tasks.read().unwrap().contains(task_id)
47    }
48
49    /// Register a task with its type and conditional status
50    pub fn register(&mut self, task_id: String, task_type: TaskType, is_conditional: bool) {
51        self.registry.insert(
52            task_id,
53            TaskMetadata {
54                task_type,
55                is_conditional,
56            },
57        );
58    }
59
60    /// Get the type of a task by ID
61    pub fn get_type(&self, task_id: &str) -> Option<TaskType> {
62        self.registry.get(task_id).map(|meta| meta.task_type)
63    }
64
65    /// Check if a task is marked as conditional
66    pub fn is_conditional(&self, task_id: &str) -> bool {
67        self.registry
68            .get(task_id)
69            .map(|meta| meta.is_conditional)
70            .unwrap_or(false)
71    }
72
73    /// Check if a task is registered
74    pub fn contains(&self, task_id: &str) -> bool {
75        self.registry.contains_key(task_id)
76    }
77
78    /// Register dependencies for a task
79    /// # Arguments
80    /// * `task_id` - The ID of the task
81    /// * `dependencies` - A list of task IDs that this task depends on
82    pub fn register_dependencies(&mut self, task_id: String, dependencies: Vec<String>) {
83        self.dependency_map.insert(task_id, Arc::new(dependencies));
84    }
85
86    /// For a given task ID, get its dependencies
87    /// # Arguments
88    /// * `task_id` - The ID of the task
89    pub fn get_dependencies(&self, task_id: &str) -> Option<Arc<Vec<String>>> {
90        if !self.registry.contains_key(task_id) {
91            return None;
92        }
93
94        Some(self.dependency_map.get(task_id)?.clone())
95    }
96}
97
98impl Default for TaskRegistry {
99    fn default() -> Self {
100        Self::new()
101    }
102}
103
104// start time, end time, result
105type AssertionResultType = (DateTime<Utc>, DateTime<Utc>, AssertionResult);
106
107/// Store for assertion results
108#[derive(Debug)]
109pub struct AssertionResultStore {
110    /// Internal store mapping task IDs to results
111    store: HashMap<String, AssertionResultType>,
112}
113
114impl AssertionResultStore {
115    pub fn new() -> Self {
116        AssertionResultStore {
117            store: HashMap::new(),
118        }
119    }
120
121    pub fn store(
122        &mut self,
123        task_id: String,
124        start_time: DateTime<Utc>,
125        end_time: DateTime<Utc>,
126        result: AssertionResult,
127    ) {
128        self.store.insert(task_id, (start_time, end_time, result));
129    }
130
131    pub fn retrieve(&self, task_id: &str) -> Option<AssertionResultType> {
132        self.store.get(task_id).cloned()
133    }
134}
135
136impl Default for AssertionResultStore {
137    fn default() -> Self {
138        Self::new()
139    }
140}
141
142/// Store for raw LLM responses
143#[derive(Debug)]
144pub struct LLMResponseStore {
145    /// Internal store mapping task IDs to raw LLM responses
146    store: HashMap<String, serde_json::Value>,
147}
148
149impl LLMResponseStore {
150    pub fn new() -> Self {
151        LLMResponseStore {
152            store: HashMap::new(),
153        }
154    }
155
156    pub fn store(&mut self, task_id: String, response: serde_json::Value) {
157        self.store.insert(task_id, response);
158    }
159
160    pub fn retrieve(&self, task_id: &str) -> Option<serde_json::Value> {
161        self.store.get(task_id).cloned()
162    }
163}
164
165impl Default for LLMResponseStore {
166    fn default() -> Self {
167        Self::new()
168    }
169}
170
171// create a very simple test
172#[cfg(test)]
173mod tests {
174    use super::*;
175    #[test]
176    fn test_task_registry() {
177        let mut registry = TaskRegistry::new();
178        registry.register("task1".to_string(), TaskType::Assertion, false);
179        registry.register("task2".to_string(), TaskType::LLMJudge, false);
180        assert_eq!(registry.get_type("task1"), Some(TaskType::Assertion));
181        assert_eq!(registry.get_type("task2"), Some(TaskType::LLMJudge));
182        assert!(registry.contains("task1"));
183        assert!(!registry.contains("task3"));
184    }
185}