Skip to main content

scouter_types/genai/
eval.rs

1use crate::error::TypeError;
2use crate::genai::traits::TaskAccessor;
3use crate::PyHelperFuncs;
4use core::fmt::Debug;
5use potato_head::prompt_types::Prompt;
6use potato_head::Provider;
7use pyo3::prelude::*;
8use pyo3::types::{PyBool, PyFloat, PyInt, PyList, PySlice, PyString};
9use pyo3::IntoPyObjectExt;
10use pythonize::{depythonize, pythonize};
11use serde::de::DeserializeOwned;
12use serde::{Deserialize, Serialize};
13use serde_json::Value;
14use std::collections::HashMap;
15use std::fmt::Display;
16use std::path::PathBuf;
17use std::str::FromStr;
18use tracing::error;
19
20pub fn deserialize_from_path<T: DeserializeOwned>(path: PathBuf) -> Result<T, TypeError> {
21    let content = std::fs::read_to_string(&path)?;
22
23    let extension = path
24        .extension()
25        .and_then(|ext| ext.to_str())
26        .ok_or_else(|| TypeError::Error(format!("Invalid file path: {:?}", path)))?;
27
28    let item = match extension.to_lowercase().as_str() {
29        "json" => serde_json::from_str(&content)?,
30        "yaml" | "yml" => serde_yaml::from_str(&content)?,
31        _ => {
32            return Err(TypeError::Error(format!(
33                "Unsupported file extension '{}'. Expected .json, .yaml, or .yml",
34                extension
35            )))
36        }
37    };
38
39    Ok(item)
40}
41
42// Default functions for task types during deserialization
43fn default_assertion_task_type() -> EvaluationTaskType {
44    EvaluationTaskType::Assertion
45}
46
47fn default_trace_assertion_task_type() -> EvaluationTaskType {
48    EvaluationTaskType::TraceAssertion
49}
50
51fn default_agent_assertion_task_type() -> EvaluationTaskType {
52    EvaluationTaskType::AgentAssertion
53}
54
55#[pyclass]
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
57pub struct AssertionResult {
58    #[pyo3(get)]
59    pub passed: bool,
60    pub actual: Value,
61
62    #[pyo3(get)]
63    pub message: String,
64
65    pub expected: Value,
66}
67
68impl AssertionResult {
69    pub fn new(passed: bool, actual: Value, message: String, expected: Value) -> Self {
70        Self {
71            passed,
72            actual,
73            message,
74            expected,
75        }
76    }
77    pub fn to_metric_value(&self) -> f64 {
78        if self.passed {
79            1.0
80        } else {
81            0.0
82        }
83    }
84}
85
86#[pymethods]
87impl AssertionResult {
88    pub fn __str__(&self) -> String {
89        PyHelperFuncs::__str__(self)
90    }
91
92    #[getter]
93    pub fn get_actual<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
94        let py_value = pythonize(py, &self.actual)?;
95        Ok(py_value)
96    }
97
98    #[getter]
99    pub fn get_expected<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
100        let py_value = pythonize(py, &self.expected)?;
101        Ok(py_value)
102    }
103}
104
105#[pyclass]
106#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
107pub struct AssertionResults {
108    #[pyo3(get)]
109    pub results: HashMap<String, AssertionResult>,
110}
111
112#[pymethods]
113impl AssertionResults {
114    pub fn __str__(&self) -> String {
115        PyHelperFuncs::__str__(self)
116    }
117
118    pub fn __getitem__(&self, key: &str) -> Result<AssertionResult, TypeError> {
119        if let Some(result) = self.results.get(key) {
120            Ok(result.clone())
121        } else {
122            Err(TypeError::KeyNotFound {
123                key: key.to_string(),
124            })
125        }
126    }
127}
128
129#[pyclass]
130#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
131pub struct AssertionTask {
132    #[pyo3(get, set)]
133    pub id: String,
134
135    #[pyo3(get, set)]
136    #[serde(default)]
137    pub context_path: Option<String>,
138
139    #[pyo3(get, set)]
140    #[serde(default)]
141    pub item_context_path: Option<String>,
142
143    #[pyo3(get, set)]
144    pub operator: ComparisonOperator,
145
146    pub expected_value: Value,
147
148    #[pyo3(get, set)]
149    #[serde(default)]
150    pub description: Option<String>,
151
152    #[pyo3(get, set)]
153    #[serde(default)]
154    pub depends_on: Vec<String>,
155
156    #[serde(default = "default_assertion_task_type")]
157    #[pyo3(get)]
158    pub task_type: EvaluationTaskType,
159
160    #[serde(skip_serializing_if = "Option::is_none")]
161    pub result: Option<AssertionResult>,
162
163    #[serde(default)]
164    pub condition: bool,
165}
166
167#[pymethods]
168impl AssertionTask {
169    #[new]
170    /// Creates a new AssertionTask
171    /// # Examples
172    /// ```python
173    ///
174    /// # assumed passed context at runtime
175    /// # context = {
176    /// #     "response": {
177    /// #         "user": {
178    /// #             "age": 25
179    /// #         }
180    /// #     }
181    /// # }
182    ///
183    /// task = AssertionTask(
184    ///     id="Check User Age",
185    ///     context_path="response.user.age",
186    ///     operator=ComparisonOperator.GREATER_THAN,
187    ///     expected_value=18,
188    ///     description="Check if user is an adult"
189    /// )
190    ///
191    /// # assumed passed context at runtime
192    /// # context = {
193    /// #     "user": {
194    /// #         "age": 25
195    /// #     }
196    /// # }
197    ///
198    /// task = AssertionTask(
199    ///     id="Check User Age",
200    ///     context_path="user.age",
201    ///     operator=ComparisonOperator.GREATER_THAN,
202    ///     expected_value=18,
203    ///     description="Check if user is an adult"
204    /// )
205    ///
206    ///  /// # assume non-map context at runtime
207    /// # context = 25
208    ///
209    /// task = AssertionTask(
210    ///     id="Check User Age",
211    ///     operator=ComparisonOperator.GREATER_THAN,
212    ///     expected_value=18,
213    ///     description="Check if user is an adult"
214    /// )
215    /// ```
216    /// # Arguments
217    /// * `context_path`: The path to the field to be asserted
218    /// * `operator`: The comparison operator to use
219    /// * `expected_value`: The expected value for the assertion
220    /// * `description`: Optional description for the assertion
221    /// # Returns
222    /// A new AssertionTask object
223    #[allow(clippy::too_many_arguments)]
224    #[pyo3(signature = (id, context_path, expected_value, operator, item_context_path=None, description=None, depends_on=None, condition=None))]
225    pub fn new(
226        id: String,
227        context_path: Option<String>,
228        expected_value: &Bound<'_, PyAny>,
229        operator: ComparisonOperator,
230        item_context_path: Option<String>,
231        description: Option<String>,
232        depends_on: Option<Vec<String>>,
233        condition: Option<bool>,
234    ) -> Result<Self, TypeError> {
235        let expected_value = depythonize(expected_value)?;
236        let condition = condition.unwrap_or(false);
237
238        Ok(Self {
239            id: id.to_lowercase(),
240            context_path,
241            item_context_path,
242            operator,
243            expected_value,
244            description,
245            task_type: EvaluationTaskType::Assertion,
246            depends_on: depends_on.unwrap_or_default(),
247            result: None,
248            condition,
249        })
250    }
251
252    #[getter]
253    pub fn get_expected_value<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
254        let py_value = pythonize(py, &self.expected_value)?;
255        Ok(py_value)
256    }
257
258    #[staticmethod]
259    pub fn from_path(path: PathBuf) -> Result<Self, TypeError> {
260        deserialize_from_path(path)
261    }
262}
263
264impl AssertionTask {}
265
266impl TaskAccessor for AssertionTask {
267    fn context_path(&self) -> Option<&str> {
268        self.context_path.as_deref()
269    }
270
271    fn item_context_path(&self) -> Option<&str> {
272        self.item_context_path.as_deref()
273    }
274
275    fn id(&self) -> &str {
276        &self.id
277    }
278
279    fn operator(&self) -> &ComparisonOperator {
280        &self.operator
281    }
282
283    fn task_type(&self) -> &EvaluationTaskType {
284        &self.task_type
285    }
286
287    fn expected_value(&self) -> &Value {
288        &self.expected_value
289    }
290
291    fn depends_on(&self) -> &[String] {
292        &self.depends_on
293    }
294
295    fn add_result(&mut self, result: AssertionResult) {
296        self.result = Some(result);
297    }
298}
299
300pub trait ValueExt {
301    /// Convert value to length for HasLength comparisons
302    fn to_length(&self) -> Option<i64>;
303
304    /// Extract numeric value for comparison
305    fn as_numeric(&self) -> Option<f64>;
306
307    /// Check if value is truthy
308    fn is_truthy(&self) -> bool;
309}
310
311impl ValueExt for Value {
312    fn to_length(&self) -> Option<i64> {
313        match self {
314            Value::Array(arr) => Some(arr.len() as i64),
315            Value::String(s) => Some(s.chars().count() as i64),
316            Value::Object(obj) => Some(obj.len() as i64),
317            _ => None,
318        }
319    }
320
321    fn as_numeric(&self) -> Option<f64> {
322        match self {
323            Value::Number(n) => n.as_f64(),
324            _ => None,
325        }
326    }
327
328    fn is_truthy(&self) -> bool {
329        match self {
330            Value::Null => false,
331            Value::Bool(b) => *b,
332            Value::Number(n) => n.as_f64() != Some(0.0),
333            Value::String(s) => !s.is_empty(),
334            Value::Array(arr) => !arr.is_empty(),
335            Value::Object(obj) => !obj.is_empty(),
336        }
337    }
338}
339
340/// Primary class for defining an LLM as a Judge in evaluation workflows
341#[pyclass]
342#[derive(Debug, Serialize, Clone, PartialEq)]
343pub struct LLMJudgeTask {
344    #[pyo3(get, set)]
345    pub id: String,
346
347    #[pyo3(get)]
348    pub prompt: Prompt,
349
350    #[pyo3(get)]
351    #[serde(default)]
352    pub context_path: Option<String>,
353
354    pub expected_value: Value,
355
356    #[pyo3(get)]
357    pub operator: ComparisonOperator,
358
359    #[pyo3(get)]
360    pub task_type: EvaluationTaskType,
361
362    #[pyo3(get, set)]
363    #[serde(default)]
364    pub depends_on: Vec<String>,
365
366    #[pyo3(get, set)]
367    #[serde(default)]
368    pub max_retries: Option<u32>,
369
370    #[serde(skip_serializing_if = "Option::is_none")]
371    pub result: Option<AssertionResult>,
372
373    #[serde(default)]
374    pub description: Option<String>,
375
376    #[pyo3(get, set)]
377    #[serde(default)]
378    pub condition: bool,
379}
380
381#[derive(Debug, Deserialize)]
382#[serde(untagged)]
383pub enum PromptConfig {
384    Path { path: String },
385    Inline(Box<Prompt>),
386}
387
388#[derive(Debug, Deserialize)]
389struct LLMJudgeTaskConfig {
390    pub id: String,
391    pub prompt: PromptConfig,
392    pub expected_value: Value,
393    pub operator: ComparisonOperator,
394    pub context_path: Option<String>,
395    pub description: Option<String>,
396    pub depends_on: Vec<String>,
397    pub max_retries: Option<u32>,
398    pub condition: bool,
399}
400
401impl LLMJudgeTaskConfig {
402    pub fn into_task(self) -> Result<LLMJudgeTask, TypeError> {
403        let prompt = match self.prompt {
404            PromptConfig::Path { path } => {
405                Prompt::from_path(PathBuf::from(path)).inspect_err(|e| {
406                    error!("Failed to deserialize Prompt from path: {}", e);
407                })?
408            }
409            PromptConfig::Inline(prompt) => *prompt,
410        };
411
412        Ok(LLMJudgeTask {
413            id: self.id.to_lowercase(),
414            prompt,
415            expected_value: self.expected_value,
416            operator: self.operator,
417            context_path: self.context_path,
418            description: self.description,
419            depends_on: self.depends_on,
420            max_retries: self.max_retries.or(Some(3)),
421            task_type: EvaluationTaskType::LLMJudge,
422            result: None,
423            condition: self.condition,
424        })
425    }
426}
427
428#[derive(Debug, Deserialize)]
429struct LLMJudgeTaskInternal {
430    pub id: String,
431    pub prompt: Prompt,
432    pub context_path: Option<String>,
433    pub expected_value: Value,
434    pub operator: ComparisonOperator,
435    pub task_type: EvaluationTaskType,
436    pub depends_on: Vec<String>,
437    pub max_retries: Option<u32>,
438    pub result: Option<AssertionResult>,
439    pub description: Option<String>,
440    pub condition: bool,
441}
442impl LLMJudgeTaskInternal {
443    pub fn into_task(self) -> LLMJudgeTask {
444        LLMJudgeTask {
445            id: self.id.to_lowercase(),
446            prompt: self.prompt,
447            context_path: self.context_path,
448            expected_value: self.expected_value,
449            operator: self.operator,
450            task_type: self.task_type,
451            depends_on: self.depends_on,
452            max_retries: self.max_retries.or(Some(3)),
453            result: self.result,
454            description: self.description,
455            condition: self.condition,
456        }
457    }
458}
459
460#[derive(Debug, Deserialize)]
461#[serde(untagged)]
462enum LLMJudgeFormat {
463    Full(Box<LLMJudgeTaskInternal>),
464    Generic(LLMJudgeTaskConfig),
465}
466
467impl<'de> Deserialize<'de> for LLMJudgeTask {
468    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
469    where
470        D: serde::Deserializer<'de>,
471    {
472        let format = LLMJudgeFormat::deserialize(deserializer)?;
473
474        match format {
475            LLMJudgeFormat::Generic(config) => config.into_task().map_err(serde::de::Error::custom),
476            LLMJudgeFormat::Full(internal) => Ok(internal.into_task()),
477        }
478    }
479}
480
481#[pymethods]
482impl LLMJudgeTask {
483    /// Creates a new LLMJudgeTask
484    /// # Examples
485    /// ```python
486    /// task = LLMJudgeTask(
487    ///     id="Sentiment Analysis Judge",
488    ///     prompt=prompt_object,
489    ///     expected_value="Positive",
490    ///     operator=ComparisonOperator.EQUALS
491    /// )
492    /// # Arguments
493    /// * `id: The id of the judge task
494    /// * `prompt`: The prompt object to be used for evaluation
495    /// * `expected_value`: The expected value for the judgement
496    /// * `context_path`: Optional context path to extract from the context for evaluation
497    /// * `operator`: The comparison operator to use
498    /// * `depends_on`: Optional list of task IDs this task depends on
499    /// * `max_retries`: Optional maximum number of retries for this task (defaults to 3 if not provided)
500    /// # Returns
501    /// A new LLMJudgeTask object
502    #[new]
503    #[pyo3(signature = (id, prompt, expected_value,  context_path,operator, description=None, depends_on=None, max_retries=None, condition=None))]
504    #[allow(clippy::too_many_arguments)]
505    pub fn new(
506        id: &str,
507        prompt: Prompt,
508        expected_value: &Bound<'_, PyAny>,
509        context_path: Option<String>,
510        operator: ComparisonOperator,
511        description: Option<String>,
512        depends_on: Option<Vec<String>>,
513        max_retries: Option<u32>,
514        condition: Option<bool>,
515    ) -> Result<Self, TypeError> {
516        let expected_value = depythonize(expected_value)?;
517
518        Ok(Self {
519            id: id.to_lowercase(),
520            prompt,
521            expected_value,
522            operator,
523            task_type: EvaluationTaskType::LLMJudge,
524            depends_on: depends_on.unwrap_or_default(),
525            max_retries: max_retries.or(Some(3)),
526            context_path,
527            result: None,
528            description,
529            condition: condition.unwrap_or(false),
530        })
531    }
532
533    pub fn __str__(&self) -> String {
534        // serialize the struct to a string
535        PyHelperFuncs::__str__(self)
536    }
537
538    #[getter]
539    pub fn get_expected_value<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
540        let py_value = pythonize(py, &self.expected_value)?;
541        Ok(py_value)
542    }
543
544    #[staticmethod]
545    pub fn from_path(path: PathBuf) -> Result<Self, TypeError> {
546        deserialize_from_path(path)
547    }
548}
549
550impl LLMJudgeTask {
551    /// Creates a new LLMJudgeTask with Rust types
552    /// # Arguments
553    /// * `id: The id of the judge task
554    /// * `prompt`: The prompt object to be used for evaluation
555    /// * `expected_value`: The expected value for the judgement
556    /// * `context_path`: Optional context path to extract from the context for evaluation
557    /// * `operator`: The comparison operator to use
558    /// * `depends_on`: Optional list of task IDs this task depends on
559    /// * `max_retries`: Optional maximum number of retries for this task (defaults to 3 if not provided)
560    /// # Returns
561    /// A new LLMJudgeTask object
562    #[allow(clippy::too_many_arguments)]
563    pub fn new_rs(
564        id: &str,
565        prompt: Prompt,
566        expected_value: Value,
567        context_path: Option<String>,
568        operator: ComparisonOperator,
569        depends_on: Option<Vec<String>>,
570        max_retries: Option<u32>,
571        description: Option<String>,
572        condition: Option<bool>,
573    ) -> Self {
574        Self {
575            id: id.to_lowercase(),
576            prompt,
577            expected_value,
578            operator,
579            task_type: EvaluationTaskType::LLMJudge,
580            depends_on: depends_on.unwrap_or_default(),
581            max_retries: max_retries.or(Some(3)),
582            context_path,
583            result: None,
584            description,
585            condition: condition.unwrap_or(false),
586        }
587    }
588}
589
590impl TaskAccessor for LLMJudgeTask {
591    fn context_path(&self) -> Option<&str> {
592        self.context_path.as_deref()
593    }
594
595    fn item_context_path(&self) -> Option<&str> {
596        None
597    }
598
599    fn id(&self) -> &str {
600        &self.id
601    }
602    fn task_type(&self) -> &EvaluationTaskType {
603        &self.task_type
604    }
605
606    fn operator(&self) -> &ComparisonOperator {
607        &self.operator
608    }
609
610    fn expected_value(&self) -> &Value {
611        &self.expected_value
612    }
613    fn depends_on(&self) -> &[String] {
614        &self.depends_on
615    }
616    fn add_result(&mut self, result: AssertionResult) {
617        self.result = Some(result);
618    }
619}
620
621#[pyclass(eq, eq_int)]
622#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
623pub enum SpanStatus {
624    Ok,
625    Error,
626    Unset,
627}
628
629#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
630pub struct PyValueWrapper(pub Value);
631
632impl<'py> IntoPyObject<'py> for PyValueWrapper {
633    type Target = PyAny;
634    type Output = Bound<'py, Self::Target>;
635    type Error = TypeError;
636
637    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
638        pythonize(py, &self.0).map_err(TypeError::from)
639    }
640}
641
642impl<'a, 'py> FromPyObject<'a, 'py> for PyValueWrapper {
643    type Error = TypeError;
644
645    fn extract(ob: pyo3::Borrowed<'a, 'py, pyo3::PyAny>) -> Result<Self, Self::Error> {
646        let value: Value = depythonize(&ob)?;
647        Ok(PyValueWrapper(value))
648    }
649}
650
651/// Filter configuration for selecting spans to assert on
652#[pyclass(eq)]
653#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
654pub enum SpanFilter {
655    /// Match spans by exact name
656    ByName { name: String },
657
658    /// Match spans by name pattern (regex)
659    ByNamePattern { pattern: String },
660
661    /// Match spans with specific attribute key
662    WithAttribute { key: String },
663
664    /// Match spans with specific attribute key-value pair
665    WithAttributeValue { key: String, value: PyValueWrapper },
666
667    /// Match spans by status code
668    WithStatus { status: SpanStatus },
669
670    /// Match spans with duration constraints
671    WithDuration {
672        min_ms: Option<f64>,
673        max_ms: Option<f64>,
674    },
675
676    /// Match a sequence of span names in order
677    Sequence { names: Vec<String> },
678
679    /// Combine multiple filters with AND logic
680    And { filters: Vec<SpanFilter> },
681
682    /// Combine multiple filters with OR logic
683    Or { filters: Vec<SpanFilter> },
684}
685
686#[pymethods]
687impl SpanFilter {
688    #[staticmethod]
689    pub fn by_name(name: String) -> Self {
690        SpanFilter::ByName { name }
691    }
692
693    #[staticmethod]
694    pub fn by_name_pattern(pattern: String) -> Self {
695        SpanFilter::ByNamePattern { pattern }
696    }
697
698    #[staticmethod]
699    pub fn with_attribute(key: String) -> Self {
700        SpanFilter::WithAttribute { key }
701    }
702
703    #[staticmethod]
704    pub fn with_attribute_value(key: String, value: &Bound<'_, PyAny>) -> Result<Self, TypeError> {
705        let value = PyValueWrapper(depythonize(value)?);
706        Ok(SpanFilter::WithAttributeValue { key, value })
707    }
708
709    #[staticmethod]
710    pub fn with_status(status: SpanStatus) -> Self {
711        SpanFilter::WithStatus { status }
712    }
713
714    #[staticmethod]
715    #[pyo3(signature = (min_ms=None, max_ms=None))]
716    pub fn with_duration(min_ms: Option<f64>, max_ms: Option<f64>) -> Self {
717        SpanFilter::WithDuration { min_ms, max_ms }
718    }
719
720    #[staticmethod]
721    pub fn sequence(names: Vec<String>) -> Self {
722        SpanFilter::Sequence { names }
723    }
724
725    pub fn and_(&self, other: SpanFilter) -> Self {
726        match self {
727            SpanFilter::And { filters } => {
728                let mut new_filters = filters.clone();
729                new_filters.push(other);
730                SpanFilter::And {
731                    filters: new_filters,
732                }
733            }
734            _ => SpanFilter::And {
735                filters: vec![self.clone(), other],
736            },
737        }
738    }
739
740    pub fn or_(&self, other: SpanFilter) -> Self {
741        match self {
742            SpanFilter::Or { filters } => {
743                let mut new_filters = filters.clone();
744                new_filters.push(other);
745                SpanFilter::Or {
746                    filters: new_filters,
747                }
748            }
749            _ => SpanFilter::Or {
750                filters: vec![self.clone(), other],
751            },
752        }
753    }
754}
755
756#[pyclass(eq, eq_int)]
757#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
758pub enum AggregationType {
759    Count,
760    Sum,
761    Average,
762    Min,
763    Max,
764    First,
765    Last,
766}
767
768/// Mode for aggregating results across multiple attribute values
769#[pyclass(eq)]
770#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
771pub enum MultiResponseMode {
772    /// At least one value must pass
773    Any,
774    /// All values must pass
775    All,
776}
777
778/// Inner task to run on each extracted attribute value
779#[pyclass]
780#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
781pub enum AttributeFilterTask {
782    /// Run a deterministic assertion on the raw attribute value
783    Assertion(AssertionTask),
784    /// Parse through AgentContextBuilder then run agent assertion
785    AgentAssertion(AgentAssertionTask),
786}
787
788#[pymethods]
789impl AttributeFilterTask {
790    #[staticmethod]
791    pub fn assertion(task: AssertionTask) -> Self {
792        AttributeFilterTask::Assertion(task)
793    }
794
795    #[staticmethod]
796    pub fn agent_assertion(task: AgentAssertionTask) -> Self {
797        AttributeFilterTask::AgentAssertion(task)
798    }
799}
800
801/// Unified assertion target that can operate on traces or filtered spans
802#[pyclass(eq)]
803#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
804#[allow(clippy::large_enum_variant)]
805pub enum TraceAssertion {
806    /// Check if spans exist in a specific order
807    SpanSequence { span_names: Vec<String> },
808
809    /// Check if all specified span names exist (order doesn't matter)
810    SpanSet { span_names: Vec<String> },
811
812    /// Count spans matching a filter
813    SpanCount { filter: SpanFilter },
814
815    /// Check if any span matching filter exists
816    SpanExists { filter: SpanFilter },
817
818    /// Get attribute value from span(s) matching filter
819    SpanAttribute {
820        filter: SpanFilter,
821        attribute_key: String,
822    },
823
824    /// Get duration of span(s) matching filter
825    SpanDuration { filter: SpanFilter },
826
827    /// Aggregate a numeric attribute across filtered spans
828    SpanAggregation {
829        filter: SpanFilter,
830        attribute_key: String,
831        aggregation: AggregationType,
832    },
833
834    /// Check total duration of entire trace
835    TraceDuration {},
836
837    /// Count total spans in trace
838    TraceSpanCount {},
839
840    /// Count spans with errors in trace
841    TraceErrorCount {},
842
843    /// Count unique services in trace
844    TraceServiceCount {},
845
846    /// Get maximum depth of span tree
847    TraceMaxDepth {},
848
849    /// Get trace-level attribute
850    TraceAttribute { attribute_key: String },
851
852    /// Filter spans by attribute key, run inner task on each value, aggregate results
853    AttributeFilter {
854        key: String,
855        task: AttributeFilterTask,
856        mode: MultiResponseMode,
857    },
858}
859
860// implement to_string for TraceAssertion
861impl Display for TraceAssertion {
862    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
863        // serde serialize to string
864        let s = serde_json::to_string(self).unwrap_or_default();
865        write!(f, "{}", s)
866    }
867}
868
869#[pymethods]
870impl TraceAssertion {
871    #[staticmethod]
872    pub fn span_sequence(span_names: Vec<String>) -> Self {
873        TraceAssertion::SpanSequence { span_names }
874    }
875
876    #[staticmethod]
877    pub fn span_set(span_names: Vec<String>) -> Self {
878        TraceAssertion::SpanSet { span_names }
879    }
880
881    #[staticmethod]
882    pub fn span_count(filter: SpanFilter) -> Self {
883        TraceAssertion::SpanCount { filter }
884    }
885
886    #[staticmethod]
887    pub fn span_exists(filter: SpanFilter) -> Self {
888        TraceAssertion::SpanExists { filter }
889    }
890
891    #[staticmethod]
892    pub fn span_attribute(filter: SpanFilter, attribute_key: String) -> Self {
893        TraceAssertion::SpanAttribute {
894            filter,
895            attribute_key,
896        }
897    }
898
899    #[staticmethod]
900    pub fn span_duration(filter: SpanFilter) -> Self {
901        TraceAssertion::SpanDuration { filter }
902    }
903
904    #[staticmethod]
905    pub fn span_aggregation(
906        filter: SpanFilter,
907        attribute_key: String,
908        aggregation: AggregationType,
909    ) -> Self {
910        TraceAssertion::SpanAggregation {
911            filter,
912            attribute_key,
913            aggregation,
914        }
915    }
916
917    #[staticmethod]
918    pub fn trace_duration() -> Self {
919        TraceAssertion::TraceDuration {}
920    }
921
922    #[staticmethod]
923    pub fn trace_span_count() -> Self {
924        TraceAssertion::TraceSpanCount {}
925    }
926
927    #[staticmethod]
928    pub fn trace_error_count() -> Self {
929        TraceAssertion::TraceErrorCount {}
930    }
931
932    #[staticmethod]
933    pub fn trace_service_count() -> Self {
934        TraceAssertion::TraceServiceCount {}
935    }
936
937    #[staticmethod]
938    pub fn trace_max_depth() -> Self {
939        TraceAssertion::TraceMaxDepth {}
940    }
941
942    #[staticmethod]
943    pub fn trace_attribute(attribute_key: String) -> Self {
944        TraceAssertion::TraceAttribute { attribute_key }
945    }
946
947    #[staticmethod]
948    pub fn attribute_filter(
949        key: String,
950        task: AttributeFilterTask,
951        mode: MultiResponseMode,
952    ) -> Self {
953        TraceAssertion::AttributeFilter { key, task, mode }
954    }
955
956    pub fn model_dump_json(&self) -> String {
957        serde_json::to_string(self).unwrap_or_default()
958    }
959}
960
961#[pyclass]
962#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
963pub struct TraceAssertionTask {
964    #[pyo3(get, set)]
965    pub id: String,
966
967    #[pyo3(get, set)]
968    pub assertion: TraceAssertion,
969
970    #[pyo3(get, set)]
971    pub operator: ComparisonOperator,
972
973    pub expected_value: Value,
974
975    #[pyo3(get, set)]
976    #[serde(default)]
977    pub description: Option<String>,
978
979    #[pyo3(get, set)]
980    #[serde(default)]
981    pub depends_on: Vec<String>,
982
983    #[serde(default = "default_trace_assertion_task_type")]
984    #[pyo3(get)]
985    pub task_type: EvaluationTaskType,
986
987    #[serde(skip_serializing_if = "Option::is_none")]
988    pub result: Option<AssertionResult>,
989
990    #[pyo3(get, set)]
991    #[serde(default)]
992    pub condition: bool,
993}
994
995#[pymethods]
996impl TraceAssertionTask {
997    /// Creates a new TraceAssertionTask
998    ///
999    /// # Examples
1000    /// ```python
1001    /// # Check execution order of spans
1002    /// task = TraceAssertionTask(
1003    ///     id="verify_agent_workflow",
1004    ///     assertion=TraceAssertion.span_sequence(["call_tool", "run_agent", "double_check"]),
1005    ///     operator=ComparisonOperator.SequenceMatches,
1006    ///     expected_value=True
1007    /// )
1008    ///
1009    /// # Check all required spans exist
1010    /// task = TraceAssertionTask(
1011    ///     id="verify_required_steps",
1012    ///     assertion=TraceAssertion.span_set(["call_tool", "run_agent", "double_check"]),
1013    ///     operator=ComparisonOperator.ContainsAll,
1014    ///     expected_value=True
1015    /// )
1016    ///
1017    /// # Check total trace duration
1018    /// task = TraceAssertionTask(
1019    ///     id="verify_performance",
1020    ///     assertion=TraceAssertion.trace_duration(),
1021    ///     operator=ComparisonOperator.LessThan,
1022    ///     expected_value=5000.0
1023    /// )
1024    ///
1025    /// # Check count of specific spans
1026    /// task = TraceAssertionTask(
1027    ///     id="verify_retry_count",
1028    ///     assertion=TraceAssertion.span_count(
1029    ///         SpanFilter.by_name("retry_operation")
1030    ///     ),
1031    ///     operator=ComparisonOperator.LessThanOrEqual,
1032    ///     expected_value=3
1033    /// )
1034    ///
1035    /// # Check span attribute
1036    /// task = TraceAssertionTask(
1037    ///     id="verify_model_used",
1038    ///     assertion=TraceAssertion.span_attribute(
1039    ///         SpanFilter.by_name("llm.generate"),
1040    ///         "model"
1041    ///     ),
1042    ///     operator=ComparisonOperator.Equals,
1043    ///     expected_value="gpt-4"
1044    /// )
1045    /// ```
1046    #[new]
1047    /// Creates a new TraceAssertionTask
1048    #[pyo3(signature = (id, assertion, expected_value, operator, description=None, depends_on=None, condition=None))]
1049    pub fn new(
1050        id: String,
1051        assertion: TraceAssertion,
1052        expected_value: &Bound<'_, PyAny>,
1053        operator: ComparisonOperator,
1054        description: Option<String>,
1055        depends_on: Option<Vec<String>>,
1056        condition: Option<bool>,
1057    ) -> Result<Self, TypeError> {
1058        let expected_value = depythonize(expected_value)?;
1059
1060        Ok(Self {
1061            id: id.to_lowercase(),
1062            assertion,
1063            operator,
1064            expected_value,
1065            description,
1066            task_type: EvaluationTaskType::TraceAssertion,
1067            depends_on: depends_on.unwrap_or_default(),
1068            result: None,
1069            condition: condition.unwrap_or(false),
1070        })
1071    }
1072
1073    pub fn __str__(&self) -> String {
1074        // serialize the struct to a string
1075        PyHelperFuncs::__str__(self)
1076    }
1077
1078    #[getter]
1079    pub fn get_expected_value<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
1080        let py_value = pythonize(py, &self.expected_value)?;
1081        Ok(py_value)
1082    }
1083
1084    #[staticmethod]
1085    pub fn from_path(path: PathBuf) -> Result<Self, TypeError> {
1086        deserialize_from_path(path)
1087    }
1088}
1089
1090impl TaskAccessor for TraceAssertionTask {
1091    fn context_path(&self) -> Option<&str> {
1092        None
1093    }
1094
1095    fn item_context_path(&self) -> Option<&str> {
1096        None
1097    }
1098
1099    fn id(&self) -> &str {
1100        &self.id
1101    }
1102
1103    fn operator(&self) -> &ComparisonOperator {
1104        &self.operator
1105    }
1106
1107    fn task_type(&self) -> &EvaluationTaskType {
1108        &self.task_type
1109    }
1110
1111    fn expected_value(&self) -> &Value {
1112        &self.expected_value
1113    }
1114
1115    fn depends_on(&self) -> &[String] {
1116        &self.depends_on
1117    }
1118
1119    fn add_result(&mut self, result: AssertionResult) {
1120        self.result = Some(result);
1121    }
1122}
1123
1124#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1125pub struct ToolCall {
1126    pub name: String,
1127    pub arguments: Value,
1128    pub result: Option<Value>,
1129    pub call_id: Option<String>,
1130}
1131
1132#[pyclass]
1133#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1134pub struct TokenUsage {
1135    #[pyo3(get, set)]
1136    pub input_tokens: Option<i64>,
1137
1138    #[pyo3(get, set)]
1139    pub output_tokens: Option<i64>,
1140
1141    #[pyo3(get, set)]
1142    pub total_tokens: Option<i64>,
1143}
1144
1145#[pymethods]
1146impl TokenUsage {
1147    #[new]
1148    #[pyo3(signature = (input_tokens=None, output_tokens=None, total_tokens=None))]
1149    pub fn new(
1150        input_tokens: Option<i64>,
1151        output_tokens: Option<i64>,
1152        total_tokens: Option<i64>,
1153    ) -> Self {
1154        Self {
1155            input_tokens,
1156            output_tokens,
1157            total_tokens,
1158        }
1159    }
1160
1161    pub fn __str__(&self) -> String {
1162        serde_json::to_string_pretty(self).unwrap_or_default()
1163    }
1164}
1165
1166#[pyclass(eq)]
1167#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1168pub enum AgentAssertion {
1169    /// Check if a specific tool was called
1170    ToolCalled { name: String },
1171
1172    /// Check if a specific tool was NOT called
1173    ToolNotCalled { name: String },
1174
1175    /// Check if a tool was called with specific arguments (partial match)
1176    ToolCalledWithArgs {
1177        name: String,
1178        arguments: PyValueWrapper,
1179    },
1180
1181    /// Check if tools were called in exact sequence
1182    ToolCallSequence { names: Vec<String> },
1183
1184    /// Count tool calls (optionally filtered by name)
1185    ToolCallCount { name: Option<String> },
1186
1187    /// Extract a tool argument value
1188    ToolArgument { name: String, argument_key: String },
1189
1190    /// Extract a tool result value
1191    ToolResult { name: String },
1192
1193    /// Get the text content of the response
1194    ResponseContent {},
1195
1196    /// Get the model name
1197    ResponseModel {},
1198
1199    /// Get the finish/stop reason
1200    ResponseFinishReason {},
1201
1202    /// Get input token count
1203    ResponseInputTokens {},
1204
1205    /// Get output token count
1206    ResponseOutputTokens {},
1207
1208    /// Get total token count
1209    ResponseTotalTokens {},
1210
1211    /// Extract a field from the raw (un-normalized) response via context_path
1212    ResponseField { path: String },
1213}
1214
1215impl Display for AgentAssertion {
1216    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1217        let s = serde_json::to_string(self).unwrap_or_default();
1218        write!(f, "{}", s)
1219    }
1220}
1221
1222#[pymethods]
1223impl AgentAssertion {
1224    #[staticmethod]
1225    pub fn tool_called(name: &str) -> Self {
1226        AgentAssertion::ToolCalled {
1227            name: name.to_string(),
1228        }
1229    }
1230
1231    #[staticmethod]
1232    pub fn tool_not_called(name: &str) -> Self {
1233        AgentAssertion::ToolNotCalled {
1234            name: name.to_string(),
1235        }
1236    }
1237
1238    #[staticmethod]
1239    pub fn tool_called_with_args(
1240        name: &str,
1241        arguments: &Bound<'_, PyAny>,
1242    ) -> Result<Self, TypeError> {
1243        let arguments: Value = depythonize(arguments)?;
1244        Ok(AgentAssertion::ToolCalledWithArgs {
1245            name: name.to_string(),
1246            arguments: PyValueWrapper(arguments),
1247        })
1248    }
1249
1250    #[staticmethod]
1251    pub fn tool_call_sequence(names: Vec<String>) -> Self {
1252        AgentAssertion::ToolCallSequence { names }
1253    }
1254
1255    #[staticmethod]
1256    #[pyo3(signature = (name=None))]
1257    pub fn tool_call_count(name: Option<String>) -> Self {
1258        AgentAssertion::ToolCallCount { name }
1259    }
1260
1261    #[staticmethod]
1262    pub fn tool_argument(name: &str, argument_key: &str) -> Self {
1263        AgentAssertion::ToolArgument {
1264            name: name.to_string(),
1265            argument_key: argument_key.to_string(),
1266        }
1267    }
1268
1269    #[staticmethod]
1270    pub fn tool_result(name: &str) -> Self {
1271        AgentAssertion::ToolResult {
1272            name: name.to_string(),
1273        }
1274    }
1275
1276    #[staticmethod]
1277    pub fn response_content() -> Self {
1278        AgentAssertion::ResponseContent {}
1279    }
1280
1281    #[staticmethod]
1282    pub fn response_model() -> Self {
1283        AgentAssertion::ResponseModel {}
1284    }
1285
1286    #[staticmethod]
1287    pub fn response_finish_reason() -> Self {
1288        AgentAssertion::ResponseFinishReason {}
1289    }
1290
1291    #[staticmethod]
1292    pub fn response_input_tokens() -> Self {
1293        AgentAssertion::ResponseInputTokens {}
1294    }
1295
1296    #[staticmethod]
1297    pub fn response_output_tokens() -> Self {
1298        AgentAssertion::ResponseOutputTokens {}
1299    }
1300
1301    #[staticmethod]
1302    pub fn response_total_tokens() -> Self {
1303        AgentAssertion::ResponseTotalTokens {}
1304    }
1305
1306    #[staticmethod]
1307    pub fn response_field(path: &str) -> Self {
1308        AgentAssertion::ResponseField {
1309            path: path.to_string(),
1310        }
1311    }
1312
1313    pub fn __str__(&self) -> String {
1314        serde_json::to_string_pretty(self).unwrap_or_default()
1315    }
1316}
1317
1318#[pyclass]
1319#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1320pub struct AgentAssertionTask {
1321    #[pyo3(get, set)]
1322    pub id: String,
1323
1324    #[pyo3(get, set)]
1325    pub assertion: AgentAssertion,
1326
1327    #[pyo3(get, set)]
1328    pub operator: ComparisonOperator,
1329
1330    pub expected_value: Value,
1331
1332    #[pyo3(get, set)]
1333    #[serde(default)]
1334    pub description: Option<String>,
1335
1336    #[pyo3(get, set)]
1337    #[serde(default)]
1338    pub depends_on: Vec<String>,
1339
1340    #[serde(default = "default_agent_assertion_task_type")]
1341    #[pyo3(get)]
1342    pub task_type: EvaluationTaskType,
1343
1344    #[serde(skip_serializing_if = "Option::is_none")]
1345    pub result: Option<AssertionResult>,
1346
1347    #[pyo3(get, set)]
1348    #[serde(default)]
1349    pub condition: bool,
1350
1351    #[pyo3(get, set)]
1352    #[serde(default)]
1353    pub provider: Option<Provider>,
1354}
1355
1356#[pymethods]
1357impl AgentAssertionTask {
1358    #[new]
1359    #[pyo3(signature = (id, assertion, expected_value, operator, description=None, depends_on=None, condition=None, provider=None))]
1360    #[allow(clippy::too_many_arguments)]
1361    pub fn new(
1362        id: String,
1363        assertion: AgentAssertion,
1364        expected_value: &Bound<'_, PyAny>,
1365        operator: ComparisonOperator,
1366        description: Option<String>,
1367        depends_on: Option<Vec<String>>,
1368        condition: Option<bool>,
1369        provider: Option<Provider>,
1370    ) -> Result<Self, TypeError> {
1371        let expected_value = depythonize(expected_value)?;
1372
1373        Ok(Self {
1374            id: id.to_lowercase(),
1375            assertion,
1376            operator,
1377            expected_value,
1378            description,
1379            task_type: EvaluationTaskType::AgentAssertion,
1380            depends_on: depends_on.unwrap_or_default(),
1381            result: None,
1382            condition: condition.unwrap_or(false),
1383            provider,
1384        })
1385    }
1386
1387    pub fn __str__(&self) -> String {
1388        PyHelperFuncs::__str__(self)
1389    }
1390
1391    pub fn model_dump_json(&self) -> String {
1392        serde_json::to_string(self).unwrap_or_default()
1393    }
1394
1395    #[getter]
1396    pub fn get_expected_value<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
1397        let py_value = pythonize(py, &self.expected_value)?;
1398        Ok(py_value)
1399    }
1400
1401    #[getter]
1402    pub fn get_result<'py>(&self, py: Python<'py>) -> Result<Option<Bound<'py, PyAny>>, TypeError> {
1403        match &self.result {
1404            Some(result) => {
1405                let py_value = pythonize(py, result)?;
1406                Ok(Some(py_value))
1407            }
1408            None => Ok(None),
1409        }
1410    }
1411
1412    #[staticmethod]
1413    pub fn from_path(path: PathBuf) -> Result<Self, TypeError> {
1414        deserialize_from_path(path)
1415    }
1416}
1417
1418impl TaskAccessor for AgentAssertionTask {
1419    fn context_path(&self) -> Option<&str> {
1420        None
1421    }
1422
1423    fn item_context_path(&self) -> Option<&str> {
1424        None
1425    }
1426
1427    fn id(&self) -> &str {
1428        &self.id
1429    }
1430
1431    fn operator(&self) -> &ComparisonOperator {
1432        &self.operator
1433    }
1434
1435    fn task_type(&self) -> &EvaluationTaskType {
1436        &self.task_type
1437    }
1438
1439    fn expected_value(&self) -> &Value {
1440        &self.expected_value
1441    }
1442
1443    fn depends_on(&self) -> &[String] {
1444        &self.depends_on
1445    }
1446
1447    fn add_result(&mut self, result: AssertionResult) {
1448        self.result = Some(result);
1449    }
1450}
1451
1452#[derive(Debug, Clone)]
1453pub enum EvaluationTask {
1454    Assertion(Box<AssertionTask>),
1455    LLMJudge(Box<LLMJudgeTask>),
1456    TraceAssertion(Box<TraceAssertionTask>),
1457    AgentAssertion(Box<AgentAssertionTask>),
1458}
1459
1460impl TaskAccessor for EvaluationTask {
1461    fn context_path(&self) -> Option<&str> {
1462        match self {
1463            EvaluationTask::Assertion(t) => t.context_path(),
1464            EvaluationTask::LLMJudge(t) => t.context_path(),
1465            EvaluationTask::TraceAssertion(t) => t.context_path(),
1466            EvaluationTask::AgentAssertion(t) => t.context_path(),
1467        }
1468    }
1469
1470    fn item_context_path(&self) -> Option<&str> {
1471        match self {
1472            EvaluationTask::Assertion(t) => t.item_context_path(),
1473            EvaluationTask::LLMJudge(t) => t.item_context_path(),
1474            EvaluationTask::TraceAssertion(t) => t.item_context_path(),
1475            EvaluationTask::AgentAssertion(t) => t.item_context_path(),
1476        }
1477    }
1478
1479    fn id(&self) -> &str {
1480        match self {
1481            EvaluationTask::Assertion(t) => t.id(),
1482            EvaluationTask::LLMJudge(t) => t.id(),
1483            EvaluationTask::TraceAssertion(t) => t.id(),
1484            EvaluationTask::AgentAssertion(t) => t.id(),
1485        }
1486    }
1487
1488    fn task_type(&self) -> &EvaluationTaskType {
1489        match self {
1490            EvaluationTask::Assertion(t) => t.task_type(),
1491            EvaluationTask::LLMJudge(t) => t.task_type(),
1492            EvaluationTask::TraceAssertion(t) => t.task_type(),
1493            EvaluationTask::AgentAssertion(t) => t.task_type(),
1494        }
1495    }
1496
1497    fn operator(&self) -> &ComparisonOperator {
1498        match self {
1499            EvaluationTask::Assertion(t) => t.operator(),
1500            EvaluationTask::LLMJudge(t) => t.operator(),
1501            EvaluationTask::TraceAssertion(t) => t.operator(),
1502            EvaluationTask::AgentAssertion(t) => t.operator(),
1503        }
1504    }
1505
1506    fn expected_value(&self) -> &Value {
1507        match self {
1508            EvaluationTask::Assertion(t) => t.expected_value(),
1509            EvaluationTask::LLMJudge(t) => t.expected_value(),
1510            EvaluationTask::TraceAssertion(t) => t.expected_value(),
1511            EvaluationTask::AgentAssertion(t) => t.expected_value(),
1512        }
1513    }
1514
1515    fn depends_on(&self) -> &[String] {
1516        match self {
1517            EvaluationTask::Assertion(t) => t.depends_on(),
1518            EvaluationTask::LLMJudge(t) => t.depends_on(),
1519            EvaluationTask::TraceAssertion(t) => t.depends_on(),
1520            EvaluationTask::AgentAssertion(t) => t.depends_on(),
1521        }
1522    }
1523
1524    fn add_result(&mut self, result: AssertionResult) {
1525        match self {
1526            EvaluationTask::Assertion(t) => t.add_result(result),
1527            EvaluationTask::LLMJudge(t) => t.add_result(result),
1528            EvaluationTask::TraceAssertion(t) => t.add_result(result),
1529            EvaluationTask::AgentAssertion(t) => t.add_result(result),
1530        }
1531    }
1532}
1533
1534pub struct EvaluationTasks(Vec<EvaluationTask>);
1535
1536impl EvaluationTasks {
1537    /// Creates a new empty builder
1538    pub fn new() -> Self {
1539        Self(Vec::new())
1540    }
1541
1542    /// Generic method that accepts anything implementing Into<EvaluationTask>
1543    pub fn add_task(mut self, task: impl Into<EvaluationTask>) -> Self {
1544        self.0.push(task.into());
1545        self
1546    }
1547
1548    /// Builds and returns the Vec<EvaluationTask>
1549    pub fn build(self) -> Vec<EvaluationTask> {
1550        self.0
1551    }
1552}
1553
1554impl From<AssertionTask> for EvaluationTask {
1555    fn from(task: AssertionTask) -> Self {
1556        EvaluationTask::Assertion(Box::new(task))
1557    }
1558}
1559
1560impl From<LLMJudgeTask> for EvaluationTask {
1561    fn from(task: LLMJudgeTask) -> Self {
1562        EvaluationTask::LLMJudge(Box::new(task))
1563    }
1564}
1565
1566impl From<TraceAssertionTask> for EvaluationTask {
1567    fn from(task: TraceAssertionTask) -> Self {
1568        EvaluationTask::TraceAssertion(Box::new(task))
1569    }
1570}
1571
1572impl From<AgentAssertionTask> for EvaluationTask {
1573    fn from(task: AgentAssertionTask) -> Self {
1574        EvaluationTask::AgentAssertion(Box::new(task))
1575    }
1576}
1577
1578impl Default for EvaluationTasks {
1579    fn default() -> Self {
1580        Self::new()
1581    }
1582}
1583
1584#[pyclass(eq, eq_int)]
1585#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1586pub enum ComparisonOperator {
1587    // Existing operators
1588    Equals,
1589    NotEqual,
1590    GreaterThan,
1591    GreaterThanOrEqual,
1592    LessThan,
1593    LessThanOrEqual,
1594    Contains,
1595    NotContains,
1596    StartsWith,
1597    EndsWith,
1598    Matches,
1599    HasLengthGreaterThan,
1600    HasLengthLessThan,
1601    HasLengthEqual,
1602    HasLengthGreaterThanOrEqual,
1603    HasLengthLessThanOrEqual,
1604
1605    // Type Validation Operators
1606    IsNumeric,
1607    IsString,
1608    IsBoolean,
1609    IsNull,
1610    IsArray,
1611    IsObject,
1612
1613    // Pattern & Format Validators
1614    IsEmail,
1615    IsUrl,
1616    IsUuid,
1617    IsIso8601,
1618    IsJson,
1619    MatchesRegex,
1620
1621    // Numeric Range Operators
1622    InRange,
1623    NotInRange,
1624    IsPositive,
1625    IsNegative,
1626    IsZero,
1627
1628    // Collection/Array Operators
1629    SequenceMatches,
1630    ContainsAll,
1631    ContainsAny,
1632    ContainsNone,
1633    IsEmpty,
1634    IsNotEmpty,
1635    HasUniqueItems,
1636
1637    // String Operators
1638    IsAlphabetic,
1639    IsAlphanumeric,
1640    IsLowerCase,
1641    IsUpperCase,
1642    ContainsWord,
1643
1644    // Comparison with Tolerance
1645    ApproximatelyEquals,
1646}
1647
1648impl Display for ComparisonOperator {
1649    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1650        write!(f, "{}", self.as_str())
1651    }
1652}
1653
1654impl FromStr for ComparisonOperator {
1655    type Err = TypeError;
1656
1657    fn from_str(s: &str) -> Result<Self, Self::Err> {
1658        match s {
1659            "Equals" => Ok(ComparisonOperator::Equals),
1660            "NotEqual" => Ok(ComparisonOperator::NotEqual),
1661            "GreaterThan" => Ok(ComparisonOperator::GreaterThan),
1662            "GreaterThanOrEqual" => Ok(ComparisonOperator::GreaterThanOrEqual),
1663            "LessThan" => Ok(ComparisonOperator::LessThan),
1664            "LessThanOrEqual" => Ok(ComparisonOperator::LessThanOrEqual),
1665            "Contains" => Ok(ComparisonOperator::Contains),
1666            "NotContains" => Ok(ComparisonOperator::NotContains),
1667            "StartsWith" => Ok(ComparisonOperator::StartsWith),
1668            "EndsWith" => Ok(ComparisonOperator::EndsWith),
1669            "Matches" => Ok(ComparisonOperator::Matches),
1670            "HasLengthEqual" => Ok(ComparisonOperator::HasLengthEqual),
1671            "HasLengthGreaterThan" => Ok(ComparisonOperator::HasLengthGreaterThan),
1672            "HasLengthLessThan" => Ok(ComparisonOperator::HasLengthLessThan),
1673            "HasLengthGreaterThanOrEqual" => Ok(ComparisonOperator::HasLengthGreaterThanOrEqual),
1674            "HasLengthLessThanOrEqual" => Ok(ComparisonOperator::HasLengthLessThanOrEqual),
1675
1676            // Type Validation
1677            "IsNumeric" => Ok(ComparisonOperator::IsNumeric),
1678            "IsString" => Ok(ComparisonOperator::IsString),
1679            "IsBoolean" => Ok(ComparisonOperator::IsBoolean),
1680            "IsNull" => Ok(ComparisonOperator::IsNull),
1681            "IsArray" => Ok(ComparisonOperator::IsArray),
1682            "IsObject" => Ok(ComparisonOperator::IsObject),
1683
1684            // Pattern & Format
1685            "IsEmail" => Ok(ComparisonOperator::IsEmail),
1686            "IsUrl" => Ok(ComparisonOperator::IsUrl),
1687            "IsUuid" => Ok(ComparisonOperator::IsUuid),
1688            "IsIso8601" => Ok(ComparisonOperator::IsIso8601),
1689            "IsJson" => Ok(ComparisonOperator::IsJson),
1690            "MatchesRegex" => Ok(ComparisonOperator::MatchesRegex),
1691
1692            // Numeric Range
1693            "InRange" => Ok(ComparisonOperator::InRange),
1694            "NotInRange" => Ok(ComparisonOperator::NotInRange),
1695            "IsPositive" => Ok(ComparisonOperator::IsPositive),
1696            "IsNegative" => Ok(ComparisonOperator::IsNegative),
1697            "IsZero" => Ok(ComparisonOperator::IsZero),
1698
1699            // Collection/Array
1700            "ContainsAll" => Ok(ComparisonOperator::ContainsAll),
1701            "ContainsAny" => Ok(ComparisonOperator::ContainsAny),
1702            "ContainsNone" => Ok(ComparisonOperator::ContainsNone),
1703            "IsEmpty" => Ok(ComparisonOperator::IsEmpty),
1704            "IsNotEmpty" => Ok(ComparisonOperator::IsNotEmpty),
1705            "HasUniqueItems" => Ok(ComparisonOperator::HasUniqueItems),
1706            "SequenceMatches" => Ok(ComparisonOperator::SequenceMatches),
1707
1708            // String
1709            "IsAlphabetic" => Ok(ComparisonOperator::IsAlphabetic),
1710            "IsAlphanumeric" => Ok(ComparisonOperator::IsAlphanumeric),
1711            "IsLowerCase" => Ok(ComparisonOperator::IsLowerCase),
1712            "IsUpperCase" => Ok(ComparisonOperator::IsUpperCase),
1713            "ContainsWord" => Ok(ComparisonOperator::ContainsWord),
1714
1715            // Tolerance
1716            "ApproximatelyEquals" => Ok(ComparisonOperator::ApproximatelyEquals),
1717
1718            _ => Err(TypeError::InvalidCompressionTypeError),
1719        }
1720    }
1721}
1722
1723impl ComparisonOperator {
1724    pub fn as_str(&self) -> &str {
1725        match self {
1726            ComparisonOperator::Equals => "Equals",
1727            ComparisonOperator::NotEqual => "NotEqual",
1728            ComparisonOperator::GreaterThan => "GreaterThan",
1729            ComparisonOperator::GreaterThanOrEqual => "GreaterThanOrEqual",
1730            ComparisonOperator::LessThan => "LessThan",
1731            ComparisonOperator::LessThanOrEqual => "LessThanOrEqual",
1732            ComparisonOperator::Contains => "Contains",
1733            ComparisonOperator::NotContains => "NotContains",
1734            ComparisonOperator::StartsWith => "StartsWith",
1735            ComparisonOperator::EndsWith => "EndsWith",
1736            ComparisonOperator::Matches => "Matches",
1737            ComparisonOperator::HasLengthEqual => "HasLengthEqual",
1738            ComparisonOperator::HasLengthGreaterThan => "HasLengthGreaterThan",
1739            ComparisonOperator::HasLengthLessThan => "HasLengthLessThan",
1740            ComparisonOperator::HasLengthGreaterThanOrEqual => "HasLengthGreaterThanOrEqual",
1741            ComparisonOperator::HasLengthLessThanOrEqual => "HasLengthLessThanOrEqual",
1742
1743            // Type Validation
1744            ComparisonOperator::IsNumeric => "IsNumeric",
1745            ComparisonOperator::IsString => "IsString",
1746            ComparisonOperator::IsBoolean => "IsBoolean",
1747            ComparisonOperator::IsNull => "IsNull",
1748            ComparisonOperator::IsArray => "IsArray",
1749            ComparisonOperator::IsObject => "IsObject",
1750
1751            // Pattern & Format
1752            ComparisonOperator::IsEmail => "IsEmail",
1753            ComparisonOperator::IsUrl => "IsUrl",
1754            ComparisonOperator::IsUuid => "IsUuid",
1755            ComparisonOperator::IsIso8601 => "IsIso8601",
1756            ComparisonOperator::IsJson => "IsJson",
1757            ComparisonOperator::MatchesRegex => "MatchesRegex",
1758
1759            // Numeric Range
1760            ComparisonOperator::InRange => "InRange",
1761            ComparisonOperator::NotInRange => "NotInRange",
1762            ComparisonOperator::IsPositive => "IsPositive",
1763            ComparisonOperator::IsNegative => "IsNegative",
1764            ComparisonOperator::IsZero => "IsZero",
1765
1766            // Collection/Array
1767            ComparisonOperator::ContainsAll => "ContainsAll",
1768            ComparisonOperator::ContainsAny => "ContainsAny",
1769            ComparisonOperator::ContainsNone => "ContainsNone",
1770            ComparisonOperator::IsEmpty => "IsEmpty",
1771            ComparisonOperator::IsNotEmpty => "IsNotEmpty",
1772            ComparisonOperator::HasUniqueItems => "HasUniqueItems",
1773            ComparisonOperator::SequenceMatches => "SequenceMatches",
1774
1775            // String
1776            ComparisonOperator::IsAlphabetic => "IsAlphabetic",
1777            ComparisonOperator::IsAlphanumeric => "IsAlphanumeric",
1778            ComparisonOperator::IsLowerCase => "IsLowerCase",
1779            ComparisonOperator::IsUpperCase => "IsUpperCase",
1780            ComparisonOperator::ContainsWord => "ContainsWord",
1781
1782            // Tolerance
1783            ComparisonOperator::ApproximatelyEquals => "ApproximatelyEquals",
1784        }
1785    }
1786}
1787
1788#[pyclass]
1789#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1790pub enum AssertionValue {
1791    String(String),
1792    Number(f64),
1793    Integer(i64),
1794    Boolean(bool),
1795    List(Vec<AssertionValue>),
1796    Null(),
1797}
1798
1799impl AssertionValue {
1800    pub fn to_actual(self, comparison: &ComparisonOperator) -> AssertionValue {
1801        match comparison {
1802            ComparisonOperator::HasLengthEqual
1803            | ComparisonOperator::HasLengthGreaterThan
1804            | ComparisonOperator::HasLengthLessThan
1805            | ComparisonOperator::HasLengthGreaterThanOrEqual
1806            | ComparisonOperator::HasLengthLessThanOrEqual => match self {
1807                AssertionValue::List(arr) => AssertionValue::Integer(arr.len() as i64),
1808                AssertionValue::String(s) => AssertionValue::Integer(s.chars().count() as i64),
1809                _ => self,
1810            },
1811            _ => self,
1812        }
1813    }
1814
1815    pub fn to_serde_value(&self) -> Value {
1816        match self {
1817            AssertionValue::String(s) => Value::String(s.clone()),
1818            AssertionValue::Number(n) => Value::Number(serde_json::Number::from_f64(*n).unwrap()),
1819            AssertionValue::Integer(i) => Value::Number(serde_json::Number::from(*i)),
1820            AssertionValue::Boolean(b) => Value::Bool(*b),
1821            AssertionValue::List(arr) => {
1822                let json_arr: Vec<Value> = arr.iter().map(|v| v.to_serde_value()).collect();
1823                Value::Array(json_arr)
1824            }
1825            AssertionValue::Null() => Value::Null,
1826        }
1827    }
1828}
1829/// Converts a PyAny value to an AssertionValue
1830///
1831/// # Errors
1832///
1833/// Returns `EvaluationError::UnsupportedType` if the Python type cannot be converted
1834/// to an `AssertionValue`.
1835pub fn assertion_value_from_py(value: &Bound<'_, PyAny>) -> Result<AssertionValue, TypeError> {
1836    // Check None first as it's a common case
1837    if value.is_none() {
1838        return Ok(AssertionValue::Null());
1839    }
1840
1841    // Check bool before int (bool is subclass of int in Python)
1842    if value.is_instance_of::<PyBool>() {
1843        return Ok(AssertionValue::Boolean(value.extract()?));
1844    }
1845
1846    if value.is_instance_of::<PyString>() {
1847        return Ok(AssertionValue::String(value.extract()?));
1848    }
1849
1850    if value.is_instance_of::<PyInt>() {
1851        return Ok(AssertionValue::Integer(value.extract()?));
1852    }
1853
1854    if value.is_instance_of::<PyFloat>() {
1855        return Ok(AssertionValue::Number(value.extract()?));
1856    }
1857
1858    if value.is_instance_of::<PyList>() {
1859        // For list, we need to iterate, so one downcast is fine
1860        let list = value.cast::<PyList>()?; // Safe: we just checked
1861        let assertion_list = list
1862            .iter()
1863            .map(|item| assertion_value_from_py(&item))
1864            .collect::<Result<Vec<_>, _>>()?;
1865        return Ok(AssertionValue::List(assertion_list));
1866    }
1867
1868    // Return error for unsupported types
1869    Err(TypeError::UnsupportedType(
1870        value.get_type().name()?.to_string(),
1871    ))
1872}
1873
1874#[pyclass(eq, eq_int)]
1875#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1876pub enum EvaluationTaskType {
1877    Assertion,
1878    LLMJudge,
1879    Conditional,
1880    HumanValidation,
1881    TraceAssertion,
1882    AgentAssertion,
1883}
1884
1885impl Display for EvaluationTaskType {
1886    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1887        let task_type_str = match self {
1888            EvaluationTaskType::Assertion => "Assertion",
1889            EvaluationTaskType::LLMJudge => "LLMJudge",
1890            EvaluationTaskType::Conditional => "Conditional",
1891            EvaluationTaskType::HumanValidation => "HumanValidation",
1892            EvaluationTaskType::TraceAssertion => "TraceAssertion",
1893            EvaluationTaskType::AgentAssertion => "AgentAssertion",
1894        };
1895        write!(f, "{}", task_type_str)
1896    }
1897}
1898
1899impl FromStr for EvaluationTaskType {
1900    type Err = TypeError;
1901
1902    fn from_str(s: &str) -> Result<Self, Self::Err> {
1903        match s {
1904            "Assertion" => Ok(EvaluationTaskType::Assertion),
1905            "LLMJudge" => Ok(EvaluationTaskType::LLMJudge),
1906            "Conditional" => Ok(EvaluationTaskType::Conditional),
1907            "HumanValidation" => Ok(EvaluationTaskType::HumanValidation),
1908            "TraceAssertion" => Ok(EvaluationTaskType::TraceAssertion),
1909            "AgentAssertion" => Ok(EvaluationTaskType::AgentAssertion),
1910            _ => Err(TypeError::InvalidEvalType(s.to_string())),
1911        }
1912    }
1913}
1914
1915impl EvaluationTaskType {
1916    pub fn as_str(&self) -> &str {
1917        match self {
1918            EvaluationTaskType::Assertion => "Assertion",
1919            EvaluationTaskType::LLMJudge => "LLMJudge",
1920            EvaluationTaskType::Conditional => "Conditional",
1921            EvaluationTaskType::HumanValidation => "HumanValidation",
1922            EvaluationTaskType::TraceAssertion => "TraceAssertion",
1923            EvaluationTaskType::AgentAssertion => "AgentAssertion",
1924        }
1925    }
1926}
1927
1928#[pyclass]
1929#[derive(Debug, Serialize)]
1930pub struct TasksFile {
1931    pub tasks: Vec<TaskConfig>,
1932
1933    #[serde(default)]
1934    index: usize,
1935}
1936
1937#[pymethods]
1938impl TasksFile {
1939    #[staticmethod]
1940    pub fn from_path(path: PathBuf) -> Result<Self, TypeError> {
1941        let tasks_file: TasksFile = deserialize_from_path(path)?;
1942        Ok(tasks_file)
1943    }
1944
1945    pub fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
1946        slf
1947    }
1948
1949    pub fn __next__<'py>(
1950        mut slf: PyRefMut<'py, Self>,
1951    ) -> Result<Option<Bound<'py, PyAny>>, TypeError> {
1952        let py = slf.py();
1953        if slf.index < slf.tasks.len() {
1954            let task = slf.tasks[slf.index].clone().into_bound_py_any(py)?;
1955            slf.index += 1;
1956            Ok(Some(task))
1957        } else {
1958            Ok(None)
1959        }
1960    }
1961
1962    fn __getitem__<'py>(
1963        &self,
1964        py: Python<'py>,
1965        index: &Bound<'py, PyAny>,
1966    ) -> Result<Bound<'py, PyAny>, TypeError> {
1967        if let Ok(i) = index.extract::<isize>() {
1968            let len = self.tasks.len() as isize;
1969            let actual_index = if i < 0 { len + i } else { i };
1970
1971            if actual_index < 0 || actual_index >= len {
1972                return Err(TypeError::IndexOutOfBounds {
1973                    index: i,
1974                    length: self.tasks.len(),
1975                });
1976            }
1977
1978            Ok(self.tasks[actual_index as usize]
1979                .clone()
1980                .into_bound_py_any(py)?)
1981        } else if let Ok(slice) = index.cast::<PySlice>() {
1982            let indices = slice.indices(self.tasks.len() as isize)?;
1983            let result = PyList::empty(py);
1984
1985            let mut i = indices.start;
1986            while (indices.step > 0 && i < indices.stop) || (indices.step < 0 && i > indices.stop) {
1987                result.append(self.tasks[i as usize].clone().into_bound_py_any(py)?)?;
1988                i += indices.step;
1989            }
1990
1991            Ok(result.into_bound_py_any(py)?)
1992        } else {
1993            Err(TypeError::IndexOrSliceExpected)
1994        }
1995    }
1996
1997    fn __len__(&self) -> usize {
1998        self.tasks.len()
1999    }
2000
2001    fn __str__(&self) -> String {
2002        PyHelperFuncs::__str__(self)
2003    }
2004}
2005
2006#[derive(Debug, Serialize, Clone)]
2007#[allow(clippy::large_enum_variant)]
2008pub enum TaskConfig {
2009    Assertion(AssertionTask),
2010    #[serde(rename = "LLMJudge")]
2011    LLMJudge(Box<LLMJudgeTask>),
2012    TraceAssertion(TraceAssertionTask),
2013    AgentAssertion(AgentAssertionTask),
2014}
2015
2016impl TaskConfig {
2017    fn into_bound_py_any<'py>(self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
2018        match self {
2019            TaskConfig::Assertion(task) => Ok(task.into_bound_py_any(py)?),
2020            TaskConfig::LLMJudge(task) => Ok(task.into_bound_py_any(py)?),
2021            TaskConfig::TraceAssertion(task) => Ok(task.into_bound_py_any(py)?),
2022            TaskConfig::AgentAssertion(task) => Ok(task.into_bound_py_any(py)?),
2023        }
2024    }
2025}
2026
2027impl<'de> Deserialize<'de> for TasksFile {
2028    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
2029    where
2030        D: serde::Deserializer<'de>,
2031    {
2032        #[derive(Deserialize)]
2033        #[serde(untagged)]
2034        enum TasksFileRaw {
2035            Direct(Vec<TaskConfigRaw>),
2036            Wrapped { tasks: Vec<TaskConfigRaw> },
2037        }
2038
2039        #[derive(Deserialize)]
2040        struct TaskConfigRaw {
2041            task_type: EvaluationTaskType,
2042            #[serde(flatten)]
2043            data: Value,
2044        }
2045
2046        let raw = TasksFileRaw::deserialize(deserializer)?;
2047        let raw_tasks = match raw {
2048            TasksFileRaw::Direct(tasks) => tasks,
2049            TasksFileRaw::Wrapped { tasks } => tasks,
2050        };
2051
2052        let mut tasks = Vec::new();
2053
2054        for task_raw in raw_tasks {
2055            let task_config = match task_raw.task_type {
2056                EvaluationTaskType::Assertion => {
2057                    let mut task: AssertionTask =
2058                        serde_json::from_value(task_raw.data).map_err(|e| {
2059                            error!("Failed to deserialize AssertionTask: {}", e);
2060                            serde::de::Error::custom(e.to_string())
2061                        })?;
2062                    task.task_type = EvaluationTaskType::Assertion;
2063                    TaskConfig::Assertion(task)
2064                }
2065                EvaluationTaskType::LLMJudge => {
2066                    let mut task: LLMJudgeTask =
2067                        serde_json::from_value(task_raw.data).map_err(|e| {
2068                            error!("Failed to deserialize LLMJudgeTask: {}", e);
2069                            serde::de::Error::custom(e.to_string())
2070                        })?;
2071                    task.task_type = EvaluationTaskType::LLMJudge;
2072                    TaskConfig::LLMJudge(Box::new(task))
2073                }
2074                EvaluationTaskType::TraceAssertion => {
2075                    let mut task: TraceAssertionTask = serde_json::from_value(task_raw.data)
2076                        .map_err(|e| {
2077                            error!("Failed to deserialize TraceAssertionTask: {}", e);
2078                            serde::de::Error::custom(e.to_string())
2079                        })?;
2080                    task.task_type = EvaluationTaskType::TraceAssertion;
2081                    TaskConfig::TraceAssertion(task)
2082                }
2083                EvaluationTaskType::AgentAssertion => {
2084                    let mut task: AgentAssertionTask = serde_json::from_value(task_raw.data)
2085                        .map_err(|e| {
2086                            error!("Failed to deserialize AgentAssertionTask: {}", e);
2087                            serde::de::Error::custom(e.to_string())
2088                        })?;
2089                    task.task_type = EvaluationTaskType::AgentAssertion;
2090                    TaskConfig::AgentAssertion(task)
2091                }
2092                _ => {
2093                    return Err(serde::de::Error::custom(format!(
2094                        "Unknown task_type: {}",
2095                        task_raw.task_type
2096                    )))
2097                }
2098            };
2099            tasks.push(task_config);
2100        }
2101
2102        Ok(TasksFile { tasks, index: 0 })
2103    }
2104}