1use crate::error::EvaluationError;
2use crate::evaluate::agent::AgentContextBuilder;
3use crate::evaluate::store::{AssertionResultStore, LLMResponseStore, TaskRegistry, TaskType};
4use crate::evaluate::trace::TraceContextBuilder;
5use crate::tasks::agent::execute_agent_assertions;
6use crate::tasks::trace::execute_trace_assertions;
7use crate::tasks::traits::EvaluationTask;
8use chrono::{DateTime, Utc};
9use scouter_types::genai::traits::ProfileExt;
10use scouter_types::genai::{
11 AgentAssertionTask, AssertionResult, EvalSet, ExecutionPlan, GenAIEvalProfile,
12 TraceAssertionTask,
13};
14use scouter_types::sql::TraceSpan;
15use scouter_types::{Assertion, EvalRecord};
16use serde_json::Value;
17use std::collections::HashMap;
18use std::sync::Arc;
19use tokio::sync::RwLock;
20use tokio::task::JoinSet;
21use tracing::{debug, error, instrument};
22
23#[derive(Debug, Clone)]
24struct ExecutionContext {
25 base_context: Arc<Value>,
26 assertion_store: Arc<RwLock<AssertionResultStore>>,
27 llm_response_store: Arc<RwLock<LLMResponseStore>>,
28 task_registry: Arc<RwLock<TaskRegistry>>,
29 task_stages: HashMap<String, i32>,
30}
31
32impl ExecutionContext {
33 fn new(base_context: Value, registry: TaskRegistry, execution_plan: &ExecutionPlan) -> Self {
34 debug!("Creating ExecutionContext");
35 Self {
36 base_context: Arc::new(base_context),
37 assertion_store: Arc::new(RwLock::new(AssertionResultStore::new())),
38 llm_response_store: Arc::new(RwLock::new(LLMResponseStore::new())),
39 task_registry: Arc::new(RwLock::new(registry)),
40 task_stages: Self::build_task_stages(execution_plan),
41 }
42 }
43
44 fn build_task_stages(execution_plan: &ExecutionPlan) -> HashMap<String, i32> {
45 execution_plan
46 .nodes
47 .iter()
48 .map(|(id, node)| (id.clone(), node.stage as i32))
49 .collect()
50 }
51
52 async fn build_scoped_context(&self, depends_on: &[String]) -> Value {
53 if depends_on.is_empty() {
54 return self.base_context.as_ref().clone();
55 }
56
57 let mut scoped_context = self.build_context_map(&self.base_context);
58 let registry = self.task_registry.read().await;
59
60 for dep_id in depends_on {
61 match registry.get_type(dep_id) {
62 Some(TaskType::Assertion) => {
63 let store = self.assertion_store.read().await;
64 if let Some(result) = store.retrieve(dep_id) {
65 scoped_context.insert(dep_id.clone(), result.2.actual.clone());
66 }
67 }
68 Some(TaskType::LLMJudge) => {
69 let store = self.llm_response_store.read().await;
70 if let Some(response) = store.retrieve(dep_id) {
71 scoped_context.insert(dep_id.clone(), response.clone());
72 }
73 }
74
75 Some(TaskType::TraceAssertion) => {
76 let store = self.assertion_store.read().await;
78 if let Some(result) = store.retrieve(dep_id) {
79 scoped_context.insert(dep_id.clone(), result.2.actual.clone());
80 }
81 }
82 Some(TaskType::AgentAssertion) => {
83 let store = self.assertion_store.read().await;
84 if let Some(result) = store.retrieve(dep_id) {
85 scoped_context.insert(dep_id.clone(), result.2.actual.clone());
86 }
87 }
88 None => {}
89 }
90 }
91
92 Value::Object(scoped_context)
93 }
94
95 fn build_context_map(&self, value: &Value) -> serde_json::Map<String, Value> {
96 match value {
97 Value::Object(obj) => obj.clone(),
98 _ => {
99 let mut map = serde_json::Map::new();
100 map.insert("context".to_string(), value.clone());
101 map
102 }
103 }
104 }
105
106 async fn store_assertion(
107 &self,
108 task_id: String,
109 start_time: DateTime<Utc>,
110 end_time: DateTime<Utc>,
111 result: AssertionResult,
112 ) {
113 self.assertion_store
114 .write()
115 .await
116 .store(task_id, start_time, end_time, result);
117 }
118
119 async fn store_llm_response(&self, task_id: String, response: Value) {
120 self.llm_response_store
121 .write()
122 .await
123 .store(task_id, response);
124 }
125}
126
127struct DependencyChecker {
128 context: ExecutionContext,
129}
130
131impl DependencyChecker {
132 fn new(context: ExecutionContext) -> Self {
133 Self { context }
134 }
135
136 async fn check_dependencies_satisfied(&self, task_id: &str) -> Option<bool> {
137 debug!("Checking dependencies for task: {}", task_id);
138 let dependencies = {
139 let registry = self.context.task_registry.read().await;
140 match registry.get_dependencies(task_id) {
141 Some(deps) => deps,
142 None => {
143 debug!("Task '{}' has no dependencies, ready to execute", task_id);
145 return Some(true);
146 }
147 }
148 };
149
150 debug!("Task '{}' has dependencies: {:?}", task_id, dependencies);
151
152 let dep_metadata = {
153 let registry = self.context.task_registry.read().await;
154 dependencies
155 .iter()
156 .map(|dep_id| {
157 (
158 dep_id.clone(),
159 registry.is_conditional(dep_id),
160 registry.is_skipped(dep_id),
161 )
162 })
163 .collect::<Vec<_>>()
164 };
165
166 for (dep_id, is_conditional, is_skipped) in dep_metadata {
167 debug!(
168 "Checking dependency '{}' for task '{}': conditional={}, skipped={}",
169 dep_id, task_id, is_conditional, is_skipped
170 );
171 if is_skipped {
172 self.mark_skipped(task_id).await;
173 return Some(false);
174 }
175
176 let completed = self.check_task_completed(&dep_id).await;
177 if !completed {
178 if is_conditional {
179 self.mark_skipped(task_id).await;
180 return Some(false);
181 }
182 return None;
183 }
184
185 if is_conditional && !self.check_assertion_passed(&dep_id).await? {
186 self.mark_skipped(task_id).await;
187 return Some(false);
188 }
189 }
190
191 Some(true)
192 }
193
194 async fn check_task_completed(&self, task_id: &str) -> bool {
195 let registry = self.context.task_registry.read().await;
196 match registry.get_type(task_id) {
197 Some(TaskType::Assertion) => self
198 .context
199 .assertion_store
200 .read()
201 .await
202 .retrieve(task_id)
203 .is_some(),
204 Some(TaskType::LLMJudge) => self
205 .context
206 .llm_response_store
207 .read()
208 .await
209 .retrieve(task_id)
210 .is_some(),
211 Some(TaskType::TraceAssertion) => self
212 .context
213 .assertion_store
214 .read()
215 .await
216 .retrieve(task_id)
217 .is_some(),
218 Some(TaskType::AgentAssertion) => self
219 .context
220 .assertion_store
221 .read()
222 .await
223 .retrieve(task_id)
224 .is_some(),
225 None => false,
226 }
227 }
228
229 async fn check_assertion_passed(&self, task_id: &str) -> Option<bool> {
230 self.context
231 .assertion_store
232 .read()
233 .await
234 .retrieve(task_id)
235 .map(|res| res.2.passed)
236 }
237
238 async fn mark_skipped(&self, task_id: &str) {
239 self.context
240 .task_registry
241 .write()
242 .await
243 .mark_skipped(task_id.to_string());
244 }
245
246 async fn filter_executable_tasks<'a>(&self, task_ids: &'a [String]) -> Vec<&'a str> {
247 debug!("Filtering executable tasks from: {:?}", task_ids);
248 let mut executable = Vec::with_capacity(task_ids.len());
249
250 for task_id in task_ids {
251 if let Some(true) = self.check_dependencies_satisfied(task_id).await {
252 executable.push(task_id.as_str());
253 }
254 }
255
256 executable
257 }
258}
259
260struct TaskExecutor {
261 context: ExecutionContext,
262 profile: Arc<GenAIEvalProfile>,
263 trace_context_builder: TraceContextBuilder,
264 request_context_builder: Option<AgentContextBuilder>,
265}
266
267impl TaskExecutor {
268 fn new(
269 context: ExecutionContext,
270 profile: Arc<GenAIEvalProfile>,
271 spans: Arc<Vec<TraceSpan>>,
272 ) -> Self {
273 debug!("Creating TaskExecutor");
274 let trace_context_builder = TraceContextBuilder::new(spans);
275
276 let request_context_builder = if profile.has_agent_assertions() {
278 AgentContextBuilder::from_context(context.base_context.as_ref(), None)
279 .inspect_err(|e| error!("Failed to build request context: {:?}", e))
280 .ok()
281 } else {
282 None
283 };
284
285 Self {
286 context,
287 profile,
288 trace_context_builder,
289 request_context_builder,
290 }
291 }
292
293 #[instrument(skip_all)]
294 async fn execute_level(&self, task_ids: &[String]) -> Result<(), EvaluationError> {
295 let checker = DependencyChecker::new(self.context.clone());
296 let executable_tasks = checker.filter_executable_tasks(task_ids).await;
297
298 debug!("Executable tasks for level: {:?}", executable_tasks);
299
300 if executable_tasks.is_empty() {
301 return Ok(());
302 }
303
304 let (assertions, judges, traces_assertions, agent_assertions) =
305 self.partition_tasks(executable_tasks).await;
306
307 debug!(
308 "Executing level with {} assertions, {} LLM judges, {} trace assertions, and {} request assertions",
309 assertions.len(),
310 judges.len(),
311 traces_assertions.len(),
312 agent_assertions.len()
313 );
314
315 let _result = tokio::try_join!(
316 self.execute_assertions(&assertions),
317 self.execute_llm_judges(&judges),
318 self.execute_trace_assertions(&traces_assertions),
319 self.execute_agent_assertions(&agent_assertions)
320 )?;
321
322 Ok(())
323 }
324
325 async fn partition_tasks<'a>(
326 &self,
327 task_ids: Vec<&'a str>,
328 ) -> (Vec<&'a str>, Vec<&'a str>, Vec<&'a str>, Vec<&'a str>) {
329 let registry = self.context.task_registry.read().await;
330 let mut assertions = Vec::new();
331 let mut traces_assertions = Vec::new();
332 let mut agent_assertions = Vec::new();
333 let mut judges = Vec::new();
334
335 for id in task_ids {
336 match registry.get_type(id) {
337 Some(TaskType::Assertion) => assertions.push(id),
338 Some(TaskType::LLMJudge) => judges.push(id),
339 Some(TaskType::TraceAssertion) => traces_assertions.push(id),
340 Some(TaskType::AgentAssertion) => agent_assertions.push(id),
341 None => continue,
342 }
343 }
344
345 (assertions, judges, traces_assertions, agent_assertions)
346 }
347
348 async fn execute_assertions(&self, task_ids: &[&str]) -> Result<(), EvaluationError> {
349 debug!("Executing assertion tasks: {:?}", task_ids);
350 if task_ids.is_empty() {
351 return Ok(());
352 }
353
354 let mut join_set = JoinSet::new();
355
356 for &task_id in task_ids {
357 let task_id = task_id.to_string();
358 let context = self.context.clone();
359 let profile = self.profile.clone();
360
361 join_set.spawn(async move {
362 Self::execute_assertion_task(&task_id, &context, &profile).await
363 });
364 }
365
366 while let Some(result) = join_set.join_next().await {
367 result.map_err(|e| {
368 EvaluationError::GenAIEvaluatorError(format!("Task join error: {}", e))
369 })??;
370 }
371
372 Ok(())
373 }
374
375 async fn execute_trace_assertions(&self, task_ids: &[&str]) -> Result<(), EvaluationError> {
376 debug!("Executing trace assertion tasks: {:?}", task_ids);
377 if task_ids.is_empty() {
378 return Ok(());
379 }
380 let tasks: Vec<TraceAssertionTask> = task_ids
381 .iter()
382 .filter_map(|&task_id| self.profile.get_trace_assertion_by_id(task_id))
383 .cloned()
384 .collect();
385
386 debug!("Executing {} trace assertion tasks", tasks.len());
387
388 let start_time = Utc::now();
389 let results =
390 execute_trace_assertions(&self.trace_context_builder, &tasks).inspect_err(|e| {
391 error!("Failed to execute trace assertions: {:?}", e);
392 })?;
393 let end_time = Utc::now();
394
395 for (task_id, result) in results.results {
396 self.context
397 .store_assertion(task_id, start_time, end_time, result)
398 .await;
399 }
400
401 Ok(())
402 }
403
404 async fn execute_agent_assertions(&self, task_ids: &[&str]) -> Result<(), EvaluationError> {
405 debug!("Executing agent assertion tasks: {:?}", task_ids);
406 if task_ids.is_empty() {
407 return Ok(());
408 }
409
410 let tasks: Vec<AgentAssertionTask> = task_ids
411 .iter()
412 .filter_map(|&task_id| self.profile.get_agent_assertion_by_id(task_id))
413 .cloned()
414 .collect();
415
416 debug!("Executing {} agent assertion tasks", tasks.len());
417
418 let start_time = Utc::now();
419 let results = match &self.request_context_builder {
420 Some(ctx) => execute_agent_assertions(ctx, &tasks).inspect_err(|e| {
421 error!("Failed to execute agent assertions: {:?}", e);
422 })?,
423 None => {
424 let results = tasks
426 .iter()
427 .map(|task| {
428 (
429 task.id.clone(),
430 AssertionResult {
431 passed: false,
432 actual: serde_json::Value::Null,
433 expected: serde_json::Value::Null,
434 message: "No request context available for evaluation".to_string(),
435 },
436 )
437 })
438 .collect();
439 scouter_types::genai::AssertionResults { results }
440 }
441 };
442
443 let end_time = Utc::now();
444 for (task_id, result) in results.results {
445 self.context
446 .store_assertion(task_id, start_time, end_time, result)
447 .await;
448 }
449
450 Ok(())
451 }
452
453 async fn execute_llm_judges(&self, task_ids: &[&str]) -> Result<(), EvaluationError> {
454 debug!("Executing LLM judge tasks: {:?}", task_ids);
455 if task_ids.is_empty() {
456 return Ok(());
457 }
458
459 let mut join_set = JoinSet::new();
460
461 for &task_id in task_ids {
462 let task_id = task_id.to_string();
463 let context = self.context.clone();
464 let profile = self.profile.clone();
465
466 join_set.spawn(async move {
467 let result = Self::execute_llm_judge_task(&task_id, &context, &profile).await;
468 result
469 });
470 }
471
472 let mut results = HashMap::with_capacity(task_ids.len());
473 while let Some(result) = join_set.join_next().await {
474 let (judge_id, start_time, response) = result.map_err(|e| {
475 EvaluationError::GenAIEvaluatorError(format!("Task join error: {}", e))
476 })??;
477 results.insert(judge_id, (start_time, response));
478 }
479
480 self.process_llm_judge_results(results).await?;
481 Ok(())
482 }
483
484 #[instrument(skip_all, fields(task_id = %task_id))]
485 async fn execute_assertion_task(
486 task_id: &str,
487 context: &ExecutionContext,
488 profile: &GenAIEvalProfile,
489 ) -> Result<(), EvaluationError> {
490 let start_time = Utc::now();
491
492 let task = profile
493 .get_assertion_by_id(task_id)
494 .ok_or_else(|| EvaluationError::TaskNotFound(task_id.to_string()))?;
495
496 let scoped_context = context.build_scoped_context(&task.depends_on).await;
497 let result = task.execute(&scoped_context)?;
498
499 let end_time = Utc::now();
500 context
501 .store_assertion(task_id.to_string(), start_time, end_time, result)
502 .await;
503 Ok(())
504 }
505
506 #[instrument(skip_all, fields(task_id = %task_id))]
507 async fn execute_llm_judge_task(
508 task_id: &str,
509 context: &ExecutionContext,
510 profile: &GenAIEvalProfile,
511 ) -> Result<(String, DateTime<Utc>, serde_json::Value), EvaluationError> {
512 debug!("Starting LLM judge task: {}", task_id);
513 let start_time = Utc::now();
514 let judge = profile
515 .get_llm_judge_by_id(task_id)
516 .ok_or_else(|| EvaluationError::TaskNotFound(task_id.to_string()))?;
517
518 debug!("Building scoped context for: {}", task_id);
519 let scoped_context = context.build_scoped_context(&judge.depends_on).await;
520
521 let workflow = profile.workflow.as_ref().ok_or_else(|| {
522 EvaluationError::GenAIEvaluatorError("No workflow defined".to_string())
523 })?;
524
525 debug!("Executing workflow task: {}", task_id);
526
527 let response = workflow
529 .execute_task(task_id, &scoped_context)
530 .await
531 .inspect_err(|e| error!("LLM task {} failed: {:?}", task_id, e))?;
532
533 debug!("Successfully completed LLM judge task: {}", task_id);
534 Ok((task_id.to_string(), start_time, response))
535 }
536
537 async fn process_llm_judge_results(
538 &self,
539 results: HashMap<String, (DateTime<Utc>, Value)>,
540 ) -> Result<(), EvaluationError> {
541 for (task_id, (start_time, response)) in results {
542 if let Some(task) = self.profile.get_llm_judge_by_id(&task_id) {
543 let assertion_result = task.execute(&response)?;
544
545 self.context
546 .store_llm_response(task_id.clone(), response)
547 .await;
548
549 self.context
550 .store_assertion(task_id, start_time, Utc::now(), assertion_result)
551 .await;
552 }
553 }
554 Ok(())
555 }
556}
557
558struct ResultCollector {
559 context: ExecutionContext,
560}
561
562impl ResultCollector {
563 fn new(context: ExecutionContext) -> Self {
564 Self { context }
565 }
566
567 async fn build_eval_set(
568 &self,
569 record: &EvalRecord,
570 profile: &GenAIEvalProfile,
571 duration_ms: i64,
572 execution_plan: ExecutionPlan,
573 ) -> EvalSet {
574 let mut passed_count = 0;
575 let mut failed_count = 0;
576 let mut records = Vec::new();
577
578 let assert_store = self.context.assertion_store.read().await;
579
580 for assertion in &profile.tasks.assertion {
581 if let Some((start_time, end_time, result)) = assert_store.retrieve(&assertion.id) {
582 if !assertion.condition {
583 if result.passed {
584 passed_count += 1;
585 } else {
586 failed_count += 1;
587 }
588 }
589
590 let stage = *self.context.task_stages.get(&assertion.id).unwrap_or(&-1);
591
592 records.push(scouter_types::EvalTaskResult {
593 created_at: chrono::Utc::now(),
594 start_time,
595 end_time,
596 record_uid: record.uid.clone(),
597 entity_id: record.entity_id,
598 task_id: assertion.id.clone(),
599 task_type: assertion.task_type.clone(),
600 passed: result.passed,
601 value: result.to_metric_value(),
602 assertion: Assertion::FieldPath(assertion.context_path.clone()),
603 expected: result.expected.clone(),
604 actual: result.actual.clone(),
605 message: result.message.clone(),
606 operator: assertion.operator.clone(),
607 entity_uid: String::new(),
608 condition: assertion.condition,
609 stage,
610 });
611 }
612 }
613
614 for judge in &profile.tasks.judge {
615 if let Some((start_time, end_time, result)) = assert_store.retrieve(&judge.id) {
616 if !judge.condition {
617 if result.passed {
618 passed_count += 1;
619 } else {
620 failed_count += 1;
621 }
622 }
623
624 let stage = *self.context.task_stages.get(&judge.id).unwrap_or(&-1);
625
626 records.push(scouter_types::EvalTaskResult {
627 created_at: chrono::Utc::now(),
628 start_time,
629 end_time,
630 record_uid: record.uid.clone(),
631 entity_id: record.entity_id,
632 task_id: judge.id.clone(),
633 task_type: judge.task_type.clone(),
634 passed: result.passed,
635 value: result.to_metric_value(),
636 assertion: Assertion::FieldPath(judge.context_path.clone()),
637 expected: judge.expected_value.clone(),
638 actual: result.actual.clone(),
639 message: result.message.clone(),
640 operator: judge.operator.clone(),
641 entity_uid: String::new(),
642 condition: judge.condition,
643 stage,
644 });
645 }
646 }
647
648 for trace_assertion in &profile.tasks.trace {
649 if let Some((start_time, end_time, result)) = assert_store.retrieve(&trace_assertion.id)
650 {
651 if !trace_assertion.condition {
652 if result.passed {
653 passed_count += 1;
654 } else {
655 failed_count += 1;
656 }
657 }
658
659 let stage = *self
660 .context
661 .task_stages
662 .get(&trace_assertion.id)
663 .unwrap_or(&-1);
664
665 records.push(scouter_types::EvalTaskResult {
666 created_at: chrono::Utc::now(),
667 start_time,
668 end_time,
669 record_uid: record.uid.clone(),
670 entity_id: record.entity_id,
671 task_id: trace_assertion.id.clone(),
672 task_type: trace_assertion.task_type.clone(),
673 passed: result.passed,
674 value: result.to_metric_value(),
675 assertion: Assertion::TraceAssertion(trace_assertion.assertion.clone()),
676 expected: result.expected.clone(),
677 actual: result.actual.clone(),
678 message: result.message.clone(),
679 operator: trace_assertion.operator.clone(),
680 entity_uid: String::new(),
681 condition: trace_assertion.condition,
682 stage,
683 });
684 }
685 }
686
687 for agent_assertion in &profile.tasks.agent {
688 if let Some((start_time, end_time, result)) = assert_store.retrieve(&agent_assertion.id)
689 {
690 if !agent_assertion.condition {
691 if result.passed {
692 passed_count += 1;
693 } else {
694 failed_count += 1;
695 }
696 }
697
698 let stage = *self
699 .context
700 .task_stages
701 .get(&agent_assertion.id)
702 .unwrap_or(&-1);
703
704 records.push(scouter_types::EvalTaskResult {
705 created_at: chrono::Utc::now(),
706 start_time,
707 end_time,
708 record_uid: record.uid.clone(),
709 entity_id: record.entity_id,
710 task_id: agent_assertion.id.clone(),
711 task_type: agent_assertion.task_type.clone(),
712 passed: result.passed,
713 value: result.to_metric_value(),
714 assertion: Assertion::AgentAssertion(agent_assertion.assertion.clone()),
715 expected: result.expected.clone(),
716 actual: result.actual.clone(),
717 message: result.message.clone(),
718 operator: agent_assertion.operator.clone(),
719 entity_uid: String::new(),
720 condition: agent_assertion.condition,
721 stage,
722 });
723 }
724 }
725
726 let workflow_record = scouter_types::GenAIEvalWorkflowResult {
727 created_at: chrono::Utc::now(),
728 id: record.id,
729 entity_id: record.entity_id,
730 record_uid: record.uid.clone(),
731 total_tasks: passed_count + failed_count,
732 passed_tasks: passed_count,
733 failed_tasks: failed_count,
734 pass_rate: if passed_count + failed_count == 0 {
735 0.0
736 } else {
737 passed_count as f64 / (passed_count + failed_count) as f64
738 },
739 duration_ms,
740 entity_uid: String::new(),
741 execution_plan,
742 };
743
744 EvalSet::new(records, workflow_record)
745 }
746}
747
748pub struct GenAIEvaluator;
749
750impl GenAIEvaluator {
751 #[instrument(skip_all, fields(record_uid = %record.uid))]
752 pub async fn process_event_record(
753 record: &EvalRecord,
754 profile: Arc<GenAIEvalProfile>,
755 spans: Arc<Vec<TraceSpan>>,
756 ) -> Result<EvalSet, EvaluationError> {
757 let begin = chrono::Utc::now();
758
759 let mut registry = TaskRegistry::new();
760 Self::register_tasks(&mut registry, &profile);
761
762 let execution_plan = profile.get_execution_plan()?;
763
764 let context = ExecutionContext::new(record.context.clone(), registry, &execution_plan);
765 let executor = TaskExecutor::new(context.clone(), profile.clone(), spans);
766
767 debug!(
768 "Starting evaluation for record: {} with {} stages",
769 record.uid,
770 execution_plan.stages.len()
771 );
772
773 for (stage_idx, stage_tasks) in execution_plan.stages.iter().enumerate() {
774 debug!(
775 "Executing stage {} with {} tasks",
776 stage_idx,
777 stage_tasks.len()
778 );
779 executor
780 .execute_level(stage_tasks)
781 .await
782 .inspect_err(|e| error!("Failed to execute stage {}: {:?}", stage_idx, e))?;
783 }
784
785 let end = chrono::Utc::now();
786 let duration_ms = (end - begin).num_milliseconds();
787
788 let collector = ResultCollector::new(context);
789 let eval_set = collector
790 .build_eval_set(record, &profile, duration_ms, execution_plan)
791 .await;
792
793 Ok(eval_set)
794 }
795
796 fn register_tasks(registry: &mut TaskRegistry, profile: &GenAIEvalProfile) {
797 for task in &profile.tasks.assertion {
798 registry.register(task.id.clone(), TaskType::Assertion, task.condition);
799 if !task.depends_on.is_empty() {
800 registry.register_dependencies(task.id.clone(), task.depends_on.clone());
801 }
802 }
803
804 for task in &profile.tasks.judge {
805 registry.register(task.id.clone(), TaskType::LLMJudge, task.condition);
806 if !task.depends_on.is_empty() {
807 registry.register_dependencies(task.id.clone(), task.depends_on.clone());
808 }
809 }
810
811 for task in &profile.tasks.trace {
812 registry.register(task.id.clone(), TaskType::TraceAssertion, task.condition);
813 if !task.depends_on.is_empty() {
814 registry.register_dependencies(task.id.clone(), task.depends_on.clone());
815 }
816 }
817
818 for task in &profile.tasks.agent {
819 registry.register(task.id.clone(), TaskType::AgentAssertion, task.condition);
820 if !task.depends_on.is_empty() {
821 registry.register_dependencies(task.id.clone(), task.depends_on.clone());
822 }
823 }
824 }
825}
826
827#[cfg(test)]
828mod tests {
829
830 use chrono::Utc;
831 use potato_head::mock::{create_score_prompt, LLMTestServer};
832 use scouter_mocks::{
833 create_multi_service_trace, create_nested_trace, create_sequence_pattern_trace,
834 create_simple_trace, create_trace_with_attributes, create_trace_with_errors, init_tracing,
835 };
836 use scouter_types::genai::{
837 AggregationType, SpanFilter, SpanStatus, TraceAssertion, TraceAssertionTask,
838 };
839 use scouter_types::genai::{
840 AssertionTask, ComparisonOperator, GenAIAlertConfig, GenAIEvalConfig, GenAIEvalProfile,
841 LLMJudgeTask,
842 };
843 use scouter_types::genai::{EvaluationTaskType, EvaluationTasks};
844 use scouter_types::EvalRecord;
845 use serde_json::Value;
846 use std::sync::Arc;
847
848 use crate::evaluate::GenAIEvaluator;
849
850 async fn create_assert_judge_profile() -> GenAIEvalProfile {
851 let prompt = create_score_prompt(Some(vec!["input".to_string()]));
852
853 let assertion_level_1 = AssertionTask {
854 id: "input_check".to_string(),
855 context_path: Some("input.foo".to_string()),
856 operator: ComparisonOperator::Equals,
857 expected_value: Value::String("bar".to_string()),
858 description: Some("Check if input.foo is bar".to_string()),
859 task_type: EvaluationTaskType::Assertion,
860 depends_on: vec![],
861 result: None,
862 condition: false,
863 item_context_path: None,
864 };
865
866 let judge_task_level_1 = LLMJudgeTask::new_rs(
867 "query_relevance",
868 prompt.clone(),
869 Value::Number(1.into()),
870 Some("score".to_string()),
871 ComparisonOperator::GreaterThanOrEqual,
872 None,
873 None,
874 None,
875 None,
876 );
877
878 let assert_query_score = AssertionTask {
879 id: "assert_score".to_string(),
880 context_path: Some("query_relevance.score".to_string()),
881 operator: ComparisonOperator::IsNumeric,
882 expected_value: Value::Bool(true),
883 depends_on: vec!["query_relevance".to_string()],
884 task_type: EvaluationTaskType::Assertion,
885 description: Some("Check that score is numeric".to_string()),
886 result: None,
887 condition: false,
888 item_context_path: None,
889 };
890
891 let assert_query_reason = AssertionTask {
892 id: "assert_reason".to_string(),
893 context_path: Some("query_relevance.reason".to_string()),
894 operator: ComparisonOperator::IsString,
895 expected_value: Value::Bool(true),
896 depends_on: vec!["query_relevance".to_string()],
897 task_type: EvaluationTaskType::Assertion,
898 description: Some("Check that reason is alphabetic".to_string()),
899 result: None,
900 condition: false,
901 item_context_path: None,
902 };
903
904 let tasks = EvaluationTasks::new()
905 .add_task(assertion_level_1)
906 .add_task(judge_task_level_1)
907 .add_task(assert_query_score)
908 .add_task(assert_query_reason)
909 .build();
910
911 let alert_config = GenAIAlertConfig::default();
912
913 let drift_config =
914 GenAIEvalConfig::new("scouter", "ML", "0.1.0", 1.0, alert_config, None).unwrap();
915
916 GenAIEvalProfile::new(drift_config, tasks).await.unwrap()
917 }
918
919 async fn create_assert_profile() -> GenAIEvalProfile {
920 let assert1 = AssertionTask {
921 id: "input_foo_check".to_string(),
922 context_path: Some("input.foo".to_string()),
923 operator: ComparisonOperator::Equals,
924 expected_value: Value::String("bar".to_string()),
925 description: Some("Check if input.foo is bar".to_string()),
926 task_type: EvaluationTaskType::Assertion,
927 depends_on: vec![],
928 result: None,
929 condition: false,
930 item_context_path: None,
931 };
932 let assert2 = AssertionTask {
933 id: "input_bar_check".to_string(),
934 context_path: Some("input.bar".to_string()),
935 operator: ComparisonOperator::IsNumeric,
936 expected_value: Value::Bool(true),
937 depends_on: vec![],
938 task_type: EvaluationTaskType::Assertion,
939 description: Some("Check that bar is numeric".to_string()),
940 result: None,
941 condition: false,
942 item_context_path: None,
943 };
944
945 let assert3 = AssertionTask {
946 id: "input_baz_check".to_string(),
947 context_path: Some("input.baz".to_string()),
948 operator: ComparisonOperator::HasLengthEqual,
949 expected_value: Value::Number(3.into()),
950 depends_on: vec![],
951 task_type: EvaluationTaskType::Assertion,
952 description: Some("Check that baz has length 3".to_string()),
953 result: None,
954 condition: false,
955 item_context_path: None,
956 };
957
958 let tasks = EvaluationTasks::new()
959 .add_task(assert1)
960 .add_task(assert2)
961 .add_task(assert3)
962 .build();
963
964 let alert_config = GenAIAlertConfig::default();
965
966 let drift_config =
967 GenAIEvalConfig::new("scouter", "ML", "0.1.0", 1.0, alert_config, None).unwrap();
968
969 GenAIEvalProfile::new(drift_config, tasks).await.unwrap()
970 }
971
972 async fn create_trace_profile_simple() -> GenAIEvalProfile {
973 let trace_task = TraceAssertionTask {
974 id: "check_span_sequence".to_string(),
975 assertion: TraceAssertion::SpanSequence {
976 span_names: vec![
977 "root".to_string(),
978 "child_1".to_string(),
979 "child_2".to_string(),
980 ],
981 },
982 operator: ComparisonOperator::Equals,
983 expected_value: Value::Bool(true),
984 description: Some("Verify span execution order".to_string()),
985 task_type: EvaluationTaskType::TraceAssertion,
986 depends_on: vec![],
987 condition: false,
988 result: None,
989 };
990
991 let tasks = EvaluationTasks::new().add_task(trace_task).build();
992
993 let alert_config = GenAIAlertConfig::default();
994 let drift_config =
995 GenAIEvalConfig::new("scouter", "trace_test", "0.1.0", 1.0, alert_config, None)
996 .unwrap();
997
998 GenAIEvalProfile::new(drift_config, tasks).await.unwrap()
999 }
1000
1001 async fn create_trace_profile_with_filters() -> GenAIEvalProfile {
1002 let span_count_task = TraceAssertionTask {
1003 id: "count_error_spans".to_string(),
1004 assertion: TraceAssertion::SpanCount {
1005 filter: SpanFilter::WithStatus {
1006 status: SpanStatus::Error,
1007 },
1008 },
1009 operator: ComparisonOperator::Equals,
1010 expected_value: Value::Number(1.into()),
1011 description: Some("Count spans with error status".to_string()),
1012 task_type: EvaluationTaskType::TraceAssertion,
1013 depends_on: vec![],
1014 condition: false,
1015 result: None,
1016 };
1017
1018 let span_exists_task = TraceAssertionTask {
1019 id: "check_recovery_span".to_string(),
1020 assertion: TraceAssertion::SpanExists {
1021 filter: SpanFilter::ByName {
1022 name: "recovery".to_string(),
1023 },
1024 },
1025 operator: ComparisonOperator::Equals,
1026 expected_value: Value::Bool(true),
1027 description: Some("Verify recovery span exists".to_string()),
1028 task_type: EvaluationTaskType::TraceAssertion,
1029 depends_on: vec![],
1030 condition: false,
1031 result: None,
1032 };
1033
1034 let tasks = EvaluationTasks::new()
1035 .add_task(span_count_task)
1036 .add_task(span_exists_task)
1037 .build();
1038
1039 let alert_config = GenAIAlertConfig::default();
1040 let drift_config =
1041 GenAIEvalConfig::new("scouter", "trace_test", "0.1.0", 1.0, alert_config, None)
1042 .unwrap();
1043
1044 GenAIEvalProfile::new(drift_config, tasks).await.unwrap()
1045 }
1046
1047 async fn create_trace_profile_with_attributes() -> GenAIEvalProfile {
1048 let attribute_task = TraceAssertionTask {
1049 id: "check_model_name".to_string(),
1050 assertion: TraceAssertion::SpanAttribute {
1051 filter: SpanFilter::ByName {
1052 name: "api_call".to_string(),
1053 },
1054 attribute_key: "model".to_string(),
1055 },
1056 operator: ComparisonOperator::Equals,
1057 expected_value: Value::String("gpt-4".to_string()),
1058 description: Some("Verify model attribute".to_string()),
1059 task_type: EvaluationTaskType::TraceAssertion,
1060 depends_on: vec![],
1061 condition: false,
1062 result: None,
1063 };
1064
1065 let aggregation_task = TraceAssertionTask {
1066 id: "sum_token_output".to_string(),
1067 assertion: TraceAssertion::SpanAggregation {
1068 filter: SpanFilter::ByName {
1069 name: "api_call".to_string(),
1070 },
1071 attribute_key: "tokens.output".to_string(),
1072 aggregation: AggregationType::Sum,
1073 },
1074 operator: ComparisonOperator::Equals,
1075 expected_value: Value::Number(300.into()),
1076 description: Some("Sum output tokens".to_string()),
1077 task_type: EvaluationTaskType::TraceAssertion,
1078 depends_on: vec![],
1079 condition: false,
1080 result: None,
1081 };
1082
1083 let tasks = EvaluationTasks::new()
1084 .add_task(attribute_task)
1085 .add_task(aggregation_task)
1086 .build();
1087
1088 let alert_config = GenAIAlertConfig::default();
1089 let drift_config =
1090 GenAIEvalConfig::new("scouter", "trace_test", "0.1.0", 1.0, alert_config, None)
1091 .unwrap();
1092
1093 GenAIEvalProfile::new(drift_config, tasks).await.unwrap()
1094 }
1095
1096 async fn create_trace_profile_complex() -> GenAIEvalProfile {
1097 let sequence_count_task = TraceAssertionTask {
1098 id: "count_tool_agent_sequence".to_string(),
1099 assertion: TraceAssertion::SpanCount {
1100 filter: SpanFilter::Sequence {
1101 names: vec!["call_tool".to_string(), "run_agent".to_string()],
1102 },
1103 },
1104 operator: ComparisonOperator::Equals,
1105 expected_value: Value::Number(2.into()),
1106 description: Some("Count tool->agent sequences".to_string()),
1107 task_type: EvaluationTaskType::TraceAssertion,
1108 depends_on: vec![],
1109 condition: false,
1110 result: None,
1111 };
1112
1113 let trace_duration_task = TraceAssertionTask {
1114 id: "check_trace_duration".to_string(),
1115 assertion: TraceAssertion::TraceDuration {},
1116 operator: ComparisonOperator::LessThanOrEqual,
1117 expected_value: Value::Number(1000.into()),
1118 description: Some("Verify trace completes within 1s".to_string()),
1119 task_type: EvaluationTaskType::TraceAssertion,
1120 depends_on: vec![],
1121 condition: false,
1122 result: None,
1123 };
1124
1125 let service_count_task = TraceAssertionTask {
1126 id: "check_service_count".to_string(),
1127 assertion: TraceAssertion::TraceServiceCount {},
1128 operator: ComparisonOperator::Equals,
1129 expected_value: Value::Number(1.into()),
1130 description: Some("Verify single service".to_string()),
1131 task_type: EvaluationTaskType::TraceAssertion,
1132 depends_on: vec![],
1133 condition: false,
1134 result: None,
1135 };
1136
1137 let tasks = EvaluationTasks::new()
1138 .add_task(sequence_count_task)
1139 .add_task(trace_duration_task)
1140 .add_task(service_count_task)
1141 .build();
1142
1143 let alert_config = GenAIAlertConfig::default();
1144 let drift_config =
1145 GenAIEvalConfig::new("scouter", "trace_test", "0.1.0", 1.0, alert_config, None)
1146 .unwrap();
1147
1148 GenAIEvalProfile::new(drift_config, tasks).await.unwrap()
1149 }
1150
1151 async fn create_trace_profile_with_dependencies() -> GenAIEvalProfile {
1152 let error_check = TraceAssertionTask {
1153 id: "check_has_errors".to_string(),
1154 assertion: TraceAssertion::TraceErrorCount {},
1155 operator: ComparisonOperator::GreaterThan,
1156 expected_value: Value::Number(0.into()),
1157 description: Some("Check if trace has errors".to_string()),
1158 task_type: EvaluationTaskType::TraceAssertion,
1159 depends_on: vec![],
1160 condition: true,
1161 result: None,
1162 };
1163
1164 let recovery_check = TraceAssertionTask {
1165 id: "check_recovery_exists".to_string(),
1166 assertion: TraceAssertion::SpanExists {
1167 filter: SpanFilter::ByName {
1168 name: "recovery".to_string(),
1169 },
1170 },
1171 operator: ComparisonOperator::Equals,
1172 expected_value: Value::Bool(true),
1173 description: Some("Verify recovery span exists when errors present".to_string()),
1174 task_type: EvaluationTaskType::TraceAssertion,
1175 depends_on: vec!["check_has_errors".to_string()],
1176 condition: false,
1177 result: None,
1178 };
1179
1180 let tasks = EvaluationTasks::new()
1181 .add_task(error_check)
1182 .add_task(recovery_check)
1183 .build();
1184
1185 let alert_config = GenAIAlertConfig::default();
1186 let drift_config =
1187 GenAIEvalConfig::new("scouter", "trace_test", "0.1.0", 1.0, alert_config, None)
1188 .unwrap();
1189
1190 GenAIEvalProfile::new(drift_config, tasks).await.unwrap()
1191 }
1192
1193 #[test]
1194 fn test_evaluator_assert_judge_all_pass() {
1195 let mut mock = LLMTestServer::new();
1196 mock.start_server().unwrap();
1197 let runtime = tokio::runtime::Runtime::new().unwrap();
1198 let profile = runtime.block_on(async { create_assert_judge_profile().await });
1199
1200 assert!(profile.has_llm_tasks());
1201 assert!(profile.has_assertions());
1202
1203 let context = serde_json::json!({
1204 "input": {
1205 "foo": "bar" }
1206 });
1207
1208 let record = EvalRecord::new_rs(
1209 context,
1210 Utc::now(),
1211 "UID123".to_string(),
1212 "ENTITY123".to_string(),
1213 None,
1214 None,
1215 );
1216
1217 let result_set = runtime.block_on(async {
1218 GenAIEvaluator::process_event_record(&record, Arc::new(profile), Arc::new(vec![])).await
1219 });
1220
1221 let eval_set = result_set.unwrap();
1222 assert!(eval_set.passed_tasks() == 4);
1223 assert!(eval_set.failed_tasks() == 0);
1224
1225 mock.stop_server().unwrap();
1226 }
1227
1228 #[test]
1229 fn test_evaluator_assert_one_fail() {
1230 let mut mock = LLMTestServer::new();
1231 mock.start_server().unwrap();
1232 let runtime = tokio::runtime::Runtime::new().unwrap();
1233 let profile = runtime.block_on(async { create_assert_profile().await });
1234
1235 assert!(!profile.has_llm_tasks());
1236 assert!(profile.has_assertions());
1237
1238 let context = serde_json::json!({
1240 "input": {
1241 "foo": "bar",
1242 "bar": "not_a_number",
1243 "baz": [1, 2, 3]}
1244 });
1245
1246 let record = EvalRecord::new_rs(
1247 context,
1248 Utc::now(),
1249 "UID123".to_string(),
1250 "ENTITY123".to_string(),
1251 None,
1252 None,
1253 );
1254
1255 let result_set = runtime.block_on(async {
1256 GenAIEvaluator::process_event_record(&record, Arc::new(profile), Arc::new(vec![])).await
1257 });
1258
1259 let eval_set = result_set.unwrap();
1260 assert!(eval_set.passed_tasks() == 2);
1261 assert!(eval_set.failed_tasks() == 1);
1262
1263 mock.stop_server().unwrap();
1264 }
1265
1266 #[test]
1267 fn test_evaluator_trace_simple_sequence() {
1268 init_tracing();
1269 let runtime = tokio::runtime::Runtime::new().unwrap();
1270 let profile = runtime.block_on(create_trace_profile_simple());
1271 let spans = Arc::new(create_simple_trace());
1272
1273 let context = serde_json::json!({});
1274 let record = EvalRecord::new_rs(
1275 context,
1276 Utc::now(),
1277 "TRACE_UID_001".to_string(),
1278 "ENTITY_001".to_string(),
1279 None,
1280 None,
1281 );
1282
1283 let result = runtime.block_on(GenAIEvaluator::process_event_record(
1284 &record,
1285 Arc::new(profile),
1286 spans,
1287 ));
1288
1289 let eval_set = result.unwrap();
1290 assert_eq!(eval_set.passed_tasks(), 1);
1291 assert_eq!(eval_set.failed_tasks(), 0);
1292 }
1293
1294 #[test]
1295 fn test_evaluator_trace_error_detection() {
1296 let runtime = tokio::runtime::Runtime::new().unwrap();
1297 let profile = runtime.block_on(create_trace_profile_with_filters());
1298 let spans = Arc::new(create_trace_with_errors());
1299
1300 let context = serde_json::json!({});
1301 let record = EvalRecord::new_rs(
1302 context,
1303 Utc::now(),
1304 "TRACE_UID_002".to_string(),
1305 "ENTITY_002".to_string(),
1306 None,
1307 None,
1308 );
1309
1310 let result = runtime.block_on(GenAIEvaluator::process_event_record(
1311 &record,
1312 Arc::new(profile),
1313 spans,
1314 ));
1315
1316 let eval_set = result.unwrap();
1317 assert_eq!(eval_set.passed_tasks(), 2);
1318 assert_eq!(eval_set.failed_tasks(), 0);
1319 }
1320
1321 #[test]
1322 fn test_evaluator_trace_attribute_extraction() {
1323 init_tracing();
1324 let runtime = tokio::runtime::Runtime::new().unwrap();
1325 let profile = runtime.block_on(create_trace_profile_with_attributes());
1326 let spans = Arc::new(create_trace_with_attributes());
1327
1328 let context = serde_json::json!({});
1329 let record = EvalRecord::new_rs(
1330 context,
1331 Utc::now(),
1332 "TRACE_UID_003".to_string(),
1333 "ENTITY_003".to_string(),
1334 None,
1335 None,
1336 );
1337
1338 let result = runtime.block_on(GenAIEvaluator::process_event_record(
1339 &record,
1340 Arc::new(profile),
1341 spans,
1342 ));
1343
1344 let eval_set = result.unwrap();
1345 assert_eq!(eval_set.passed_tasks(), 2);
1346 assert_eq!(eval_set.failed_tasks(), 0);
1347 }
1348
1349 #[test]
1350 fn test_evaluator_trace_sequence_pattern() {
1351 init_tracing();
1352 let runtime = tokio::runtime::Runtime::new().unwrap();
1353 let profile = runtime.block_on(create_trace_profile_complex());
1354 let spans = Arc::new(create_sequence_pattern_trace());
1355
1356 let context = serde_json::json!({});
1357 let record = EvalRecord::new_rs(
1358 context,
1359 Utc::now(),
1360 "TRACE_UID_004".to_string(),
1361 "ENTITY_004".to_string(),
1362 None,
1363 None,
1364 );
1365
1366 let result = runtime.block_on(GenAIEvaluator::process_event_record(
1367 &record,
1368 Arc::new(profile),
1369 spans,
1370 ));
1371
1372 let eval_set = result.unwrap();
1373 assert_eq!(eval_set.passed_tasks(), 3);
1374 assert_eq!(eval_set.failed_tasks(), 0);
1375 }
1376
1377 #[test]
1378 fn test_evaluator_trace_conditional_dependency() {
1379 let runtime = tokio::runtime::Runtime::new().unwrap();
1380 let profile = runtime.block_on(create_trace_profile_with_dependencies());
1381 let spans = Arc::new(create_trace_with_errors());
1382
1383 let context = serde_json::json!({});
1384 let record = EvalRecord::new_rs(
1385 context,
1386 Utc::now(),
1387 "TRACE_UID_005".to_string(),
1388 "ENTITY_005".to_string(),
1389 None,
1390 None,
1391 );
1392
1393 let result = runtime.block_on(GenAIEvaluator::process_event_record(
1394 &record,
1395 Arc::new(profile),
1396 spans,
1397 ));
1398
1399 let eval_set = result.unwrap();
1400 assert_eq!(eval_set.passed_tasks(), 1); assert_eq!(eval_set.failed_tasks(), 0);
1402 }
1403
1404 #[test]
1405 fn test_evaluator_trace_multi_service() {
1406 let runtime = tokio::runtime::Runtime::new().unwrap();
1407
1408 let task = TraceAssertionTask {
1409 id: "check_service_count".to_string(),
1410 assertion: TraceAssertion::TraceServiceCount {},
1411 operator: ComparisonOperator::Equals,
1412 expected_value: Value::Number(3.into()),
1413 description: Some("Verify three services".to_string()),
1414 task_type: EvaluationTaskType::TraceAssertion,
1415 depends_on: vec![],
1416 condition: false,
1417 result: None,
1418 };
1419
1420 let tasks = EvaluationTasks::new().add_task(task).build();
1421 let alert_config = GenAIAlertConfig::default();
1422 let drift_config =
1423 GenAIEvalConfig::new("scouter", "trace_test", "0.1.0", 1.0, alert_config, None)
1424 .unwrap();
1425
1426 let profile = runtime
1427 .block_on(GenAIEvalProfile::new(drift_config, tasks))
1428 .unwrap();
1429 let spans = Arc::new(create_multi_service_trace());
1430
1431 let context = serde_json::json!({});
1432 let record = EvalRecord::new_rs(
1433 context,
1434 Utc::now(),
1435 "TRACE_UID_006".to_string(),
1436 "ENTITY_006".to_string(),
1437 None,
1438 None,
1439 );
1440
1441 let result = runtime.block_on(GenAIEvaluator::process_event_record(
1442 &record,
1443 Arc::new(profile),
1444 spans,
1445 ));
1446
1447 let eval_set = result.unwrap();
1448 assert_eq!(eval_set.passed_tasks(), 1);
1449 assert_eq!(eval_set.failed_tasks(), 0);
1450 }
1451
1452 #[test]
1453 fn test_evaluator_trace_assertion_failure() {
1454 let runtime = tokio::runtime::Runtime::new().unwrap();
1455
1456 let task = TraceAssertionTask {
1457 id: "check_wrong_sequence".to_string(),
1458 assertion: TraceAssertion::SpanSequence {
1459 span_names: vec![
1460 "root".to_string(),
1461 "wrong_child".to_string(),
1462 "child_2".to_string(),
1463 ],
1464 },
1465 operator: ComparisonOperator::Equals,
1466 expected_value: Value::Bool(true),
1467 description: Some("Verify incorrect span order".to_string()),
1468 task_type: EvaluationTaskType::TraceAssertion,
1469 depends_on: vec![],
1470 condition: false,
1471 result: None,
1472 };
1473
1474 let tasks = EvaluationTasks::new().add_task(task).build();
1475 let alert_config = GenAIAlertConfig::default();
1476 let drift_config =
1477 GenAIEvalConfig::new("scouter", "trace_test", "0.1.0", 1.0, alert_config, None)
1478 .unwrap();
1479
1480 let profile = runtime
1481 .block_on(GenAIEvalProfile::new(drift_config, tasks))
1482 .unwrap();
1483 let spans = Arc::new(create_simple_trace());
1484
1485 let context = serde_json::json!({});
1486 let record = EvalRecord::new_rs(
1487 context,
1488 Utc::now(),
1489 "TRACE_UID_007".to_string(),
1490 "ENTITY_007".to_string(),
1491 None,
1492 None,
1493 );
1494
1495 let result = runtime.block_on(GenAIEvaluator::process_event_record(
1496 &record,
1497 Arc::new(profile),
1498 spans,
1499 ));
1500
1501 let eval_set = result.unwrap();
1502 assert_eq!(eval_set.passed_tasks(), 0);
1503 assert_eq!(eval_set.failed_tasks(), 1);
1504 }
1505
1506 #[test]
1507 fn test_evaluator_trace_mixed_assertions() {
1508 init_tracing();
1509 let runtime = tokio::runtime::Runtime::new().unwrap();
1510
1511 let trace_task = TraceAssertionTask {
1512 id: "check_max_depth".to_string(),
1513 assertion: TraceAssertion::TraceMaxDepth {},
1514 operator: ComparisonOperator::Equals,
1515 expected_value: Value::Number(2.into()),
1516 description: Some("Verify max depth".to_string()),
1517 task_type: EvaluationTaskType::TraceAssertion,
1518 depends_on: vec![],
1519 condition: false,
1520 result: None,
1521 };
1522
1523 let regular_assertion = AssertionTask {
1524 id: "check_context".to_string(),
1525 context_path: Some("metadata.version".to_string()),
1526 operator: ComparisonOperator::Equals,
1527 expected_value: Value::String("1.0.0".to_string()),
1528 description: Some("Verify version".to_string()),
1529 task_type: EvaluationTaskType::Assertion,
1530 depends_on: vec![],
1531 result: None,
1532 condition: false,
1533 item_context_path: None,
1534 };
1535
1536 let tasks = EvaluationTasks::new()
1537 .add_task(trace_task)
1538 .add_task(regular_assertion)
1539 .build();
1540
1541 let alert_config = GenAIAlertConfig::default();
1542 let drift_config =
1543 GenAIEvalConfig::new("scouter", "trace_test", "0.1.0", 1.0, alert_config, None)
1544 .unwrap();
1545
1546 let profile = runtime
1547 .block_on(GenAIEvalProfile::new(drift_config, tasks))
1548 .unwrap();
1549 let spans = Arc::new(create_nested_trace());
1550
1551 let context = serde_json::json!({
1552 "metadata": {
1553 "version": "1.0.0"
1554 }
1555 });
1556
1557 let record = EvalRecord::new_rs(
1558 context,
1559 Utc::now(),
1560 "TRACE_UID_008".to_string(),
1561 "ENTITY_008".to_string(),
1562 None,
1563 None,
1564 );
1565
1566 let result = runtime.block_on(GenAIEvaluator::process_event_record(
1567 &record,
1568 Arc::new(profile),
1569 spans,
1570 ));
1571
1572 let eval_set = result.unwrap();
1573 assert_eq!(eval_set.passed_tasks(), 2);
1574 assert_eq!(eval_set.failed_tasks(), 0);
1575 }
1576
1577 #[test]
1578 fn test_evaluator_trace_duration_filter() {
1579 init_tracing();
1580 let runtime = tokio::runtime::Runtime::new().unwrap();
1581
1582 let task = TraceAssertionTask {
1583 id: "check_slow_spans".to_string(),
1584 assertion: TraceAssertion::SpanCount {
1585 filter: SpanFilter::WithDuration {
1586 min_ms: Some(100.0),
1587 max_ms: None,
1588 },
1589 },
1590 operator: ComparisonOperator::GreaterThanOrEqual,
1591 expected_value: Value::Number(2.into()),
1592 description: Some("Count spans over 100ms".to_string()),
1593 task_type: EvaluationTaskType::TraceAssertion,
1594 depends_on: vec![],
1595 condition: false,
1596 result: None,
1597 };
1598
1599 let tasks = EvaluationTasks::new().add_task(task).build();
1600 let alert_config = GenAIAlertConfig::default();
1601 let drift_config =
1602 GenAIEvalConfig::new("scouter", "trace_test", "0.1.0", 1.0, alert_config, None)
1603 .unwrap();
1604
1605 let profile = runtime
1606 .block_on(GenAIEvalProfile::new(drift_config, tasks))
1607 .unwrap();
1608 let spans = Arc::new(create_nested_trace());
1609
1610 let context = serde_json::json!({});
1611 let record = EvalRecord::new_rs(
1612 context,
1613 Utc::now(),
1614 "TRACE_UID_009".to_string(),
1615 "ENTITY_009".to_string(),
1616 None,
1617 None,
1618 );
1619
1620 let result = runtime.block_on(GenAIEvaluator::process_event_record(
1621 &record,
1622 Arc::new(profile),
1623 spans,
1624 ));
1625
1626 let eval_set = result.unwrap();
1627 assert_eq!(eval_set.passed_tasks(), 1);
1628 assert_eq!(eval_set.failed_tasks(), 0);
1629 }
1630}