Skip to main content

scouter_evaluate/evaluate/
evaluator.rs

1use crate::error::EvaluationError;
2use crate::evaluate::agent::AgentContextBuilder;
3use crate::evaluate::store::{AssertionResultStore, LLMResponseStore, TaskRegistry, TaskType};
4use crate::evaluate::trace::TraceContextBuilder;
5use crate::tasks::agent::execute_agent_assertions;
6use crate::tasks::trace::execute_trace_assertions;
7use crate::tasks::traits::EvaluationTask;
8use chrono::{DateTime, Utc};
9use scouter_types::genai::traits::ProfileExt;
10use scouter_types::genai::{
11    AgentAssertionTask, AssertionResult, EvalSet, ExecutionPlan, GenAIEvalProfile,
12    TraceAssertionTask,
13};
14use scouter_types::sql::TraceSpan;
15use scouter_types::{Assertion, EvalRecord};
16use serde_json::Value;
17use std::collections::HashMap;
18use std::sync::Arc;
19use tokio::sync::RwLock;
20use tokio::task::JoinSet;
21use tracing::{debug, error, instrument};
22
23#[derive(Debug, Clone)]
24struct ExecutionContext {
25    base_context: Arc<Value>,
26    assertion_store: Arc<RwLock<AssertionResultStore>>,
27    llm_response_store: Arc<RwLock<LLMResponseStore>>,
28    task_registry: Arc<RwLock<TaskRegistry>>,
29    task_stages: HashMap<String, i32>,
30}
31
32impl ExecutionContext {
33    fn new(base_context: Value, registry: TaskRegistry, execution_plan: &ExecutionPlan) -> Self {
34        debug!("Creating ExecutionContext");
35        Self {
36            base_context: Arc::new(base_context),
37            assertion_store: Arc::new(RwLock::new(AssertionResultStore::new())),
38            llm_response_store: Arc::new(RwLock::new(LLMResponseStore::new())),
39            task_registry: Arc::new(RwLock::new(registry)),
40            task_stages: Self::build_task_stages(execution_plan),
41        }
42    }
43
44    fn build_task_stages(execution_plan: &ExecutionPlan) -> HashMap<String, i32> {
45        execution_plan
46            .nodes
47            .iter()
48            .map(|(id, node)| (id.clone(), node.stage as i32))
49            .collect()
50    }
51
52    async fn build_scoped_context(&self, depends_on: &[String]) -> Value {
53        if depends_on.is_empty() {
54            return self.base_context.as_ref().clone();
55        }
56
57        let mut scoped_context = self.build_context_map(&self.base_context);
58        let registry = self.task_registry.read().await;
59
60        for dep_id in depends_on {
61            match registry.get_type(dep_id) {
62                Some(TaskType::Assertion) => {
63                    let store = self.assertion_store.read().await;
64                    if let Some(result) = store.retrieve(dep_id) {
65                        scoped_context.insert(dep_id.clone(), result.2.actual.clone());
66                    }
67                }
68                Some(TaskType::LLMJudge) => {
69                    let store = self.llm_response_store.read().await;
70                    if let Some(response) = store.retrieve(dep_id) {
71                        scoped_context.insert(dep_id.clone(), response.clone());
72                    }
73                }
74
75                Some(TaskType::TraceAssertion) => {
76                    // Trace assertions store their results in the assertion store
77                    let store = self.assertion_store.read().await;
78                    if let Some(result) = store.retrieve(dep_id) {
79                        scoped_context.insert(dep_id.clone(), result.2.actual.clone());
80                    }
81                }
82                Some(TaskType::AgentAssertion) => {
83                    let store = self.assertion_store.read().await;
84                    if let Some(result) = store.retrieve(dep_id) {
85                        scoped_context.insert(dep_id.clone(), result.2.actual.clone());
86                    }
87                }
88                None => {}
89            }
90        }
91
92        Value::Object(scoped_context)
93    }
94
95    fn build_context_map(&self, value: &Value) -> serde_json::Map<String, Value> {
96        match value {
97            Value::Object(obj) => obj.clone(),
98            _ => {
99                let mut map = serde_json::Map::new();
100                map.insert("context".to_string(), value.clone());
101                map
102            }
103        }
104    }
105
106    async fn store_assertion(
107        &self,
108        task_id: String,
109        start_time: DateTime<Utc>,
110        end_time: DateTime<Utc>,
111        result: AssertionResult,
112    ) {
113        self.assertion_store
114            .write()
115            .await
116            .store(task_id, start_time, end_time, result);
117    }
118
119    async fn store_llm_response(&self, task_id: String, response: Value) {
120        self.llm_response_store
121            .write()
122            .await
123            .store(task_id, response);
124    }
125}
126
127struct DependencyChecker {
128    context: ExecutionContext,
129}
130
131impl DependencyChecker {
132    fn new(context: ExecutionContext) -> Self {
133        Self { context }
134    }
135
136    async fn check_dependencies_satisfied(&self, task_id: &str) -> Option<bool> {
137        debug!("Checking dependencies for task: {}", task_id);
138        let dependencies = {
139            let registry = self.context.task_registry.read().await;
140            match registry.get_dependencies(task_id) {
141                Some(deps) => deps,
142                None => {
143                    // Task exists but has no dependencies - ready to execute
144                    debug!("Task '{}' has no dependencies, ready to execute", task_id);
145                    return Some(true);
146                }
147            }
148        };
149
150        debug!("Task '{}' has dependencies: {:?}", task_id, dependencies);
151
152        let dep_metadata = {
153            let registry = self.context.task_registry.read().await;
154            dependencies
155                .iter()
156                .map(|dep_id| {
157                    (
158                        dep_id.clone(),
159                        registry.is_conditional(dep_id),
160                        registry.is_skipped(dep_id),
161                    )
162                })
163                .collect::<Vec<_>>()
164        };
165
166        for (dep_id, is_conditional, is_skipped) in dep_metadata {
167            debug!(
168                "Checking dependency '{}' for task '{}': conditional={}, skipped={}",
169                dep_id, task_id, is_conditional, is_skipped
170            );
171            if is_skipped {
172                self.mark_skipped(task_id).await;
173                return Some(false);
174            }
175
176            let completed = self.check_task_completed(&dep_id).await;
177            if !completed {
178                if is_conditional {
179                    self.mark_skipped(task_id).await;
180                    return Some(false);
181                }
182                return None;
183            }
184
185            if is_conditional && !self.check_assertion_passed(&dep_id).await? {
186                self.mark_skipped(task_id).await;
187                return Some(false);
188            }
189        }
190
191        Some(true)
192    }
193
194    async fn check_task_completed(&self, task_id: &str) -> bool {
195        let registry = self.context.task_registry.read().await;
196        match registry.get_type(task_id) {
197            Some(TaskType::Assertion) => self
198                .context
199                .assertion_store
200                .read()
201                .await
202                .retrieve(task_id)
203                .is_some(),
204            Some(TaskType::LLMJudge) => self
205                .context
206                .llm_response_store
207                .read()
208                .await
209                .retrieve(task_id)
210                .is_some(),
211            Some(TaskType::TraceAssertion) => self
212                .context
213                .assertion_store
214                .read()
215                .await
216                .retrieve(task_id)
217                .is_some(),
218            Some(TaskType::AgentAssertion) => self
219                .context
220                .assertion_store
221                .read()
222                .await
223                .retrieve(task_id)
224                .is_some(),
225            None => false,
226        }
227    }
228
229    async fn check_assertion_passed(&self, task_id: &str) -> Option<bool> {
230        self.context
231            .assertion_store
232            .read()
233            .await
234            .retrieve(task_id)
235            .map(|res| res.2.passed)
236    }
237
238    async fn mark_skipped(&self, task_id: &str) {
239        self.context
240            .task_registry
241            .write()
242            .await
243            .mark_skipped(task_id.to_string());
244    }
245
246    async fn filter_executable_tasks<'a>(&self, task_ids: &'a [String]) -> Vec<&'a str> {
247        debug!("Filtering executable tasks from: {:?}", task_ids);
248        let mut executable = Vec::with_capacity(task_ids.len());
249
250        for task_id in task_ids {
251            if let Some(true) = self.check_dependencies_satisfied(task_id).await {
252                executable.push(task_id.as_str());
253            }
254        }
255
256        executable
257    }
258}
259
260struct TaskExecutor {
261    context: ExecutionContext,
262    profile: Arc<GenAIEvalProfile>,
263    trace_context_builder: TraceContextBuilder,
264    request_context_builder: Option<AgentContextBuilder>,
265}
266
267impl TaskExecutor {
268    fn new(
269        context: ExecutionContext,
270        profile: Arc<GenAIEvalProfile>,
271        spans: Arc<Vec<TraceSpan>>,
272    ) -> Self {
273        debug!("Creating TaskExecutor");
274        let trace_context_builder = TraceContextBuilder::new(spans);
275
276        // Build request context builder from the eval record context if there are request assertions
277        let request_context_builder = if profile.has_agent_assertions() {
278            AgentContextBuilder::from_context(context.base_context.as_ref(), None)
279                .inspect_err(|e| error!("Failed to build request context: {:?}", e))
280                .ok()
281        } else {
282            None
283        };
284
285        Self {
286            context,
287            profile,
288            trace_context_builder,
289            request_context_builder,
290        }
291    }
292
293    #[instrument(skip_all)]
294    async fn execute_level(&self, task_ids: &[String]) -> Result<(), EvaluationError> {
295        let checker = DependencyChecker::new(self.context.clone());
296        let executable_tasks = checker.filter_executable_tasks(task_ids).await;
297
298        debug!("Executable tasks for level: {:?}", executable_tasks);
299
300        if executable_tasks.is_empty() {
301            return Ok(());
302        }
303
304        let (assertions, judges, traces_assertions, agent_assertions) =
305            self.partition_tasks(executable_tasks).await;
306
307        debug!(
308            "Executing level with {} assertions, {} LLM judges, {} trace assertions, and {} request assertions",
309            assertions.len(),
310            judges.len(),
311            traces_assertions.len(),
312            agent_assertions.len()
313        );
314
315        let _result = tokio::try_join!(
316            self.execute_assertions(&assertions),
317            self.execute_llm_judges(&judges),
318            self.execute_trace_assertions(&traces_assertions),
319            self.execute_agent_assertions(&agent_assertions)
320        )?;
321
322        Ok(())
323    }
324
325    async fn partition_tasks<'a>(
326        &self,
327        task_ids: Vec<&'a str>,
328    ) -> (Vec<&'a str>, Vec<&'a str>, Vec<&'a str>, Vec<&'a str>) {
329        let registry = self.context.task_registry.read().await;
330        let mut assertions = Vec::new();
331        let mut traces_assertions = Vec::new();
332        let mut agent_assertions = Vec::new();
333        let mut judges = Vec::new();
334
335        for id in task_ids {
336            match registry.get_type(id) {
337                Some(TaskType::Assertion) => assertions.push(id),
338                Some(TaskType::LLMJudge) => judges.push(id),
339                Some(TaskType::TraceAssertion) => traces_assertions.push(id),
340                Some(TaskType::AgentAssertion) => agent_assertions.push(id),
341                None => continue,
342            }
343        }
344
345        (assertions, judges, traces_assertions, agent_assertions)
346    }
347
348    async fn execute_assertions(&self, task_ids: &[&str]) -> Result<(), EvaluationError> {
349        debug!("Executing assertion tasks: {:?}", task_ids);
350        if task_ids.is_empty() {
351            return Ok(());
352        }
353
354        let mut join_set = JoinSet::new();
355
356        for &task_id in task_ids {
357            let task_id = task_id.to_string();
358            let context = self.context.clone();
359            let profile = self.profile.clone();
360
361            join_set.spawn(async move {
362                Self::execute_assertion_task(&task_id, &context, &profile).await
363            });
364        }
365
366        while let Some(result) = join_set.join_next().await {
367            result.map_err(|e| {
368                EvaluationError::GenAIEvaluatorError(format!("Task join error: {}", e))
369            })??;
370        }
371
372        Ok(())
373    }
374
375    async fn execute_trace_assertions(&self, task_ids: &[&str]) -> Result<(), EvaluationError> {
376        debug!("Executing trace assertion tasks: {:?}", task_ids);
377        if task_ids.is_empty() {
378            return Ok(());
379        }
380        let tasks: Vec<TraceAssertionTask> = task_ids
381            .iter()
382            .filter_map(|&task_id| self.profile.get_trace_assertion_by_id(task_id))
383            .cloned()
384            .collect();
385
386        debug!("Executing {} trace assertion tasks", tasks.len());
387
388        let start_time = Utc::now();
389        let results =
390            execute_trace_assertions(&self.trace_context_builder, &tasks).inspect_err(|e| {
391                error!("Failed to execute trace assertions: {:?}", e);
392            })?;
393        let end_time = Utc::now();
394
395        for (task_id, result) in results.results {
396            self.context
397                .store_assertion(task_id, start_time, end_time, result)
398                .await;
399        }
400
401        Ok(())
402    }
403
404    async fn execute_agent_assertions(&self, task_ids: &[&str]) -> Result<(), EvaluationError> {
405        debug!("Executing agent assertion tasks: {:?}", task_ids);
406        if task_ids.is_empty() {
407            return Ok(());
408        }
409
410        let tasks: Vec<AgentAssertionTask> = task_ids
411            .iter()
412            .filter_map(|&task_id| self.profile.get_agent_assertion_by_id(task_id))
413            .cloned()
414            .collect();
415
416        debug!("Executing {} agent assertion tasks", tasks.len());
417
418        let start_time = Utc::now();
419        let results = match &self.request_context_builder {
420            Some(ctx) => execute_agent_assertions(ctx, &tasks).inspect_err(|e| {
421                error!("Failed to execute agent assertions: {:?}", e);
422            })?,
423            None => {
424                // No request context available - fail all tasks
425                let results = tasks
426                    .iter()
427                    .map(|task| {
428                        (
429                            task.id.clone(),
430                            AssertionResult {
431                                passed: false,
432                                actual: serde_json::Value::Null,
433                                expected: serde_json::Value::Null,
434                                message: "No request context available for evaluation".to_string(),
435                            },
436                        )
437                    })
438                    .collect();
439                scouter_types::genai::AssertionResults { results }
440            }
441        };
442
443        let end_time = Utc::now();
444        for (task_id, result) in results.results {
445            self.context
446                .store_assertion(task_id, start_time, end_time, result)
447                .await;
448        }
449
450        Ok(())
451    }
452
453    async fn execute_llm_judges(&self, task_ids: &[&str]) -> Result<(), EvaluationError> {
454        debug!("Executing LLM judge tasks: {:?}", task_ids);
455        if task_ids.is_empty() {
456            return Ok(());
457        }
458
459        let mut join_set = JoinSet::new();
460
461        for &task_id in task_ids {
462            let task_id = task_id.to_string();
463            let context = self.context.clone();
464            let profile = self.profile.clone();
465
466            join_set.spawn(async move {
467                let result = Self::execute_llm_judge_task(&task_id, &context, &profile).await;
468                result
469            });
470        }
471
472        let mut results = HashMap::with_capacity(task_ids.len());
473        while let Some(result) = join_set.join_next().await {
474            let (judge_id, start_time, response) = result.map_err(|e| {
475                EvaluationError::GenAIEvaluatorError(format!("Task join error: {}", e))
476            })??;
477            results.insert(judge_id, (start_time, response));
478        }
479
480        self.process_llm_judge_results(results).await?;
481        Ok(())
482    }
483
484    #[instrument(skip_all, fields(task_id = %task_id))]
485    async fn execute_assertion_task(
486        task_id: &str,
487        context: &ExecutionContext,
488        profile: &GenAIEvalProfile,
489    ) -> Result<(), EvaluationError> {
490        let start_time = Utc::now();
491
492        let task = profile
493            .get_assertion_by_id(task_id)
494            .ok_or_else(|| EvaluationError::TaskNotFound(task_id.to_string()))?;
495
496        let scoped_context = context.build_scoped_context(&task.depends_on).await;
497        let result = task.execute(&scoped_context)?;
498
499        let end_time = Utc::now();
500        context
501            .store_assertion(task_id.to_string(), start_time, end_time, result)
502            .await;
503        Ok(())
504    }
505
506    #[instrument(skip_all, fields(task_id = %task_id))]
507    async fn execute_llm_judge_task(
508        task_id: &str,
509        context: &ExecutionContext,
510        profile: &GenAIEvalProfile,
511    ) -> Result<(String, DateTime<Utc>, serde_json::Value), EvaluationError> {
512        debug!("Starting LLM judge task: {}", task_id);
513        let start_time = Utc::now();
514        let judge = profile
515            .get_llm_judge_by_id(task_id)
516            .ok_or_else(|| EvaluationError::TaskNotFound(task_id.to_string()))?;
517
518        debug!("Building scoped context for: {}", task_id);
519        let scoped_context = context.build_scoped_context(&judge.depends_on).await;
520
521        let workflow = profile.workflow.as_ref().ok_or_else(|| {
522            EvaluationError::GenAIEvaluatorError("No workflow defined".to_string())
523        })?;
524
525        debug!("Executing workflow task: {}", task_id);
526
527        // This is where the actual LLM call happens - ensure it's awaited
528        let response = workflow
529            .execute_task(task_id, &scoped_context)
530            .await
531            .inspect_err(|e| error!("LLM task {} failed: {:?}", task_id, e))?;
532
533        debug!("Successfully completed LLM judge task: {}", task_id);
534        Ok((task_id.to_string(), start_time, response))
535    }
536
537    async fn process_llm_judge_results(
538        &self,
539        results: HashMap<String, (DateTime<Utc>, Value)>,
540    ) -> Result<(), EvaluationError> {
541        for (task_id, (start_time, response)) in results {
542            if let Some(task) = self.profile.get_llm_judge_by_id(&task_id) {
543                let assertion_result = task.execute(&response)?;
544
545                self.context
546                    .store_llm_response(task_id.clone(), response)
547                    .await;
548
549                self.context
550                    .store_assertion(task_id, start_time, Utc::now(), assertion_result)
551                    .await;
552            }
553        }
554        Ok(())
555    }
556}
557
558struct ResultCollector {
559    context: ExecutionContext,
560}
561
562impl ResultCollector {
563    fn new(context: ExecutionContext) -> Self {
564        Self { context }
565    }
566
567    async fn build_eval_set(
568        &self,
569        record: &EvalRecord,
570        profile: &GenAIEvalProfile,
571        duration_ms: i64,
572        execution_plan: ExecutionPlan,
573    ) -> EvalSet {
574        let mut passed_count = 0;
575        let mut failed_count = 0;
576        let mut records = Vec::new();
577
578        let assert_store = self.context.assertion_store.read().await;
579
580        for assertion in &profile.tasks.assertion {
581            if let Some((start_time, end_time, result)) = assert_store.retrieve(&assertion.id) {
582                if !assertion.condition {
583                    if result.passed {
584                        passed_count += 1;
585                    } else {
586                        failed_count += 1;
587                    }
588                }
589
590                let stage = *self.context.task_stages.get(&assertion.id).unwrap_or(&-1);
591
592                records.push(scouter_types::EvalTaskResult {
593                    created_at: chrono::Utc::now(),
594                    start_time,
595                    end_time,
596                    record_uid: record.uid.clone(),
597                    entity_id: record.entity_id,
598                    task_id: assertion.id.clone(),
599                    task_type: assertion.task_type.clone(),
600                    passed: result.passed,
601                    value: result.to_metric_value(),
602                    assertion: Assertion::FieldPath(assertion.context_path.clone()),
603                    expected: result.expected.clone(),
604                    actual: result.actual.clone(),
605                    message: result.message.clone(),
606                    operator: assertion.operator.clone(),
607                    entity_uid: String::new(),
608                    condition: assertion.condition,
609                    stage,
610                });
611            }
612        }
613
614        for judge in &profile.tasks.judge {
615            if let Some((start_time, end_time, result)) = assert_store.retrieve(&judge.id) {
616                if !judge.condition {
617                    if result.passed {
618                        passed_count += 1;
619                    } else {
620                        failed_count += 1;
621                    }
622                }
623
624                let stage = *self.context.task_stages.get(&judge.id).unwrap_or(&-1);
625
626                records.push(scouter_types::EvalTaskResult {
627                    created_at: chrono::Utc::now(),
628                    start_time,
629                    end_time,
630                    record_uid: record.uid.clone(),
631                    entity_id: record.entity_id,
632                    task_id: judge.id.clone(),
633                    task_type: judge.task_type.clone(),
634                    passed: result.passed,
635                    value: result.to_metric_value(),
636                    assertion: Assertion::FieldPath(judge.context_path.clone()),
637                    expected: judge.expected_value.clone(),
638                    actual: result.actual.clone(),
639                    message: result.message.clone(),
640                    operator: judge.operator.clone(),
641                    entity_uid: String::new(),
642                    condition: judge.condition,
643                    stage,
644                });
645            }
646        }
647
648        for trace_assertion in &profile.tasks.trace {
649            if let Some((start_time, end_time, result)) = assert_store.retrieve(&trace_assertion.id)
650            {
651                if !trace_assertion.condition {
652                    if result.passed {
653                        passed_count += 1;
654                    } else {
655                        failed_count += 1;
656                    }
657                }
658
659                let stage = *self
660                    .context
661                    .task_stages
662                    .get(&trace_assertion.id)
663                    .unwrap_or(&-1);
664
665                records.push(scouter_types::EvalTaskResult {
666                    created_at: chrono::Utc::now(),
667                    start_time,
668                    end_time,
669                    record_uid: record.uid.clone(),
670                    entity_id: record.entity_id,
671                    task_id: trace_assertion.id.clone(),
672                    task_type: trace_assertion.task_type.clone(),
673                    passed: result.passed,
674                    value: result.to_metric_value(),
675                    assertion: Assertion::TraceAssertion(trace_assertion.assertion.clone()),
676                    expected: result.expected.clone(),
677                    actual: result.actual.clone(),
678                    message: result.message.clone(),
679                    operator: trace_assertion.operator.clone(),
680                    entity_uid: String::new(),
681                    condition: trace_assertion.condition,
682                    stage,
683                });
684            }
685        }
686
687        for agent_assertion in &profile.tasks.agent {
688            if let Some((start_time, end_time, result)) = assert_store.retrieve(&agent_assertion.id)
689            {
690                if !agent_assertion.condition {
691                    if result.passed {
692                        passed_count += 1;
693                    } else {
694                        failed_count += 1;
695                    }
696                }
697
698                let stage = *self
699                    .context
700                    .task_stages
701                    .get(&agent_assertion.id)
702                    .unwrap_or(&-1);
703
704                records.push(scouter_types::EvalTaskResult {
705                    created_at: chrono::Utc::now(),
706                    start_time,
707                    end_time,
708                    record_uid: record.uid.clone(),
709                    entity_id: record.entity_id,
710                    task_id: agent_assertion.id.clone(),
711                    task_type: agent_assertion.task_type.clone(),
712                    passed: result.passed,
713                    value: result.to_metric_value(),
714                    assertion: Assertion::AgentAssertion(agent_assertion.assertion.clone()),
715                    expected: result.expected.clone(),
716                    actual: result.actual.clone(),
717                    message: result.message.clone(),
718                    operator: agent_assertion.operator.clone(),
719                    entity_uid: String::new(),
720                    condition: agent_assertion.condition,
721                    stage,
722                });
723            }
724        }
725
726        let workflow_record = scouter_types::GenAIEvalWorkflowResult {
727            created_at: chrono::Utc::now(),
728            id: record.id,
729            entity_id: record.entity_id,
730            record_uid: record.uid.clone(),
731            total_tasks: passed_count + failed_count,
732            passed_tasks: passed_count,
733            failed_tasks: failed_count,
734            pass_rate: if passed_count + failed_count == 0 {
735                0.0
736            } else {
737                passed_count as f64 / (passed_count + failed_count) as f64
738            },
739            duration_ms,
740            entity_uid: String::new(),
741            execution_plan,
742        };
743
744        EvalSet::new(records, workflow_record)
745    }
746}
747
748pub struct GenAIEvaluator;
749
750impl GenAIEvaluator {
751    #[instrument(skip_all, fields(record_uid = %record.uid))]
752    pub async fn process_event_record(
753        record: &EvalRecord,
754        profile: Arc<GenAIEvalProfile>,
755        spans: Arc<Vec<TraceSpan>>,
756    ) -> Result<EvalSet, EvaluationError> {
757        let begin = chrono::Utc::now();
758
759        let mut registry = TaskRegistry::new();
760        Self::register_tasks(&mut registry, &profile);
761
762        let execution_plan = profile.get_execution_plan()?;
763
764        let context = ExecutionContext::new(record.context.clone(), registry, &execution_plan);
765        let executor = TaskExecutor::new(context.clone(), profile.clone(), spans);
766
767        debug!(
768            "Starting evaluation for record: {} with {} stages",
769            record.uid,
770            execution_plan.stages.len()
771        );
772
773        for (stage_idx, stage_tasks) in execution_plan.stages.iter().enumerate() {
774            debug!(
775                "Executing stage {} with {} tasks",
776                stage_idx,
777                stage_tasks.len()
778            );
779            executor
780                .execute_level(stage_tasks)
781                .await
782                .inspect_err(|e| error!("Failed to execute stage {}: {:?}", stage_idx, e))?;
783        }
784
785        let end = chrono::Utc::now();
786        let duration_ms = (end - begin).num_milliseconds();
787
788        let collector = ResultCollector::new(context);
789        let eval_set = collector
790            .build_eval_set(record, &profile, duration_ms, execution_plan)
791            .await;
792
793        Ok(eval_set)
794    }
795
796    fn register_tasks(registry: &mut TaskRegistry, profile: &GenAIEvalProfile) {
797        for task in &profile.tasks.assertion {
798            registry.register(task.id.clone(), TaskType::Assertion, task.condition);
799            if !task.depends_on.is_empty() {
800                registry.register_dependencies(task.id.clone(), task.depends_on.clone());
801            }
802        }
803
804        for task in &profile.tasks.judge {
805            registry.register(task.id.clone(), TaskType::LLMJudge, task.condition);
806            if !task.depends_on.is_empty() {
807                registry.register_dependencies(task.id.clone(), task.depends_on.clone());
808            }
809        }
810
811        for task in &profile.tasks.trace {
812            registry.register(task.id.clone(), TaskType::TraceAssertion, task.condition);
813            if !task.depends_on.is_empty() {
814                registry.register_dependencies(task.id.clone(), task.depends_on.clone());
815            }
816        }
817
818        for task in &profile.tasks.agent {
819            registry.register(task.id.clone(), TaskType::AgentAssertion, task.condition);
820            if !task.depends_on.is_empty() {
821                registry.register_dependencies(task.id.clone(), task.depends_on.clone());
822            }
823        }
824    }
825}
826
827#[cfg(test)]
828mod tests {
829
830    use chrono::Utc;
831    use potato_head::mock::{create_score_prompt, LLMTestServer};
832    use scouter_mocks::{
833        create_multi_service_trace, create_nested_trace, create_sequence_pattern_trace,
834        create_simple_trace, create_trace_with_attributes, create_trace_with_errors, init_tracing,
835    };
836    use scouter_types::genai::{
837        AggregationType, SpanFilter, SpanStatus, TraceAssertion, TraceAssertionTask,
838    };
839    use scouter_types::genai::{
840        AssertionTask, ComparisonOperator, GenAIAlertConfig, GenAIEvalConfig, GenAIEvalProfile,
841        LLMJudgeTask,
842    };
843    use scouter_types::genai::{EvaluationTaskType, EvaluationTasks};
844    use scouter_types::EvalRecord;
845    use serde_json::Value;
846    use std::sync::Arc;
847
848    use crate::evaluate::GenAIEvaluator;
849
850    async fn create_assert_judge_profile() -> GenAIEvalProfile {
851        let prompt = create_score_prompt(Some(vec!["input".to_string()]));
852
853        let assertion_level_1 = AssertionTask {
854            id: "input_check".to_string(),
855            context_path: Some("input.foo".to_string()),
856            operator: ComparisonOperator::Equals,
857            expected_value: Value::String("bar".to_string()),
858            description: Some("Check if input.foo is bar".to_string()),
859            task_type: EvaluationTaskType::Assertion,
860            depends_on: vec![],
861            result: None,
862            condition: false,
863            item_context_path: None,
864        };
865
866        let judge_task_level_1 = LLMJudgeTask::new_rs(
867            "query_relevance",
868            prompt.clone(),
869            Value::Number(1.into()),
870            Some("score".to_string()),
871            ComparisonOperator::GreaterThanOrEqual,
872            None,
873            None,
874            None,
875            None,
876        );
877
878        let assert_query_score = AssertionTask {
879            id: "assert_score".to_string(),
880            context_path: Some("query_relevance.score".to_string()),
881            operator: ComparisonOperator::IsNumeric,
882            expected_value: Value::Bool(true),
883            depends_on: vec!["query_relevance".to_string()],
884            task_type: EvaluationTaskType::Assertion,
885            description: Some("Check that score is numeric".to_string()),
886            result: None,
887            condition: false,
888            item_context_path: None,
889        };
890
891        let assert_query_reason = AssertionTask {
892            id: "assert_reason".to_string(),
893            context_path: Some("query_relevance.reason".to_string()),
894            operator: ComparisonOperator::IsString,
895            expected_value: Value::Bool(true),
896            depends_on: vec!["query_relevance".to_string()],
897            task_type: EvaluationTaskType::Assertion,
898            description: Some("Check that reason is alphabetic".to_string()),
899            result: None,
900            condition: false,
901            item_context_path: None,
902        };
903
904        let tasks = EvaluationTasks::new()
905            .add_task(assertion_level_1)
906            .add_task(judge_task_level_1)
907            .add_task(assert_query_score)
908            .add_task(assert_query_reason)
909            .build();
910
911        let alert_config = GenAIAlertConfig::default();
912
913        let drift_config =
914            GenAIEvalConfig::new("scouter", "ML", "0.1.0", 1.0, alert_config, None).unwrap();
915
916        GenAIEvalProfile::new(drift_config, tasks).await.unwrap()
917    }
918
919    async fn create_assert_profile() -> GenAIEvalProfile {
920        let assert1 = AssertionTask {
921            id: "input_foo_check".to_string(),
922            context_path: Some("input.foo".to_string()),
923            operator: ComparisonOperator::Equals,
924            expected_value: Value::String("bar".to_string()),
925            description: Some("Check if input.foo is bar".to_string()),
926            task_type: EvaluationTaskType::Assertion,
927            depends_on: vec![],
928            result: None,
929            condition: false,
930            item_context_path: None,
931        };
932        let assert2 = AssertionTask {
933            id: "input_bar_check".to_string(),
934            context_path: Some("input.bar".to_string()),
935            operator: ComparisonOperator::IsNumeric,
936            expected_value: Value::Bool(true),
937            depends_on: vec![],
938            task_type: EvaluationTaskType::Assertion,
939            description: Some("Check that bar is numeric".to_string()),
940            result: None,
941            condition: false,
942            item_context_path: None,
943        };
944
945        let assert3 = AssertionTask {
946            id: "input_baz_check".to_string(),
947            context_path: Some("input.baz".to_string()),
948            operator: ComparisonOperator::HasLengthEqual,
949            expected_value: Value::Number(3.into()),
950            depends_on: vec![],
951            task_type: EvaluationTaskType::Assertion,
952            description: Some("Check that baz has length 3".to_string()),
953            result: None,
954            condition: false,
955            item_context_path: None,
956        };
957
958        let tasks = EvaluationTasks::new()
959            .add_task(assert1)
960            .add_task(assert2)
961            .add_task(assert3)
962            .build();
963
964        let alert_config = GenAIAlertConfig::default();
965
966        let drift_config =
967            GenAIEvalConfig::new("scouter", "ML", "0.1.0", 1.0, alert_config, None).unwrap();
968
969        GenAIEvalProfile::new(drift_config, tasks).await.unwrap()
970    }
971
972    async fn create_trace_profile_simple() -> GenAIEvalProfile {
973        let trace_task = TraceAssertionTask {
974            id: "check_span_sequence".to_string(),
975            assertion: TraceAssertion::SpanSequence {
976                span_names: vec![
977                    "root".to_string(),
978                    "child_1".to_string(),
979                    "child_2".to_string(),
980                ],
981            },
982            operator: ComparisonOperator::Equals,
983            expected_value: Value::Bool(true),
984            description: Some("Verify span execution order".to_string()),
985            task_type: EvaluationTaskType::TraceAssertion,
986            depends_on: vec![],
987            condition: false,
988            result: None,
989        };
990
991        let tasks = EvaluationTasks::new().add_task(trace_task).build();
992
993        let alert_config = GenAIAlertConfig::default();
994        let drift_config =
995            GenAIEvalConfig::new("scouter", "trace_test", "0.1.0", 1.0, alert_config, None)
996                .unwrap();
997
998        GenAIEvalProfile::new(drift_config, tasks).await.unwrap()
999    }
1000
1001    async fn create_trace_profile_with_filters() -> GenAIEvalProfile {
1002        let span_count_task = TraceAssertionTask {
1003            id: "count_error_spans".to_string(),
1004            assertion: TraceAssertion::SpanCount {
1005                filter: SpanFilter::WithStatus {
1006                    status: SpanStatus::Error,
1007                },
1008            },
1009            operator: ComparisonOperator::Equals,
1010            expected_value: Value::Number(1.into()),
1011            description: Some("Count spans with error status".to_string()),
1012            task_type: EvaluationTaskType::TraceAssertion,
1013            depends_on: vec![],
1014            condition: false,
1015            result: None,
1016        };
1017
1018        let span_exists_task = TraceAssertionTask {
1019            id: "check_recovery_span".to_string(),
1020            assertion: TraceAssertion::SpanExists {
1021                filter: SpanFilter::ByName {
1022                    name: "recovery".to_string(),
1023                },
1024            },
1025            operator: ComparisonOperator::Equals,
1026            expected_value: Value::Bool(true),
1027            description: Some("Verify recovery span exists".to_string()),
1028            task_type: EvaluationTaskType::TraceAssertion,
1029            depends_on: vec![],
1030            condition: false,
1031            result: None,
1032        };
1033
1034        let tasks = EvaluationTasks::new()
1035            .add_task(span_count_task)
1036            .add_task(span_exists_task)
1037            .build();
1038
1039        let alert_config = GenAIAlertConfig::default();
1040        let drift_config =
1041            GenAIEvalConfig::new("scouter", "trace_test", "0.1.0", 1.0, alert_config, None)
1042                .unwrap();
1043
1044        GenAIEvalProfile::new(drift_config, tasks).await.unwrap()
1045    }
1046
1047    async fn create_trace_profile_with_attributes() -> GenAIEvalProfile {
1048        let attribute_task = TraceAssertionTask {
1049            id: "check_model_name".to_string(),
1050            assertion: TraceAssertion::SpanAttribute {
1051                filter: SpanFilter::ByName {
1052                    name: "api_call".to_string(),
1053                },
1054                attribute_key: "model".to_string(),
1055            },
1056            operator: ComparisonOperator::Equals,
1057            expected_value: Value::String("gpt-4".to_string()),
1058            description: Some("Verify model attribute".to_string()),
1059            task_type: EvaluationTaskType::TraceAssertion,
1060            depends_on: vec![],
1061            condition: false,
1062            result: None,
1063        };
1064
1065        let aggregation_task = TraceAssertionTask {
1066            id: "sum_token_output".to_string(),
1067            assertion: TraceAssertion::SpanAggregation {
1068                filter: SpanFilter::ByName {
1069                    name: "api_call".to_string(),
1070                },
1071                attribute_key: "tokens.output".to_string(),
1072                aggregation: AggregationType::Sum,
1073            },
1074            operator: ComparisonOperator::Equals,
1075            expected_value: Value::Number(300.into()),
1076            description: Some("Sum output tokens".to_string()),
1077            task_type: EvaluationTaskType::TraceAssertion,
1078            depends_on: vec![],
1079            condition: false,
1080            result: None,
1081        };
1082
1083        let tasks = EvaluationTasks::new()
1084            .add_task(attribute_task)
1085            .add_task(aggregation_task)
1086            .build();
1087
1088        let alert_config = GenAIAlertConfig::default();
1089        let drift_config =
1090            GenAIEvalConfig::new("scouter", "trace_test", "0.1.0", 1.0, alert_config, None)
1091                .unwrap();
1092
1093        GenAIEvalProfile::new(drift_config, tasks).await.unwrap()
1094    }
1095
1096    async fn create_trace_profile_complex() -> GenAIEvalProfile {
1097        let sequence_count_task = TraceAssertionTask {
1098            id: "count_tool_agent_sequence".to_string(),
1099            assertion: TraceAssertion::SpanCount {
1100                filter: SpanFilter::Sequence {
1101                    names: vec!["call_tool".to_string(), "run_agent".to_string()],
1102                },
1103            },
1104            operator: ComparisonOperator::Equals,
1105            expected_value: Value::Number(2.into()),
1106            description: Some("Count tool->agent sequences".to_string()),
1107            task_type: EvaluationTaskType::TraceAssertion,
1108            depends_on: vec![],
1109            condition: false,
1110            result: None,
1111        };
1112
1113        let trace_duration_task = TraceAssertionTask {
1114            id: "check_trace_duration".to_string(),
1115            assertion: TraceAssertion::TraceDuration {},
1116            operator: ComparisonOperator::LessThanOrEqual,
1117            expected_value: Value::Number(1000.into()),
1118            description: Some("Verify trace completes within 1s".to_string()),
1119            task_type: EvaluationTaskType::TraceAssertion,
1120            depends_on: vec![],
1121            condition: false,
1122            result: None,
1123        };
1124
1125        let service_count_task = TraceAssertionTask {
1126            id: "check_service_count".to_string(),
1127            assertion: TraceAssertion::TraceServiceCount {},
1128            operator: ComparisonOperator::Equals,
1129            expected_value: Value::Number(1.into()),
1130            description: Some("Verify single service".to_string()),
1131            task_type: EvaluationTaskType::TraceAssertion,
1132            depends_on: vec![],
1133            condition: false,
1134            result: None,
1135        };
1136
1137        let tasks = EvaluationTasks::new()
1138            .add_task(sequence_count_task)
1139            .add_task(trace_duration_task)
1140            .add_task(service_count_task)
1141            .build();
1142
1143        let alert_config = GenAIAlertConfig::default();
1144        let drift_config =
1145            GenAIEvalConfig::new("scouter", "trace_test", "0.1.0", 1.0, alert_config, None)
1146                .unwrap();
1147
1148        GenAIEvalProfile::new(drift_config, tasks).await.unwrap()
1149    }
1150
1151    async fn create_trace_profile_with_dependencies() -> GenAIEvalProfile {
1152        let error_check = TraceAssertionTask {
1153            id: "check_has_errors".to_string(),
1154            assertion: TraceAssertion::TraceErrorCount {},
1155            operator: ComparisonOperator::GreaterThan,
1156            expected_value: Value::Number(0.into()),
1157            description: Some("Check if trace has errors".to_string()),
1158            task_type: EvaluationTaskType::TraceAssertion,
1159            depends_on: vec![],
1160            condition: true,
1161            result: None,
1162        };
1163
1164        let recovery_check = TraceAssertionTask {
1165            id: "check_recovery_exists".to_string(),
1166            assertion: TraceAssertion::SpanExists {
1167                filter: SpanFilter::ByName {
1168                    name: "recovery".to_string(),
1169                },
1170            },
1171            operator: ComparisonOperator::Equals,
1172            expected_value: Value::Bool(true),
1173            description: Some("Verify recovery span exists when errors present".to_string()),
1174            task_type: EvaluationTaskType::TraceAssertion,
1175            depends_on: vec!["check_has_errors".to_string()],
1176            condition: false,
1177            result: None,
1178        };
1179
1180        let tasks = EvaluationTasks::new()
1181            .add_task(error_check)
1182            .add_task(recovery_check)
1183            .build();
1184
1185        let alert_config = GenAIAlertConfig::default();
1186        let drift_config =
1187            GenAIEvalConfig::new("scouter", "trace_test", "0.1.0", 1.0, alert_config, None)
1188                .unwrap();
1189
1190        GenAIEvalProfile::new(drift_config, tasks).await.unwrap()
1191    }
1192
1193    #[test]
1194    fn test_evaluator_assert_judge_all_pass() {
1195        let mut mock = LLMTestServer::new();
1196        mock.start_server().unwrap();
1197        let runtime = tokio::runtime::Runtime::new().unwrap();
1198        let profile = runtime.block_on(async { create_assert_judge_profile().await });
1199
1200        assert!(profile.has_llm_tasks());
1201        assert!(profile.has_assertions());
1202
1203        let context = serde_json::json!({
1204        "input": {
1205            "foo": "bar" }
1206        });
1207
1208        let record = EvalRecord::new_rs(
1209            context,
1210            Utc::now(),
1211            "UID123".to_string(),
1212            "ENTITY123".to_string(),
1213            None,
1214            None,
1215        );
1216
1217        let result_set = runtime.block_on(async {
1218            GenAIEvaluator::process_event_record(&record, Arc::new(profile), Arc::new(vec![])).await
1219        });
1220
1221        let eval_set = result_set.unwrap();
1222        assert!(eval_set.passed_tasks() == 4);
1223        assert!(eval_set.failed_tasks() == 0);
1224
1225        mock.stop_server().unwrap();
1226    }
1227
1228    #[test]
1229    fn test_evaluator_assert_one_fail() {
1230        let mut mock = LLMTestServer::new();
1231        mock.start_server().unwrap();
1232        let runtime = tokio::runtime::Runtime::new().unwrap();
1233        let profile = runtime.block_on(async { create_assert_profile().await });
1234
1235        assert!(!profile.has_llm_tasks());
1236        assert!(profile.has_assertions());
1237
1238        // we want task "input_bar_check" to fail (is_numeric on non-numeric)
1239        let context = serde_json::json!({
1240            "input": {
1241                "foo": "bar",
1242                "bar": "not_a_number",
1243                "baz": [1, 2, 3]}
1244        });
1245
1246        let record = EvalRecord::new_rs(
1247            context,
1248            Utc::now(),
1249            "UID123".to_string(),
1250            "ENTITY123".to_string(),
1251            None,
1252            None,
1253        );
1254
1255        let result_set = runtime.block_on(async {
1256            GenAIEvaluator::process_event_record(&record, Arc::new(profile), Arc::new(vec![])).await
1257        });
1258
1259        let eval_set = result_set.unwrap();
1260        assert!(eval_set.passed_tasks() == 2);
1261        assert!(eval_set.failed_tasks() == 1);
1262
1263        mock.stop_server().unwrap();
1264    }
1265
1266    #[test]
1267    fn test_evaluator_trace_simple_sequence() {
1268        init_tracing();
1269        let runtime = tokio::runtime::Runtime::new().unwrap();
1270        let profile = runtime.block_on(create_trace_profile_simple());
1271        let spans = Arc::new(create_simple_trace());
1272
1273        let context = serde_json::json!({});
1274        let record = EvalRecord::new_rs(
1275            context,
1276            Utc::now(),
1277            "TRACE_UID_001".to_string(),
1278            "ENTITY_001".to_string(),
1279            None,
1280            None,
1281        );
1282
1283        let result = runtime.block_on(GenAIEvaluator::process_event_record(
1284            &record,
1285            Arc::new(profile),
1286            spans,
1287        ));
1288
1289        let eval_set = result.unwrap();
1290        assert_eq!(eval_set.passed_tasks(), 1);
1291        assert_eq!(eval_set.failed_tasks(), 0);
1292    }
1293
1294    #[test]
1295    fn test_evaluator_trace_error_detection() {
1296        let runtime = tokio::runtime::Runtime::new().unwrap();
1297        let profile = runtime.block_on(create_trace_profile_with_filters());
1298        let spans = Arc::new(create_trace_with_errors());
1299
1300        let context = serde_json::json!({});
1301        let record = EvalRecord::new_rs(
1302            context,
1303            Utc::now(),
1304            "TRACE_UID_002".to_string(),
1305            "ENTITY_002".to_string(),
1306            None,
1307            None,
1308        );
1309
1310        let result = runtime.block_on(GenAIEvaluator::process_event_record(
1311            &record,
1312            Arc::new(profile),
1313            spans,
1314        ));
1315
1316        let eval_set = result.unwrap();
1317        assert_eq!(eval_set.passed_tasks(), 2);
1318        assert_eq!(eval_set.failed_tasks(), 0);
1319    }
1320
1321    #[test]
1322    fn test_evaluator_trace_attribute_extraction() {
1323        init_tracing();
1324        let runtime = tokio::runtime::Runtime::new().unwrap();
1325        let profile = runtime.block_on(create_trace_profile_with_attributes());
1326        let spans = Arc::new(create_trace_with_attributes());
1327
1328        let context = serde_json::json!({});
1329        let record = EvalRecord::new_rs(
1330            context,
1331            Utc::now(),
1332            "TRACE_UID_003".to_string(),
1333            "ENTITY_003".to_string(),
1334            None,
1335            None,
1336        );
1337
1338        let result = runtime.block_on(GenAIEvaluator::process_event_record(
1339            &record,
1340            Arc::new(profile),
1341            spans,
1342        ));
1343
1344        let eval_set = result.unwrap();
1345        assert_eq!(eval_set.passed_tasks(), 2);
1346        assert_eq!(eval_set.failed_tasks(), 0);
1347    }
1348
1349    #[test]
1350    fn test_evaluator_trace_sequence_pattern() {
1351        init_tracing();
1352        let runtime = tokio::runtime::Runtime::new().unwrap();
1353        let profile = runtime.block_on(create_trace_profile_complex());
1354        let spans = Arc::new(create_sequence_pattern_trace());
1355
1356        let context = serde_json::json!({});
1357        let record = EvalRecord::new_rs(
1358            context,
1359            Utc::now(),
1360            "TRACE_UID_004".to_string(),
1361            "ENTITY_004".to_string(),
1362            None,
1363            None,
1364        );
1365
1366        let result = runtime.block_on(GenAIEvaluator::process_event_record(
1367            &record,
1368            Arc::new(profile),
1369            spans,
1370        ));
1371
1372        let eval_set = result.unwrap();
1373        assert_eq!(eval_set.passed_tasks(), 3);
1374        assert_eq!(eval_set.failed_tasks(), 0);
1375    }
1376
1377    #[test]
1378    fn test_evaluator_trace_conditional_dependency() {
1379        let runtime = tokio::runtime::Runtime::new().unwrap();
1380        let profile = runtime.block_on(create_trace_profile_with_dependencies());
1381        let spans = Arc::new(create_trace_with_errors());
1382
1383        let context = serde_json::json!({});
1384        let record = EvalRecord::new_rs(
1385            context,
1386            Utc::now(),
1387            "TRACE_UID_005".to_string(),
1388            "ENTITY_005".to_string(),
1389            None,
1390            None,
1391        );
1392
1393        let result = runtime.block_on(GenAIEvaluator::process_event_record(
1394            &record,
1395            Arc::new(profile),
1396            spans,
1397        ));
1398
1399        let eval_set = result.unwrap();
1400        assert_eq!(eval_set.passed_tasks(), 1); // first task is conditional and is excluded
1401        assert_eq!(eval_set.failed_tasks(), 0);
1402    }
1403
1404    #[test]
1405    fn test_evaluator_trace_multi_service() {
1406        let runtime = tokio::runtime::Runtime::new().unwrap();
1407
1408        let task = TraceAssertionTask {
1409            id: "check_service_count".to_string(),
1410            assertion: TraceAssertion::TraceServiceCount {},
1411            operator: ComparisonOperator::Equals,
1412            expected_value: Value::Number(3.into()),
1413            description: Some("Verify three services".to_string()),
1414            task_type: EvaluationTaskType::TraceAssertion,
1415            depends_on: vec![],
1416            condition: false,
1417            result: None,
1418        };
1419
1420        let tasks = EvaluationTasks::new().add_task(task).build();
1421        let alert_config = GenAIAlertConfig::default();
1422        let drift_config =
1423            GenAIEvalConfig::new("scouter", "trace_test", "0.1.0", 1.0, alert_config, None)
1424                .unwrap();
1425
1426        let profile = runtime
1427            .block_on(GenAIEvalProfile::new(drift_config, tasks))
1428            .unwrap();
1429        let spans = Arc::new(create_multi_service_trace());
1430
1431        let context = serde_json::json!({});
1432        let record = EvalRecord::new_rs(
1433            context,
1434            Utc::now(),
1435            "TRACE_UID_006".to_string(),
1436            "ENTITY_006".to_string(),
1437            None,
1438            None,
1439        );
1440
1441        let result = runtime.block_on(GenAIEvaluator::process_event_record(
1442            &record,
1443            Arc::new(profile),
1444            spans,
1445        ));
1446
1447        let eval_set = result.unwrap();
1448        assert_eq!(eval_set.passed_tasks(), 1);
1449        assert_eq!(eval_set.failed_tasks(), 0);
1450    }
1451
1452    #[test]
1453    fn test_evaluator_trace_assertion_failure() {
1454        let runtime = tokio::runtime::Runtime::new().unwrap();
1455
1456        let task = TraceAssertionTask {
1457            id: "check_wrong_sequence".to_string(),
1458            assertion: TraceAssertion::SpanSequence {
1459                span_names: vec![
1460                    "root".to_string(),
1461                    "wrong_child".to_string(),
1462                    "child_2".to_string(),
1463                ],
1464            },
1465            operator: ComparisonOperator::Equals,
1466            expected_value: Value::Bool(true),
1467            description: Some("Verify incorrect span order".to_string()),
1468            task_type: EvaluationTaskType::TraceAssertion,
1469            depends_on: vec![],
1470            condition: false,
1471            result: None,
1472        };
1473
1474        let tasks = EvaluationTasks::new().add_task(task).build();
1475        let alert_config = GenAIAlertConfig::default();
1476        let drift_config =
1477            GenAIEvalConfig::new("scouter", "trace_test", "0.1.0", 1.0, alert_config, None)
1478                .unwrap();
1479
1480        let profile = runtime
1481            .block_on(GenAIEvalProfile::new(drift_config, tasks))
1482            .unwrap();
1483        let spans = Arc::new(create_simple_trace());
1484
1485        let context = serde_json::json!({});
1486        let record = EvalRecord::new_rs(
1487            context,
1488            Utc::now(),
1489            "TRACE_UID_007".to_string(),
1490            "ENTITY_007".to_string(),
1491            None,
1492            None,
1493        );
1494
1495        let result = runtime.block_on(GenAIEvaluator::process_event_record(
1496            &record,
1497            Arc::new(profile),
1498            spans,
1499        ));
1500
1501        let eval_set = result.unwrap();
1502        assert_eq!(eval_set.passed_tasks(), 0);
1503        assert_eq!(eval_set.failed_tasks(), 1);
1504    }
1505
1506    #[test]
1507    fn test_evaluator_trace_mixed_assertions() {
1508        init_tracing();
1509        let runtime = tokio::runtime::Runtime::new().unwrap();
1510
1511        let trace_task = TraceAssertionTask {
1512            id: "check_max_depth".to_string(),
1513            assertion: TraceAssertion::TraceMaxDepth {},
1514            operator: ComparisonOperator::Equals,
1515            expected_value: Value::Number(2.into()),
1516            description: Some("Verify max depth".to_string()),
1517            task_type: EvaluationTaskType::TraceAssertion,
1518            depends_on: vec![],
1519            condition: false,
1520            result: None,
1521        };
1522
1523        let regular_assertion = AssertionTask {
1524            id: "check_context".to_string(),
1525            context_path: Some("metadata.version".to_string()),
1526            operator: ComparisonOperator::Equals,
1527            expected_value: Value::String("1.0.0".to_string()),
1528            description: Some("Verify version".to_string()),
1529            task_type: EvaluationTaskType::Assertion,
1530            depends_on: vec![],
1531            result: None,
1532            condition: false,
1533            item_context_path: None,
1534        };
1535
1536        let tasks = EvaluationTasks::new()
1537            .add_task(trace_task)
1538            .add_task(regular_assertion)
1539            .build();
1540
1541        let alert_config = GenAIAlertConfig::default();
1542        let drift_config =
1543            GenAIEvalConfig::new("scouter", "trace_test", "0.1.0", 1.0, alert_config, None)
1544                .unwrap();
1545
1546        let profile = runtime
1547            .block_on(GenAIEvalProfile::new(drift_config, tasks))
1548            .unwrap();
1549        let spans = Arc::new(create_nested_trace());
1550
1551        let context = serde_json::json!({
1552            "metadata": {
1553                "version": "1.0.0"
1554            }
1555        });
1556
1557        let record = EvalRecord::new_rs(
1558            context,
1559            Utc::now(),
1560            "TRACE_UID_008".to_string(),
1561            "ENTITY_008".to_string(),
1562            None,
1563            None,
1564        );
1565
1566        let result = runtime.block_on(GenAIEvaluator::process_event_record(
1567            &record,
1568            Arc::new(profile),
1569            spans,
1570        ));
1571
1572        let eval_set = result.unwrap();
1573        assert_eq!(eval_set.passed_tasks(), 2);
1574        assert_eq!(eval_set.failed_tasks(), 0);
1575    }
1576
1577    #[test]
1578    fn test_evaluator_trace_duration_filter() {
1579        init_tracing();
1580        let runtime = tokio::runtime::Runtime::new().unwrap();
1581
1582        let task = TraceAssertionTask {
1583            id: "check_slow_spans".to_string(),
1584            assertion: TraceAssertion::SpanCount {
1585                filter: SpanFilter::WithDuration {
1586                    min_ms: Some(100.0),
1587                    max_ms: None,
1588                },
1589            },
1590            operator: ComparisonOperator::GreaterThanOrEqual,
1591            expected_value: Value::Number(2.into()),
1592            description: Some("Count spans over 100ms".to_string()),
1593            task_type: EvaluationTaskType::TraceAssertion,
1594            depends_on: vec![],
1595            condition: false,
1596            result: None,
1597        };
1598
1599        let tasks = EvaluationTasks::new().add_task(task).build();
1600        let alert_config = GenAIAlertConfig::default();
1601        let drift_config =
1602            GenAIEvalConfig::new("scouter", "trace_test", "0.1.0", 1.0, alert_config, None)
1603                .unwrap();
1604
1605        let profile = runtime
1606            .block_on(GenAIEvalProfile::new(drift_config, tasks))
1607            .unwrap();
1608        let spans = Arc::new(create_nested_trace());
1609
1610        let context = serde_json::json!({});
1611        let record = EvalRecord::new_rs(
1612            context,
1613            Utc::now(),
1614            "TRACE_UID_009".to_string(),
1615            "ENTITY_009".to_string(),
1616            None,
1617            None,
1618        );
1619
1620        let result = runtime.block_on(GenAIEvaluator::process_event_record(
1621            &record,
1622            Arc::new(profile),
1623            spans,
1624        ));
1625
1626        let eval_set = result.unwrap();
1627        assert_eq!(eval_set.passed_tasks(), 1);
1628        assert_eq!(eval_set.failed_tasks(), 0);
1629    }
1630}