scouter_evaluate/evaluate/
store.rs1use chrono::{DateTime, Utc};
2use scouter_types::genai::AssertionResult;
3use std::sync::RwLock;
4use std::{
5 collections::{HashMap, HashSet},
6 sync::Arc,
7};
8use tracing::debug;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum TaskType {
12 Assertion,
13 LLMJudge,
14 TraceAssertion,
15}
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub struct TaskMetadata {
19 pub task_type: TaskType,
20 pub is_conditional: bool,
21}
22
23#[derive(Debug)]
25pub struct TaskRegistry {
26 registry: HashMap<String, TaskMetadata>,
28 dependency_map: HashMap<String, Arc<Vec<String>>>,
29 skipped_tasks: RwLock<HashSet<String>>,
30}
31
32impl TaskRegistry {
33 pub fn new() -> Self {
34 debug!("Initializing TaskRegistry");
35 Self {
36 registry: HashMap::new(),
37 dependency_map: HashMap::new(),
38 skipped_tasks: RwLock::new(HashSet::new()),
39 }
40 }
41
42 pub fn mark_skipped(&self, task_id: String) {
43 self.skipped_tasks.write().unwrap().insert(task_id);
44 }
45
46 pub fn is_skipped(&self, task_id: &str) -> bool {
47 self.skipped_tasks.read().unwrap().contains(task_id)
48 }
49
50 pub fn register(&mut self, task_id: String, task_type: TaskType, is_conditional: bool) {
52 self.registry.insert(
53 task_id,
54 TaskMetadata {
55 task_type,
56 is_conditional,
57 },
58 );
59 }
60
61 pub fn get_type(&self, task_id: &str) -> Option<TaskType> {
63 self.registry.get(task_id).map(|meta| meta.task_type)
64 }
65
66 pub fn is_conditional(&self, task_id: &str) -> bool {
68 self.registry
69 .get(task_id)
70 .map(|meta| meta.is_conditional)
71 .unwrap_or(false)
72 }
73
74 pub fn contains(&self, task_id: &str) -> bool {
76 self.registry.contains_key(task_id)
77 }
78
79 pub fn register_dependencies(&mut self, task_id: String, dependencies: Vec<String>) {
84 self.dependency_map.insert(task_id, Arc::new(dependencies));
85 }
86
87 pub fn get_dependencies(&self, task_id: &str) -> Option<Arc<Vec<String>>> {
91 if !self.registry.contains_key(task_id) {
92 return None;
93 }
94
95 Some(self.dependency_map.get(task_id)?.clone())
96 }
97}
98
99impl Default for TaskRegistry {
100 fn default() -> Self {
101 Self::new()
102 }
103}
104
105type AssertionResultType = (DateTime<Utc>, DateTime<Utc>, AssertionResult);
107
108#[derive(Debug)]
110pub struct AssertionResultStore {
111 store: HashMap<String, AssertionResultType>,
113}
114
115impl AssertionResultStore {
116 pub fn new() -> Self {
117 AssertionResultStore {
118 store: HashMap::new(),
119 }
120 }
121
122 pub fn store(
123 &mut self,
124 task_id: String,
125 start_time: DateTime<Utc>,
126 end_time: DateTime<Utc>,
127 result: AssertionResult,
128 ) {
129 self.store.insert(task_id, (start_time, end_time, result));
130 }
131
132 pub fn retrieve(&self, task_id: &str) -> Option<AssertionResultType> {
133 self.store.get(task_id).cloned()
134 }
135}
136
137impl Default for AssertionResultStore {
138 fn default() -> Self {
139 Self::new()
140 }
141}
142
143#[derive(Debug)]
145pub struct LLMResponseStore {
146 store: HashMap<String, serde_json::Value>,
148}
149
150impl LLMResponseStore {
151 pub fn new() -> Self {
152 LLMResponseStore {
153 store: HashMap::new(),
154 }
155 }
156
157 pub fn store(&mut self, task_id: String, response: serde_json::Value) {
158 self.store.insert(task_id, response);
159 }
160
161 pub fn retrieve(&self, task_id: &str) -> Option<serde_json::Value> {
162 self.store.get(task_id).cloned()
163 }
164}
165
166impl Default for LLMResponseStore {
167 fn default() -> Self {
168 Self::new()
169 }
170}
171
172#[cfg(test)]
174mod tests {
175 use super::*;
176 #[test]
177 fn test_task_registry() {
178 let mut registry = TaskRegistry::new();
179 registry.register("task1".to_string(), TaskType::Assertion, false);
180 registry.register("task2".to_string(), TaskType::LLMJudge, false);
181 assert_eq!(registry.get_type("task1"), Some(TaskType::Assertion));
182 assert_eq!(registry.get_type("task2"), Some(TaskType::LLMJudge));
183 assert!(registry.contains("task1"));
184 assert!(!registry.contains("task3"));
185 }
186}