scouter_evaluate/evaluate/
store.rs1use chrono::{DateTime, Utc};
2use scouter_types::genai::AssertionResult;
3use std::sync::RwLock;
4use std::{
5 collections::{HashMap, HashSet},
6 sync::Arc,
7};
8use tracing::debug;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum TaskType {
12 Assertion,
13 LLMJudge,
14}
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub struct TaskMetadata {
18 pub task_type: TaskType,
19 pub is_conditional: bool,
20}
21
22#[derive(Debug)]
24pub struct TaskRegistry {
25 registry: HashMap<String, TaskMetadata>,
27 dependency_map: HashMap<String, Arc<Vec<String>>>,
28 skipped_tasks: RwLock<HashSet<String>>,
29}
30
31impl TaskRegistry {
32 pub fn new() -> Self {
33 debug!("Initializing TaskRegistry");
34 Self {
35 registry: HashMap::new(),
36 dependency_map: HashMap::new(),
37 skipped_tasks: RwLock::new(HashSet::new()),
38 }
39 }
40
41 pub fn mark_skipped(&self, task_id: String) {
42 self.skipped_tasks.write().unwrap().insert(task_id);
43 }
44
45 pub fn is_skipped(&self, task_id: &str) -> bool {
46 self.skipped_tasks.read().unwrap().contains(task_id)
47 }
48
49 pub fn register(&mut self, task_id: String, task_type: TaskType, is_conditional: bool) {
51 self.registry.insert(
52 task_id,
53 TaskMetadata {
54 task_type,
55 is_conditional,
56 },
57 );
58 }
59
60 pub fn get_type(&self, task_id: &str) -> Option<TaskType> {
62 self.registry.get(task_id).map(|meta| meta.task_type)
63 }
64
65 pub fn is_conditional(&self, task_id: &str) -> bool {
67 self.registry
68 .get(task_id)
69 .map(|meta| meta.is_conditional)
70 .unwrap_or(false)
71 }
72
73 pub fn contains(&self, task_id: &str) -> bool {
75 self.registry.contains_key(task_id)
76 }
77
78 pub fn register_dependencies(&mut self, task_id: String, dependencies: Vec<String>) {
83 self.dependency_map.insert(task_id, Arc::new(dependencies));
84 }
85
86 pub fn get_dependencies(&self, task_id: &str) -> Option<Arc<Vec<String>>> {
90 if !self.registry.contains_key(task_id) {
91 return None;
92 }
93
94 Some(self.dependency_map.get(task_id)?.clone())
95 }
96}
97
98impl Default for TaskRegistry {
99 fn default() -> Self {
100 Self::new()
101 }
102}
103
104type AssertionResultType = (DateTime<Utc>, DateTime<Utc>, AssertionResult);
106
107#[derive(Debug)]
109pub struct AssertionResultStore {
110 store: HashMap<String, AssertionResultType>,
112}
113
114impl AssertionResultStore {
115 pub fn new() -> Self {
116 AssertionResultStore {
117 store: HashMap::new(),
118 }
119 }
120
121 pub fn store(
122 &mut self,
123 task_id: String,
124 start_time: DateTime<Utc>,
125 end_time: DateTime<Utc>,
126 result: AssertionResult,
127 ) {
128 self.store.insert(task_id, (start_time, end_time, result));
129 }
130
131 pub fn retrieve(&self, task_id: &str) -> Option<AssertionResultType> {
132 self.store.get(task_id).cloned()
133 }
134}
135
136impl Default for AssertionResultStore {
137 fn default() -> Self {
138 Self::new()
139 }
140}
141
142#[derive(Debug)]
144pub struct LLMResponseStore {
145 store: HashMap<String, serde_json::Value>,
147}
148
149impl LLMResponseStore {
150 pub fn new() -> Self {
151 LLMResponseStore {
152 store: HashMap::new(),
153 }
154 }
155
156 pub fn store(&mut self, task_id: String, response: serde_json::Value) {
157 self.store.insert(task_id, response);
158 }
159
160 pub fn retrieve(&self, task_id: &str) -> Option<serde_json::Value> {
161 self.store.get(task_id).cloned()
162 }
163}
164
165impl Default for LLMResponseStore {
166 fn default() -> Self {
167 Self::new()
168 }
169}
170
171#[cfg(test)]
173mod tests {
174 use super::*;
175 #[test]
176 fn test_task_registry() {
177 let mut registry = TaskRegistry::new();
178 registry.register("task1".to_string(), TaskType::Assertion, false);
179 registry.register("task2".to_string(), TaskType::LLMJudge, false);
180 assert_eq!(registry.get_type("task1"), Some(TaskType::Assertion));
181 assert_eq!(registry.get_type("task2"), Some(TaskType::LLMJudge));
182 assert!(registry.contains("task1"));
183 assert!(!registry.contains("task3"));
184 }
185}