Skip to main content

scouter_types/genai/
eval.rs

1use crate::error::TypeError;
2use crate::genai::traits::TaskAccessor;
3use crate::PyHelperFuncs;
4use core::fmt::Debug;
5use potato_head::prompt_types::Prompt;
6use pyo3::prelude::*;
7use pyo3::types::{PyBool, PyFloat, PyInt, PyList, PySlice, PyString};
8use pyo3::IntoPyObjectExt;
9use pythonize::{depythonize, pythonize};
10use serde::de::DeserializeOwned;
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13use std::collections::HashMap;
14use std::fmt::Display;
15use std::path::PathBuf;
16use std::str::FromStr;
17use tracing::error;
18
19pub fn deserialize_from_path<T: DeserializeOwned>(path: PathBuf) -> Result<T, TypeError> {
20    let content = std::fs::read_to_string(&path)?;
21
22    let extension = path
23        .extension()
24        .and_then(|ext| ext.to_str())
25        .ok_or_else(|| TypeError::Error(format!("Invalid file path: {:?}", path)))?;
26
27    let item = match extension.to_lowercase().as_str() {
28        "json" => serde_json::from_str(&content)?,
29        "yaml" | "yml" => serde_yaml::from_str(&content)?,
30        _ => {
31            return Err(TypeError::Error(format!(
32                "Unsupported file extension '{}'. Expected .json, .yaml, or .yml",
33                extension
34            )))
35        }
36    };
37
38    Ok(item)
39}
40
41// Default functions for task types during deserialization
42fn default_assertion_task_type() -> EvaluationTaskType {
43    EvaluationTaskType::Assertion
44}
45
46fn default_trace_assertion_task_type() -> EvaluationTaskType {
47    EvaluationTaskType::TraceAssertion
48}
49
50fn default_agent_assertion_task_type() -> EvaluationTaskType {
51    EvaluationTaskType::AgentAssertion
52}
53
54#[pyclass]
55#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
56pub struct AssertionResult {
57    #[pyo3(get)]
58    pub passed: bool,
59    pub actual: Value,
60
61    #[pyo3(get)]
62    pub message: String,
63
64    pub expected: Value,
65}
66
67impl AssertionResult {
68    pub fn new(passed: bool, actual: Value, message: String, expected: Value) -> Self {
69        Self {
70            passed,
71            actual,
72            message,
73            expected,
74        }
75    }
76    pub fn to_metric_value(&self) -> f64 {
77        if self.passed {
78            1.0
79        } else {
80            0.0
81        }
82    }
83}
84
85#[pymethods]
86impl AssertionResult {
87    pub fn __str__(&self) -> String {
88        PyHelperFuncs::__str__(self)
89    }
90
91    #[getter]
92    pub fn get_actual<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
93        let py_value = pythonize(py, &self.actual)?;
94        Ok(py_value)
95    }
96
97    #[getter]
98    pub fn get_expected<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
99        let py_value = pythonize(py, &self.expected)?;
100        Ok(py_value)
101    }
102}
103
104#[pyclass]
105#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
106pub struct AssertionResults {
107    #[pyo3(get)]
108    pub results: HashMap<String, AssertionResult>,
109}
110
111#[pymethods]
112impl AssertionResults {
113    pub fn __str__(&self) -> String {
114        PyHelperFuncs::__str__(self)
115    }
116
117    pub fn __getitem__(&self, key: &str) -> Result<AssertionResult, TypeError> {
118        if let Some(result) = self.results.get(key) {
119            Ok(result.clone())
120        } else {
121            Err(TypeError::KeyNotFound {
122                key: key.to_string(),
123            })
124        }
125    }
126}
127
128#[pyclass]
129#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
130pub struct AssertionTask {
131    #[pyo3(get, set)]
132    pub id: String,
133
134    #[pyo3(get, set)]
135    #[serde(default)]
136    pub context_path: Option<String>,
137
138    #[pyo3(get, set)]
139    #[serde(default)]
140    pub item_context_path: Option<String>,
141
142    #[pyo3(get, set)]
143    pub operator: ComparisonOperator,
144
145    pub expected_value: Value,
146
147    #[pyo3(get, set)]
148    #[serde(default)]
149    pub description: Option<String>,
150
151    #[pyo3(get, set)]
152    #[serde(default)]
153    pub depends_on: Vec<String>,
154
155    #[serde(default = "default_assertion_task_type")]
156    #[pyo3(get)]
157    pub task_type: EvaluationTaskType,
158
159    #[serde(skip_serializing_if = "Option::is_none")]
160    pub result: Option<AssertionResult>,
161
162    #[serde(default)]
163    pub condition: bool,
164}
165
166#[pymethods]
167impl AssertionTask {
168    #[new]
169    /// Creates a new AssertionTask
170    /// # Examples
171    /// ```python
172    ///
173    /// # assumed passed context at runtime
174    /// # context = {
175    /// #     "response": {
176    /// #         "user": {
177    /// #             "age": 25
178    /// #         }
179    /// #     }
180    /// # }
181    ///
182    /// task = AssertionTask(
183    ///     id="Check User Age",
184    ///     context_path="response.user.age",
185    ///     operator=ComparisonOperator.GREATER_THAN,
186    ///     expected_value=18,
187    ///     description="Check if user is an adult"
188    /// )
189    ///
190    /// # assumed passed context at runtime
191    /// # context = {
192    /// #     "user": {
193    /// #         "age": 25
194    /// #     }
195    /// # }
196    ///
197    /// task = AssertionTask(
198    ///     id="Check User Age",
199    ///     context_path="user.age",
200    ///     operator=ComparisonOperator.GREATER_THAN,
201    ///     expected_value=18,
202    ///     description="Check if user is an adult"
203    /// )
204    ///
205    ///  /// # assume non-map context at runtime
206    /// # context = 25
207    ///
208    /// task = AssertionTask(
209    ///     id="Check User Age",
210    ///     operator=ComparisonOperator.GREATER_THAN,
211    ///     expected_value=18,
212    ///     description="Check if user is an adult"
213    /// )
214    /// ```
215    /// # Arguments
216    /// * `context_path`: The path to the field to be asserted
217    /// * `operator`: The comparison operator to use
218    /// * `expected_value`: The expected value for the assertion
219    /// * `description`: Optional description for the assertion
220    /// # Returns
221    /// A new AssertionTask object
222    #[allow(clippy::too_many_arguments)]
223    #[pyo3(signature = (id, context_path, expected_value, operator, item_context_path=None, description=None, depends_on=None, condition=None))]
224    pub fn new(
225        id: String,
226        context_path: Option<String>,
227        expected_value: &Bound<'_, PyAny>,
228        operator: ComparisonOperator,
229        item_context_path: Option<String>,
230        description: Option<String>,
231        depends_on: Option<Vec<String>>,
232        condition: Option<bool>,
233    ) -> Result<Self, TypeError> {
234        let expected_value = depythonize(expected_value)?;
235        let condition = condition.unwrap_or(false);
236
237        Ok(Self {
238            id: id.to_lowercase(),
239            context_path,
240            item_context_path,
241            operator,
242            expected_value,
243            description,
244            task_type: EvaluationTaskType::Assertion,
245            depends_on: depends_on.unwrap_or_default(),
246            result: None,
247            condition,
248        })
249    }
250
251    #[getter]
252    pub fn get_expected_value<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
253        let py_value = pythonize(py, &self.expected_value)?;
254        Ok(py_value)
255    }
256
257    #[staticmethod]
258    pub fn from_path(path: PathBuf) -> Result<Self, TypeError> {
259        deserialize_from_path(path)
260    }
261}
262
263impl AssertionTask {}
264
265impl TaskAccessor for AssertionTask {
266    fn context_path(&self) -> Option<&str> {
267        self.context_path.as_deref()
268    }
269
270    fn item_context_path(&self) -> Option<&str> {
271        self.item_context_path.as_deref()
272    }
273
274    fn id(&self) -> &str {
275        &self.id
276    }
277
278    fn operator(&self) -> &ComparisonOperator {
279        &self.operator
280    }
281
282    fn task_type(&self) -> &EvaluationTaskType {
283        &self.task_type
284    }
285
286    fn expected_value(&self) -> &Value {
287        &self.expected_value
288    }
289
290    fn depends_on(&self) -> &[String] {
291        &self.depends_on
292    }
293
294    fn add_result(&mut self, result: AssertionResult) {
295        self.result = Some(result);
296    }
297}
298
299pub trait ValueExt {
300    /// Convert value to length for HasLength comparisons
301    fn to_length(&self) -> Option<i64>;
302
303    /// Extract numeric value for comparison
304    fn as_numeric(&self) -> Option<f64>;
305
306    /// Check if value is truthy
307    fn is_truthy(&self) -> bool;
308}
309
310impl ValueExt for Value {
311    fn to_length(&self) -> Option<i64> {
312        match self {
313            Value::Array(arr) => Some(arr.len() as i64),
314            Value::String(s) => Some(s.chars().count() as i64),
315            Value::Object(obj) => Some(obj.len() as i64),
316            _ => None,
317        }
318    }
319
320    fn as_numeric(&self) -> Option<f64> {
321        match self {
322            Value::Number(n) => n.as_f64(),
323            _ => None,
324        }
325    }
326
327    fn is_truthy(&self) -> bool {
328        match self {
329            Value::Null => false,
330            Value::Bool(b) => *b,
331            Value::Number(n) => n.as_f64() != Some(0.0),
332            Value::String(s) => !s.is_empty(),
333            Value::Array(arr) => !arr.is_empty(),
334            Value::Object(obj) => !obj.is_empty(),
335        }
336    }
337}
338
339/// Primary class for defining an LLM as a Judge in evaluation workflows
340#[pyclass]
341#[derive(Debug, Serialize, Clone, PartialEq)]
342pub struct LLMJudgeTask {
343    #[pyo3(get, set)]
344    pub id: String,
345
346    #[pyo3(get)]
347    pub prompt: Prompt,
348
349    #[pyo3(get)]
350    #[serde(default)]
351    pub context_path: Option<String>,
352
353    pub expected_value: Value,
354
355    #[pyo3(get)]
356    pub operator: ComparisonOperator,
357
358    #[pyo3(get)]
359    pub task_type: EvaluationTaskType,
360
361    #[pyo3(get, set)]
362    #[serde(default)]
363    pub depends_on: Vec<String>,
364
365    #[pyo3(get, set)]
366    #[serde(default)]
367    pub max_retries: Option<u32>,
368
369    #[serde(skip_serializing_if = "Option::is_none")]
370    pub result: Option<AssertionResult>,
371
372    #[serde(default)]
373    pub description: Option<String>,
374
375    #[pyo3(get, set)]
376    #[serde(default)]
377    pub condition: bool,
378}
379
380#[derive(Debug, Deserialize)]
381#[serde(untagged)]
382pub enum PromptConfig {
383    Path { path: String },
384    Inline(Box<Prompt>),
385}
386
387#[derive(Debug, Deserialize)]
388struct LLMJudgeTaskConfig {
389    pub id: String,
390    pub prompt: PromptConfig,
391    pub expected_value: Value,
392    pub operator: ComparisonOperator,
393    pub context_path: Option<String>,
394    pub description: Option<String>,
395    pub depends_on: Vec<String>,
396    pub max_retries: Option<u32>,
397    pub condition: bool,
398}
399
400impl LLMJudgeTaskConfig {
401    pub fn into_task(self) -> Result<LLMJudgeTask, TypeError> {
402        let prompt = match self.prompt {
403            PromptConfig::Path { path } => {
404                Prompt::from_path(PathBuf::from(path)).inspect_err(|e| {
405                    error!("Failed to deserialize Prompt from path: {}", e);
406                })?
407            }
408            PromptConfig::Inline(prompt) => *prompt,
409        };
410
411        Ok(LLMJudgeTask {
412            id: self.id.to_lowercase(),
413            prompt,
414            expected_value: self.expected_value,
415            operator: self.operator,
416            context_path: self.context_path,
417            description: self.description,
418            depends_on: self.depends_on,
419            max_retries: self.max_retries.or(Some(3)),
420            task_type: EvaluationTaskType::LLMJudge,
421            result: None,
422            condition: self.condition,
423        })
424    }
425}
426
427#[derive(Debug, Deserialize)]
428struct LLMJudgeTaskInternal {
429    pub id: String,
430    pub prompt: Prompt,
431    pub context_path: Option<String>,
432    pub expected_value: Value,
433    pub operator: ComparisonOperator,
434    pub task_type: EvaluationTaskType,
435    pub depends_on: Vec<String>,
436    pub max_retries: Option<u32>,
437    pub result: Option<AssertionResult>,
438    pub description: Option<String>,
439    pub condition: bool,
440}
441impl LLMJudgeTaskInternal {
442    pub fn into_task(self) -> LLMJudgeTask {
443        LLMJudgeTask {
444            id: self.id.to_lowercase(),
445            prompt: self.prompt,
446            context_path: self.context_path,
447            expected_value: self.expected_value,
448            operator: self.operator,
449            task_type: self.task_type,
450            depends_on: self.depends_on,
451            max_retries: self.max_retries.or(Some(3)),
452            result: self.result,
453            description: self.description,
454            condition: self.condition,
455        }
456    }
457}
458
459#[derive(Debug, Deserialize)]
460#[serde(untagged)]
461enum LLMJudgeFormat {
462    Full(Box<LLMJudgeTaskInternal>),
463    Generic(LLMJudgeTaskConfig),
464}
465
466impl<'de> Deserialize<'de> for LLMJudgeTask {
467    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
468    where
469        D: serde::Deserializer<'de>,
470    {
471        let format = LLMJudgeFormat::deserialize(deserializer)?;
472
473        match format {
474            LLMJudgeFormat::Generic(config) => config.into_task().map_err(serde::de::Error::custom),
475            LLMJudgeFormat::Full(internal) => Ok(internal.into_task()),
476        }
477    }
478}
479
480#[pymethods]
481impl LLMJudgeTask {
482    /// Creates a new LLMJudgeTask
483    /// # Examples
484    /// ```python
485    /// task = LLMJudgeTask(
486    ///     id="Sentiment Analysis Judge",
487    ///     prompt=prompt_object,
488    ///     expected_value="Positive",
489    ///     operator=ComparisonOperator.EQUALS
490    /// )
491    /// # Arguments
492    /// * `id: The id of the judge task
493    /// * `prompt`: The prompt object to be used for evaluation
494    /// * `expected_value`: The expected value for the judgement
495    /// * `context_path`: Optional context path to extract from the context for evaluation
496    /// * `operator`: The comparison operator to use
497    /// * `depends_on`: Optional list of task IDs this task depends on
498    /// * `max_retries`: Optional maximum number of retries for this task (defaults to 3 if not provided)
499    /// # Returns
500    /// A new LLMJudgeTask object
501    #[new]
502    #[pyo3(signature = (id, prompt, expected_value,  context_path,operator, description=None, depends_on=None, max_retries=None, condition=None))]
503    #[allow(clippy::too_many_arguments)]
504    pub fn new(
505        id: &str,
506        prompt: Prompt,
507        expected_value: &Bound<'_, PyAny>,
508        context_path: Option<String>,
509        operator: ComparisonOperator,
510        description: Option<String>,
511        depends_on: Option<Vec<String>>,
512        max_retries: Option<u32>,
513        condition: Option<bool>,
514    ) -> Result<Self, TypeError> {
515        let expected_value = depythonize(expected_value)?;
516
517        Ok(Self {
518            id: id.to_lowercase(),
519            prompt,
520            expected_value,
521            operator,
522            task_type: EvaluationTaskType::LLMJudge,
523            depends_on: depends_on.unwrap_or_default(),
524            max_retries: max_retries.or(Some(3)),
525            context_path,
526            result: None,
527            description,
528            condition: condition.unwrap_or(false),
529        })
530    }
531
532    pub fn __str__(&self) -> String {
533        // serialize the struct to a string
534        PyHelperFuncs::__str__(self)
535    }
536
537    #[getter]
538    pub fn get_expected_value<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
539        let py_value = pythonize(py, &self.expected_value)?;
540        Ok(py_value)
541    }
542
543    #[staticmethod]
544    pub fn from_path(path: PathBuf) -> Result<Self, TypeError> {
545        deserialize_from_path(path)
546    }
547}
548
549impl LLMJudgeTask {
550    /// Creates a new LLMJudgeTask with Rust types
551    /// # Arguments
552    /// * `id: The id of the judge task
553    /// * `prompt`: The prompt object to be used for evaluation
554    /// * `expected_value`: The expected value for the judgement
555    /// * `context_path`: Optional context path to extract from the context for evaluation
556    /// * `operator`: The comparison operator to use
557    /// * `depends_on`: Optional list of task IDs this task depends on
558    /// * `max_retries`: Optional maximum number of retries for this task (defaults to 3 if not provided)
559    /// # Returns
560    /// A new LLMJudgeTask object
561    #[allow(clippy::too_many_arguments)]
562    pub fn new_rs(
563        id: &str,
564        prompt: Prompt,
565        expected_value: Value,
566        context_path: Option<String>,
567        operator: ComparisonOperator,
568        depends_on: Option<Vec<String>>,
569        max_retries: Option<u32>,
570        description: Option<String>,
571        condition: Option<bool>,
572    ) -> Self {
573        Self {
574            id: id.to_lowercase(),
575            prompt,
576            expected_value,
577            operator,
578            task_type: EvaluationTaskType::LLMJudge,
579            depends_on: depends_on.unwrap_or_default(),
580            max_retries: max_retries.or(Some(3)),
581            context_path,
582            result: None,
583            description,
584            condition: condition.unwrap_or(false),
585        }
586    }
587}
588
589impl TaskAccessor for LLMJudgeTask {
590    fn context_path(&self) -> Option<&str> {
591        self.context_path.as_deref()
592    }
593
594    fn item_context_path(&self) -> Option<&str> {
595        None
596    }
597
598    fn id(&self) -> &str {
599        &self.id
600    }
601    fn task_type(&self) -> &EvaluationTaskType {
602        &self.task_type
603    }
604
605    fn operator(&self) -> &ComparisonOperator {
606        &self.operator
607    }
608
609    fn expected_value(&self) -> &Value {
610        &self.expected_value
611    }
612    fn depends_on(&self) -> &[String] {
613        &self.depends_on
614    }
615    fn add_result(&mut self, result: AssertionResult) {
616        self.result = Some(result);
617    }
618}
619
620#[pyclass(eq, eq_int)]
621#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
622pub enum SpanStatus {
623    Ok,
624    Error,
625    Unset,
626}
627
628#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
629pub struct PyValueWrapper(pub Value);
630
631impl<'py> IntoPyObject<'py> for PyValueWrapper {
632    type Target = PyAny;
633    type Output = Bound<'py, Self::Target>;
634    type Error = TypeError;
635
636    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
637        pythonize(py, &self.0).map_err(TypeError::from)
638    }
639}
640
641impl<'a, 'py> FromPyObject<'a, 'py> for PyValueWrapper {
642    type Error = TypeError;
643
644    fn extract(ob: pyo3::Borrowed<'a, 'py, pyo3::PyAny>) -> Result<Self, Self::Error> {
645        let value: Value = depythonize(&ob)?;
646        Ok(PyValueWrapper(value))
647    }
648}
649
650/// Filter configuration for selecting spans to assert on
651#[pyclass(eq)]
652#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
653pub enum SpanFilter {
654    /// Match spans by exact name
655    ByName { name: String },
656
657    /// Match spans by name pattern (regex)
658    ByNamePattern { pattern: String },
659
660    /// Match spans with specific attribute key
661    WithAttribute { key: String },
662
663    /// Match spans with specific attribute key-value pair
664    WithAttributeValue { key: String, value: PyValueWrapper },
665
666    /// Match spans by status code
667    WithStatus { status: SpanStatus },
668
669    /// Match spans with duration constraints
670    WithDuration {
671        min_ms: Option<f64>,
672        max_ms: Option<f64>,
673    },
674
675    /// Match a sequence of span names in order
676    Sequence { names: Vec<String> },
677
678    /// Combine multiple filters with AND logic
679    And { filters: Vec<SpanFilter> },
680
681    /// Combine multiple filters with OR logic
682    Or { filters: Vec<SpanFilter> },
683}
684
685#[pymethods]
686impl SpanFilter {
687    #[staticmethod]
688    pub fn by_name(name: String) -> Self {
689        SpanFilter::ByName { name }
690    }
691
692    #[staticmethod]
693    pub fn by_name_pattern(pattern: String) -> Self {
694        SpanFilter::ByNamePattern { pattern }
695    }
696
697    #[staticmethod]
698    pub fn with_attribute(key: String) -> Self {
699        SpanFilter::WithAttribute { key }
700    }
701
702    #[staticmethod]
703    pub fn with_attribute_value(key: String, value: &Bound<'_, PyAny>) -> Result<Self, TypeError> {
704        let value = PyValueWrapper(depythonize(value)?);
705        Ok(SpanFilter::WithAttributeValue { key, value })
706    }
707
708    #[staticmethod]
709    pub fn with_status(status: SpanStatus) -> Self {
710        SpanFilter::WithStatus { status }
711    }
712
713    #[staticmethod]
714    #[pyo3(signature = (min_ms=None, max_ms=None))]
715    pub fn with_duration(min_ms: Option<f64>, max_ms: Option<f64>) -> Self {
716        SpanFilter::WithDuration { min_ms, max_ms }
717    }
718
719    #[staticmethod]
720    pub fn sequence(names: Vec<String>) -> Self {
721        SpanFilter::Sequence { names }
722    }
723
724    pub fn and_(&self, other: SpanFilter) -> Self {
725        match self {
726            SpanFilter::And { filters } => {
727                let mut new_filters = filters.clone();
728                new_filters.push(other);
729                SpanFilter::And {
730                    filters: new_filters,
731                }
732            }
733            _ => SpanFilter::And {
734                filters: vec![self.clone(), other],
735            },
736        }
737    }
738
739    pub fn or_(&self, other: SpanFilter) -> Self {
740        match self {
741            SpanFilter::Or { filters } => {
742                let mut new_filters = filters.clone();
743                new_filters.push(other);
744                SpanFilter::Or {
745                    filters: new_filters,
746                }
747            }
748            _ => SpanFilter::Or {
749                filters: vec![self.clone(), other],
750            },
751        }
752    }
753}
754
755#[pyclass(eq, eq_int)]
756#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
757pub enum AggregationType {
758    Count,
759    Sum,
760    Average,
761    Min,
762    Max,
763    First,
764    Last,
765}
766
767/// Unified assertion target that can operate on traces or filtered spans
768#[pyclass(eq)]
769#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
770pub enum TraceAssertion {
771    /// Check if spans exist in a specific order
772    SpanSequence { span_names: Vec<String> },
773
774    /// Check if all specified span names exist (order doesn't matter)
775    SpanSet { span_names: Vec<String> },
776
777    /// Count spans matching a filter
778    SpanCount { filter: SpanFilter },
779
780    /// Check if any span matching filter exists
781    SpanExists { filter: SpanFilter },
782
783    /// Get attribute value from span(s) matching filter
784    SpanAttribute {
785        filter: SpanFilter,
786        attribute_key: String,
787    },
788
789    /// Get duration of span(s) matching filter
790    SpanDuration { filter: SpanFilter },
791
792    /// Aggregate a numeric attribute across filtered spans
793    SpanAggregation {
794        filter: SpanFilter,
795        attribute_key: String,
796        aggregation: AggregationType,
797    },
798
799    /// Check total duration of entire trace
800    TraceDuration {},
801
802    /// Count total spans in trace
803    TraceSpanCount {},
804
805    /// Count spans with errors in trace
806    TraceErrorCount {},
807
808    /// Count unique services in trace
809    TraceServiceCount {},
810
811    /// Get maximum depth of span tree
812    TraceMaxDepth {},
813
814    /// Get trace-level attribute
815    TraceAttribute { attribute_key: String },
816}
817
818// implement to_string for TraceAssertion
819impl Display for TraceAssertion {
820    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
821        // serde serialize to string
822        let s = serde_json::to_string(self).unwrap_or_default();
823        write!(f, "{}", s)
824    }
825}
826
827#[pymethods]
828impl TraceAssertion {
829    #[staticmethod]
830    pub fn span_sequence(span_names: Vec<String>) -> Self {
831        TraceAssertion::SpanSequence { span_names }
832    }
833
834    #[staticmethod]
835    pub fn span_set(span_names: Vec<String>) -> Self {
836        TraceAssertion::SpanSet { span_names }
837    }
838
839    #[staticmethod]
840    pub fn span_count(filter: SpanFilter) -> Self {
841        TraceAssertion::SpanCount { filter }
842    }
843
844    #[staticmethod]
845    pub fn span_exists(filter: SpanFilter) -> Self {
846        TraceAssertion::SpanExists { filter }
847    }
848
849    #[staticmethod]
850    pub fn span_attribute(filter: SpanFilter, attribute_key: String) -> Self {
851        TraceAssertion::SpanAttribute {
852            filter,
853            attribute_key,
854        }
855    }
856
857    #[staticmethod]
858    pub fn span_duration(filter: SpanFilter) -> Self {
859        TraceAssertion::SpanDuration { filter }
860    }
861
862    #[staticmethod]
863    pub fn span_aggregation(
864        filter: SpanFilter,
865        attribute_key: String,
866        aggregation: AggregationType,
867    ) -> Self {
868        TraceAssertion::SpanAggregation {
869            filter,
870            attribute_key,
871            aggregation,
872        }
873    }
874
875    #[staticmethod]
876    pub fn trace_duration() -> Self {
877        TraceAssertion::TraceDuration {}
878    }
879
880    #[staticmethod]
881    pub fn trace_span_count() -> Self {
882        TraceAssertion::TraceSpanCount {}
883    }
884
885    #[staticmethod]
886    pub fn trace_error_count() -> Self {
887        TraceAssertion::TraceErrorCount {}
888    }
889
890    #[staticmethod]
891    pub fn trace_service_count() -> Self {
892        TraceAssertion::TraceServiceCount {}
893    }
894
895    #[staticmethod]
896    pub fn trace_max_depth() -> Self {
897        TraceAssertion::TraceMaxDepth {}
898    }
899
900    #[staticmethod]
901    pub fn trace_attribute(attribute_key: String) -> Self {
902        TraceAssertion::TraceAttribute { attribute_key }
903    }
904
905    pub fn model_dump_json(&self) -> String {
906        serde_json::to_string(self).unwrap_or_default()
907    }
908}
909
910#[pyclass]
911#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
912pub struct TraceAssertionTask {
913    #[pyo3(get, set)]
914    pub id: String,
915
916    #[pyo3(get, set)]
917    pub assertion: TraceAssertion,
918
919    #[pyo3(get, set)]
920    pub operator: ComparisonOperator,
921
922    pub expected_value: Value,
923
924    #[pyo3(get, set)]
925    #[serde(default)]
926    pub description: Option<String>,
927
928    #[pyo3(get, set)]
929    #[serde(default)]
930    pub depends_on: Vec<String>,
931
932    #[serde(default = "default_trace_assertion_task_type")]
933    #[pyo3(get)]
934    pub task_type: EvaluationTaskType,
935
936    #[serde(skip_serializing_if = "Option::is_none")]
937    pub result: Option<AssertionResult>,
938
939    #[pyo3(get, set)]
940    #[serde(default)]
941    pub condition: bool,
942}
943
944#[pymethods]
945impl TraceAssertionTask {
946    /// Creates a new TraceAssertionTask
947    ///
948    /// # Examples
949    /// ```python
950    /// # Check execution order of spans
951    /// task = TraceAssertionTask(
952    ///     id="verify_agent_workflow",
953    ///     assertion=TraceAssertion.span_sequence(["call_tool", "run_agent", "double_check"]),
954    ///     operator=ComparisonOperator.SequenceMatches,
955    ///     expected_value=True
956    /// )
957    ///
958    /// # Check all required spans exist
959    /// task = TraceAssertionTask(
960    ///     id="verify_required_steps",
961    ///     assertion=TraceAssertion.span_set(["call_tool", "run_agent", "double_check"]),
962    ///     operator=ComparisonOperator.ContainsAll,
963    ///     expected_value=True
964    /// )
965    ///
966    /// # Check total trace duration
967    /// task = TraceAssertionTask(
968    ///     id="verify_performance",
969    ///     assertion=TraceAssertion.trace_duration(),
970    ///     operator=ComparisonOperator.LessThan,
971    ///     expected_value=5000.0
972    /// )
973    ///
974    /// # Check count of specific spans
975    /// task = TraceAssertionTask(
976    ///     id="verify_retry_count",
977    ///     assertion=TraceAssertion.span_count(
978    ///         SpanFilter.by_name("retry_operation")
979    ///     ),
980    ///     operator=ComparisonOperator.LessThanOrEqual,
981    ///     expected_value=3
982    /// )
983    ///
984    /// # Check span attribute
985    /// task = TraceAssertionTask(
986    ///     id="verify_model_used",
987    ///     assertion=TraceAssertion.span_attribute(
988    ///         SpanFilter.by_name("llm.generate"),
989    ///         "model"
990    ///     ),
991    ///     operator=ComparisonOperator.Equals,
992    ///     expected_value="gpt-4"
993    /// )
994    /// ```
995    #[new]
996    /// Creates a new TraceAssertionTask
997    #[pyo3(signature = (id, assertion, expected_value, operator, description=None, depends_on=None, condition=None))]
998    pub fn new(
999        id: String,
1000        assertion: TraceAssertion,
1001        expected_value: &Bound<'_, PyAny>,
1002        operator: ComparisonOperator,
1003        description: Option<String>,
1004        depends_on: Option<Vec<String>>,
1005        condition: Option<bool>,
1006    ) -> Result<Self, TypeError> {
1007        let expected_value = depythonize(expected_value)?;
1008
1009        Ok(Self {
1010            id: id.to_lowercase(),
1011            assertion,
1012            operator,
1013            expected_value,
1014            description,
1015            task_type: EvaluationTaskType::TraceAssertion,
1016            depends_on: depends_on.unwrap_or_default(),
1017            result: None,
1018            condition: condition.unwrap_or(false),
1019        })
1020    }
1021
1022    pub fn __str__(&self) -> String {
1023        // serialize the struct to a string
1024        PyHelperFuncs::__str__(self)
1025    }
1026
1027    #[getter]
1028    pub fn get_expected_value<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
1029        let py_value = pythonize(py, &self.expected_value)?;
1030        Ok(py_value)
1031    }
1032
1033    #[staticmethod]
1034    pub fn from_path(path: PathBuf) -> Result<Self, TypeError> {
1035        deserialize_from_path(path)
1036    }
1037}
1038
1039impl TaskAccessor for TraceAssertionTask {
1040    fn context_path(&self) -> Option<&str> {
1041        None
1042    }
1043
1044    fn item_context_path(&self) -> Option<&str> {
1045        None
1046    }
1047
1048    fn id(&self) -> &str {
1049        &self.id
1050    }
1051
1052    fn operator(&self) -> &ComparisonOperator {
1053        &self.operator
1054    }
1055
1056    fn task_type(&self) -> &EvaluationTaskType {
1057        &self.task_type
1058    }
1059
1060    fn expected_value(&self) -> &Value {
1061        &self.expected_value
1062    }
1063
1064    fn depends_on(&self) -> &[String] {
1065        &self.depends_on
1066    }
1067
1068    fn add_result(&mut self, result: AssertionResult) {
1069        self.result = Some(result);
1070    }
1071}
1072
1073#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1074pub struct ToolCall {
1075    pub name: String,
1076    pub arguments: Value,
1077    pub result: Option<Value>,
1078    pub call_id: Option<String>,
1079}
1080
1081#[pyclass]
1082#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1083pub struct TokenUsage {
1084    #[pyo3(get, set)]
1085    pub input_tokens: Option<i64>,
1086
1087    #[pyo3(get, set)]
1088    pub output_tokens: Option<i64>,
1089
1090    #[pyo3(get, set)]
1091    pub total_tokens: Option<i64>,
1092}
1093
1094#[pymethods]
1095impl TokenUsage {
1096    #[new]
1097    #[pyo3(signature = (input_tokens=None, output_tokens=None, total_tokens=None))]
1098    pub fn new(
1099        input_tokens: Option<i64>,
1100        output_tokens: Option<i64>,
1101        total_tokens: Option<i64>,
1102    ) -> Self {
1103        Self {
1104            input_tokens,
1105            output_tokens,
1106            total_tokens,
1107        }
1108    }
1109
1110    pub fn __str__(&self) -> String {
1111        serde_json::to_string_pretty(self).unwrap_or_default()
1112    }
1113}
1114
1115#[pyclass(eq)]
1116#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1117pub enum AgentAssertion {
1118    /// Check if a specific tool was called
1119    ToolCalled { name: String },
1120
1121    /// Check if a specific tool was NOT called
1122    ToolNotCalled { name: String },
1123
1124    /// Check if a tool was called with specific arguments (partial match)
1125    ToolCalledWithArgs {
1126        name: String,
1127        arguments: PyValueWrapper,
1128    },
1129
1130    /// Check if tools were called in exact sequence
1131    ToolCallSequence { names: Vec<String> },
1132
1133    /// Count tool calls (optionally filtered by name)
1134    ToolCallCount { name: Option<String> },
1135
1136    /// Extract a tool argument value
1137    ToolArgument { name: String, argument_key: String },
1138
1139    /// Extract a tool result value
1140    ToolResult { name: String },
1141
1142    /// Get the text content of the response
1143    ResponseContent {},
1144
1145    /// Get the model name
1146    ResponseModel {},
1147
1148    /// Get the finish/stop reason
1149    ResponseFinishReason {},
1150
1151    /// Get input token count
1152    ResponseInputTokens {},
1153
1154    /// Get output token count
1155    ResponseOutputTokens {},
1156
1157    /// Get total token count
1158    ResponseTotalTokens {},
1159
1160    /// Extract a field from the raw (un-normalized) response via context_path
1161    ResponseField { path: String },
1162}
1163
1164impl Display for AgentAssertion {
1165    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1166        let s = serde_json::to_string(self).unwrap_or_default();
1167        write!(f, "{}", s)
1168    }
1169}
1170
1171#[pymethods]
1172impl AgentAssertion {
1173    #[staticmethod]
1174    pub fn tool_called(name: &str) -> Self {
1175        AgentAssertion::ToolCalled {
1176            name: name.to_string(),
1177        }
1178    }
1179
1180    #[staticmethod]
1181    pub fn tool_not_called(name: &str) -> Self {
1182        AgentAssertion::ToolNotCalled {
1183            name: name.to_string(),
1184        }
1185    }
1186
1187    #[staticmethod]
1188    pub fn tool_called_with_args(
1189        name: &str,
1190        arguments: &Bound<'_, PyAny>,
1191    ) -> Result<Self, TypeError> {
1192        let arguments: Value = depythonize(arguments)?;
1193        Ok(AgentAssertion::ToolCalledWithArgs {
1194            name: name.to_string(),
1195            arguments: PyValueWrapper(arguments),
1196        })
1197    }
1198
1199    #[staticmethod]
1200    pub fn tool_call_sequence(names: Vec<String>) -> Self {
1201        AgentAssertion::ToolCallSequence { names }
1202    }
1203
1204    #[staticmethod]
1205    #[pyo3(signature = (name=None))]
1206    pub fn tool_call_count(name: Option<String>) -> Self {
1207        AgentAssertion::ToolCallCount { name }
1208    }
1209
1210    #[staticmethod]
1211    pub fn tool_argument(name: &str, argument_key: &str) -> Self {
1212        AgentAssertion::ToolArgument {
1213            name: name.to_string(),
1214            argument_key: argument_key.to_string(),
1215        }
1216    }
1217
1218    #[staticmethod]
1219    pub fn tool_result(name: &str) -> Self {
1220        AgentAssertion::ToolResult {
1221            name: name.to_string(),
1222        }
1223    }
1224
1225    #[staticmethod]
1226    pub fn response_content() -> Self {
1227        AgentAssertion::ResponseContent {}
1228    }
1229
1230    #[staticmethod]
1231    pub fn response_model() -> Self {
1232        AgentAssertion::ResponseModel {}
1233    }
1234
1235    #[staticmethod]
1236    pub fn response_finish_reason() -> Self {
1237        AgentAssertion::ResponseFinishReason {}
1238    }
1239
1240    #[staticmethod]
1241    pub fn response_input_tokens() -> Self {
1242        AgentAssertion::ResponseInputTokens {}
1243    }
1244
1245    #[staticmethod]
1246    pub fn response_output_tokens() -> Self {
1247        AgentAssertion::ResponseOutputTokens {}
1248    }
1249
1250    #[staticmethod]
1251    pub fn response_total_tokens() -> Self {
1252        AgentAssertion::ResponseTotalTokens {}
1253    }
1254
1255    #[staticmethod]
1256    pub fn response_field(path: &str) -> Self {
1257        AgentAssertion::ResponseField {
1258            path: path.to_string(),
1259        }
1260    }
1261
1262    pub fn __str__(&self) -> String {
1263        serde_json::to_string_pretty(self).unwrap_or_default()
1264    }
1265}
1266
1267#[pyclass]
1268#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1269pub struct AgentAssertionTask {
1270    #[pyo3(get, set)]
1271    pub id: String,
1272
1273    #[pyo3(get, set)]
1274    pub assertion: AgentAssertion,
1275
1276    #[pyo3(get, set)]
1277    pub operator: ComparisonOperator,
1278
1279    pub expected_value: Value,
1280
1281    #[pyo3(get, set)]
1282    #[serde(default)]
1283    pub description: Option<String>,
1284
1285    #[pyo3(get, set)]
1286    #[serde(default)]
1287    pub depends_on: Vec<String>,
1288
1289    #[serde(default = "default_agent_assertion_task_type")]
1290    #[pyo3(get)]
1291    pub task_type: EvaluationTaskType,
1292
1293    #[serde(skip_serializing_if = "Option::is_none")]
1294    pub result: Option<AssertionResult>,
1295
1296    #[pyo3(get, set)]
1297    #[serde(default)]
1298    pub condition: bool,
1299}
1300
1301#[pymethods]
1302impl AgentAssertionTask {
1303    #[new]
1304    #[pyo3(signature = (id, assertion, expected_value, operator, description=None, depends_on=None, condition=None))]
1305    pub fn new(
1306        id: String,
1307        assertion: AgentAssertion,
1308        expected_value: &Bound<'_, PyAny>,
1309        operator: ComparisonOperator,
1310        description: Option<String>,
1311        depends_on: Option<Vec<String>>,
1312        condition: Option<bool>,
1313    ) -> Result<Self, TypeError> {
1314        let expected_value = depythonize(expected_value)?;
1315
1316        Ok(Self {
1317            id: id.to_lowercase(),
1318            assertion,
1319            operator,
1320            expected_value,
1321            description,
1322            task_type: EvaluationTaskType::AgentAssertion,
1323            depends_on: depends_on.unwrap_or_default(),
1324            result: None,
1325            condition: condition.unwrap_or(false),
1326        })
1327    }
1328
1329    pub fn __str__(&self) -> String {
1330        PyHelperFuncs::__str__(self)
1331    }
1332
1333    pub fn model_dump_json(&self) -> String {
1334        serde_json::to_string(self).unwrap_or_default()
1335    }
1336
1337    #[getter]
1338    pub fn get_expected_value<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
1339        let py_value = pythonize(py, &self.expected_value)?;
1340        Ok(py_value)
1341    }
1342
1343    #[getter]
1344    pub fn get_result<'py>(&self, py: Python<'py>) -> Result<Option<Bound<'py, PyAny>>, TypeError> {
1345        match &self.result {
1346            Some(result) => {
1347                let py_value = pythonize(py, result)?;
1348                Ok(Some(py_value))
1349            }
1350            None => Ok(None),
1351        }
1352    }
1353
1354    #[staticmethod]
1355    pub fn from_path(path: PathBuf) -> Result<Self, TypeError> {
1356        deserialize_from_path(path)
1357    }
1358}
1359
1360impl TaskAccessor for AgentAssertionTask {
1361    fn context_path(&self) -> Option<&str> {
1362        None
1363    }
1364
1365    fn item_context_path(&self) -> Option<&str> {
1366        None
1367    }
1368
1369    fn id(&self) -> &str {
1370        &self.id
1371    }
1372
1373    fn operator(&self) -> &ComparisonOperator {
1374        &self.operator
1375    }
1376
1377    fn task_type(&self) -> &EvaluationTaskType {
1378        &self.task_type
1379    }
1380
1381    fn expected_value(&self) -> &Value {
1382        &self.expected_value
1383    }
1384
1385    fn depends_on(&self) -> &[String] {
1386        &self.depends_on
1387    }
1388
1389    fn add_result(&mut self, result: AssertionResult) {
1390        self.result = Some(result);
1391    }
1392}
1393
1394#[derive(Debug, Clone)]
1395pub enum EvaluationTask {
1396    Assertion(Box<AssertionTask>),
1397    LLMJudge(Box<LLMJudgeTask>),
1398    TraceAssertion(Box<TraceAssertionTask>),
1399    AgentAssertion(Box<AgentAssertionTask>),
1400}
1401
1402impl TaskAccessor for EvaluationTask {
1403    fn context_path(&self) -> Option<&str> {
1404        match self {
1405            EvaluationTask::Assertion(t) => t.context_path(),
1406            EvaluationTask::LLMJudge(t) => t.context_path(),
1407            EvaluationTask::TraceAssertion(t) => t.context_path(),
1408            EvaluationTask::AgentAssertion(t) => t.context_path(),
1409        }
1410    }
1411
1412    fn item_context_path(&self) -> Option<&str> {
1413        match self {
1414            EvaluationTask::Assertion(t) => t.item_context_path(),
1415            EvaluationTask::LLMJudge(t) => t.item_context_path(),
1416            EvaluationTask::TraceAssertion(t) => t.item_context_path(),
1417            EvaluationTask::AgentAssertion(t) => t.item_context_path(),
1418        }
1419    }
1420
1421    fn id(&self) -> &str {
1422        match self {
1423            EvaluationTask::Assertion(t) => t.id(),
1424            EvaluationTask::LLMJudge(t) => t.id(),
1425            EvaluationTask::TraceAssertion(t) => t.id(),
1426            EvaluationTask::AgentAssertion(t) => t.id(),
1427        }
1428    }
1429
1430    fn task_type(&self) -> &EvaluationTaskType {
1431        match self {
1432            EvaluationTask::Assertion(t) => t.task_type(),
1433            EvaluationTask::LLMJudge(t) => t.task_type(),
1434            EvaluationTask::TraceAssertion(t) => t.task_type(),
1435            EvaluationTask::AgentAssertion(t) => t.task_type(),
1436        }
1437    }
1438
1439    fn operator(&self) -> &ComparisonOperator {
1440        match self {
1441            EvaluationTask::Assertion(t) => t.operator(),
1442            EvaluationTask::LLMJudge(t) => t.operator(),
1443            EvaluationTask::TraceAssertion(t) => t.operator(),
1444            EvaluationTask::AgentAssertion(t) => t.operator(),
1445        }
1446    }
1447
1448    fn expected_value(&self) -> &Value {
1449        match self {
1450            EvaluationTask::Assertion(t) => t.expected_value(),
1451            EvaluationTask::LLMJudge(t) => t.expected_value(),
1452            EvaluationTask::TraceAssertion(t) => t.expected_value(),
1453            EvaluationTask::AgentAssertion(t) => t.expected_value(),
1454        }
1455    }
1456
1457    fn depends_on(&self) -> &[String] {
1458        match self {
1459            EvaluationTask::Assertion(t) => t.depends_on(),
1460            EvaluationTask::LLMJudge(t) => t.depends_on(),
1461            EvaluationTask::TraceAssertion(t) => t.depends_on(),
1462            EvaluationTask::AgentAssertion(t) => t.depends_on(),
1463        }
1464    }
1465
1466    fn add_result(&mut self, result: AssertionResult) {
1467        match self {
1468            EvaluationTask::Assertion(t) => t.add_result(result),
1469            EvaluationTask::LLMJudge(t) => t.add_result(result),
1470            EvaluationTask::TraceAssertion(t) => t.add_result(result),
1471            EvaluationTask::AgentAssertion(t) => t.add_result(result),
1472        }
1473    }
1474}
1475
1476pub struct EvaluationTasks(Vec<EvaluationTask>);
1477
1478impl EvaluationTasks {
1479    /// Creates a new empty builder
1480    pub fn new() -> Self {
1481        Self(Vec::new())
1482    }
1483
1484    /// Generic method that accepts anything implementing Into<EvaluationTask>
1485    pub fn add_task(mut self, task: impl Into<EvaluationTask>) -> Self {
1486        self.0.push(task.into());
1487        self
1488    }
1489
1490    /// Builds and returns the Vec<EvaluationTask>
1491    pub fn build(self) -> Vec<EvaluationTask> {
1492        self.0
1493    }
1494}
1495
1496impl From<AssertionTask> for EvaluationTask {
1497    fn from(task: AssertionTask) -> Self {
1498        EvaluationTask::Assertion(Box::new(task))
1499    }
1500}
1501
1502impl From<LLMJudgeTask> for EvaluationTask {
1503    fn from(task: LLMJudgeTask) -> Self {
1504        EvaluationTask::LLMJudge(Box::new(task))
1505    }
1506}
1507
1508impl From<TraceAssertionTask> for EvaluationTask {
1509    fn from(task: TraceAssertionTask) -> Self {
1510        EvaluationTask::TraceAssertion(Box::new(task))
1511    }
1512}
1513
1514impl From<AgentAssertionTask> for EvaluationTask {
1515    fn from(task: AgentAssertionTask) -> Self {
1516        EvaluationTask::AgentAssertion(Box::new(task))
1517    }
1518}
1519
1520impl Default for EvaluationTasks {
1521    fn default() -> Self {
1522        Self::new()
1523    }
1524}
1525
1526#[pyclass(eq, eq_int)]
1527#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1528pub enum ComparisonOperator {
1529    // Existing operators
1530    Equals,
1531    NotEqual,
1532    GreaterThan,
1533    GreaterThanOrEqual,
1534    LessThan,
1535    LessThanOrEqual,
1536    Contains,
1537    NotContains,
1538    StartsWith,
1539    EndsWith,
1540    Matches,
1541    HasLengthGreaterThan,
1542    HasLengthLessThan,
1543    HasLengthEqual,
1544    HasLengthGreaterThanOrEqual,
1545    HasLengthLessThanOrEqual,
1546
1547    // Type Validation Operators
1548    IsNumeric,
1549    IsString,
1550    IsBoolean,
1551    IsNull,
1552    IsArray,
1553    IsObject,
1554
1555    // Pattern & Format Validators
1556    IsEmail,
1557    IsUrl,
1558    IsUuid,
1559    IsIso8601,
1560    IsJson,
1561    MatchesRegex,
1562
1563    // Numeric Range Operators
1564    InRange,
1565    NotInRange,
1566    IsPositive,
1567    IsNegative,
1568    IsZero,
1569
1570    // Collection/Array Operators
1571    SequenceMatches,
1572    ContainsAll,
1573    ContainsAny,
1574    ContainsNone,
1575    IsEmpty,
1576    IsNotEmpty,
1577    HasUniqueItems,
1578
1579    // String Operators
1580    IsAlphabetic,
1581    IsAlphanumeric,
1582    IsLowerCase,
1583    IsUpperCase,
1584    ContainsWord,
1585
1586    // Comparison with Tolerance
1587    ApproximatelyEquals,
1588}
1589
1590impl Display for ComparisonOperator {
1591    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1592        write!(f, "{}", self.as_str())
1593    }
1594}
1595
1596impl FromStr for ComparisonOperator {
1597    type Err = TypeError;
1598
1599    fn from_str(s: &str) -> Result<Self, Self::Err> {
1600        match s {
1601            "Equals" => Ok(ComparisonOperator::Equals),
1602            "NotEqual" => Ok(ComparisonOperator::NotEqual),
1603            "GreaterThan" => Ok(ComparisonOperator::GreaterThan),
1604            "GreaterThanOrEqual" => Ok(ComparisonOperator::GreaterThanOrEqual),
1605            "LessThan" => Ok(ComparisonOperator::LessThan),
1606            "LessThanOrEqual" => Ok(ComparisonOperator::LessThanOrEqual),
1607            "Contains" => Ok(ComparisonOperator::Contains),
1608            "NotContains" => Ok(ComparisonOperator::NotContains),
1609            "StartsWith" => Ok(ComparisonOperator::StartsWith),
1610            "EndsWith" => Ok(ComparisonOperator::EndsWith),
1611            "Matches" => Ok(ComparisonOperator::Matches),
1612            "HasLengthEqual" => Ok(ComparisonOperator::HasLengthEqual),
1613            "HasLengthGreaterThan" => Ok(ComparisonOperator::HasLengthGreaterThan),
1614            "HasLengthLessThan" => Ok(ComparisonOperator::HasLengthLessThan),
1615            "HasLengthGreaterThanOrEqual" => Ok(ComparisonOperator::HasLengthGreaterThanOrEqual),
1616            "HasLengthLessThanOrEqual" => Ok(ComparisonOperator::HasLengthLessThanOrEqual),
1617
1618            // Type Validation
1619            "IsNumeric" => Ok(ComparisonOperator::IsNumeric),
1620            "IsString" => Ok(ComparisonOperator::IsString),
1621            "IsBoolean" => Ok(ComparisonOperator::IsBoolean),
1622            "IsNull" => Ok(ComparisonOperator::IsNull),
1623            "IsArray" => Ok(ComparisonOperator::IsArray),
1624            "IsObject" => Ok(ComparisonOperator::IsObject),
1625
1626            // Pattern & Format
1627            "IsEmail" => Ok(ComparisonOperator::IsEmail),
1628            "IsUrl" => Ok(ComparisonOperator::IsUrl),
1629            "IsUuid" => Ok(ComparisonOperator::IsUuid),
1630            "IsIso8601" => Ok(ComparisonOperator::IsIso8601),
1631            "IsJson" => Ok(ComparisonOperator::IsJson),
1632            "MatchesRegex" => Ok(ComparisonOperator::MatchesRegex),
1633
1634            // Numeric Range
1635            "InRange" => Ok(ComparisonOperator::InRange),
1636            "NotInRange" => Ok(ComparisonOperator::NotInRange),
1637            "IsPositive" => Ok(ComparisonOperator::IsPositive),
1638            "IsNegative" => Ok(ComparisonOperator::IsNegative),
1639            "IsZero" => Ok(ComparisonOperator::IsZero),
1640
1641            // Collection/Array
1642            "ContainsAll" => Ok(ComparisonOperator::ContainsAll),
1643            "ContainsAny" => Ok(ComparisonOperator::ContainsAny),
1644            "ContainsNone" => Ok(ComparisonOperator::ContainsNone),
1645            "IsEmpty" => Ok(ComparisonOperator::IsEmpty),
1646            "IsNotEmpty" => Ok(ComparisonOperator::IsNotEmpty),
1647            "HasUniqueItems" => Ok(ComparisonOperator::HasUniqueItems),
1648            "SequenceMatches" => Ok(ComparisonOperator::SequenceMatches),
1649
1650            // String
1651            "IsAlphabetic" => Ok(ComparisonOperator::IsAlphabetic),
1652            "IsAlphanumeric" => Ok(ComparisonOperator::IsAlphanumeric),
1653            "IsLowerCase" => Ok(ComparisonOperator::IsLowerCase),
1654            "IsUpperCase" => Ok(ComparisonOperator::IsUpperCase),
1655            "ContainsWord" => Ok(ComparisonOperator::ContainsWord),
1656
1657            // Tolerance
1658            "ApproximatelyEquals" => Ok(ComparisonOperator::ApproximatelyEquals),
1659
1660            _ => Err(TypeError::InvalidCompressionTypeError),
1661        }
1662    }
1663}
1664
1665impl ComparisonOperator {
1666    pub fn as_str(&self) -> &str {
1667        match self {
1668            ComparisonOperator::Equals => "Equals",
1669            ComparisonOperator::NotEqual => "NotEqual",
1670            ComparisonOperator::GreaterThan => "GreaterThan",
1671            ComparisonOperator::GreaterThanOrEqual => "GreaterThanOrEqual",
1672            ComparisonOperator::LessThan => "LessThan",
1673            ComparisonOperator::LessThanOrEqual => "LessThanOrEqual",
1674            ComparisonOperator::Contains => "Contains",
1675            ComparisonOperator::NotContains => "NotContains",
1676            ComparisonOperator::StartsWith => "StartsWith",
1677            ComparisonOperator::EndsWith => "EndsWith",
1678            ComparisonOperator::Matches => "Matches",
1679            ComparisonOperator::HasLengthEqual => "HasLengthEqual",
1680            ComparisonOperator::HasLengthGreaterThan => "HasLengthGreaterThan",
1681            ComparisonOperator::HasLengthLessThan => "HasLengthLessThan",
1682            ComparisonOperator::HasLengthGreaterThanOrEqual => "HasLengthGreaterThanOrEqual",
1683            ComparisonOperator::HasLengthLessThanOrEqual => "HasLengthLessThanOrEqual",
1684
1685            // Type Validation
1686            ComparisonOperator::IsNumeric => "IsNumeric",
1687            ComparisonOperator::IsString => "IsString",
1688            ComparisonOperator::IsBoolean => "IsBoolean",
1689            ComparisonOperator::IsNull => "IsNull",
1690            ComparisonOperator::IsArray => "IsArray",
1691            ComparisonOperator::IsObject => "IsObject",
1692
1693            // Pattern & Format
1694            ComparisonOperator::IsEmail => "IsEmail",
1695            ComparisonOperator::IsUrl => "IsUrl",
1696            ComparisonOperator::IsUuid => "IsUuid",
1697            ComparisonOperator::IsIso8601 => "IsIso8601",
1698            ComparisonOperator::IsJson => "IsJson",
1699            ComparisonOperator::MatchesRegex => "MatchesRegex",
1700
1701            // Numeric Range
1702            ComparisonOperator::InRange => "InRange",
1703            ComparisonOperator::NotInRange => "NotInRange",
1704            ComparisonOperator::IsPositive => "IsPositive",
1705            ComparisonOperator::IsNegative => "IsNegative",
1706            ComparisonOperator::IsZero => "IsZero",
1707
1708            // Collection/Array
1709            ComparisonOperator::ContainsAll => "ContainsAll",
1710            ComparisonOperator::ContainsAny => "ContainsAny",
1711            ComparisonOperator::ContainsNone => "ContainsNone",
1712            ComparisonOperator::IsEmpty => "IsEmpty",
1713            ComparisonOperator::IsNotEmpty => "IsNotEmpty",
1714            ComparisonOperator::HasUniqueItems => "HasUniqueItems",
1715            ComparisonOperator::SequenceMatches => "SequenceMatches",
1716
1717            // String
1718            ComparisonOperator::IsAlphabetic => "IsAlphabetic",
1719            ComparisonOperator::IsAlphanumeric => "IsAlphanumeric",
1720            ComparisonOperator::IsLowerCase => "IsLowerCase",
1721            ComparisonOperator::IsUpperCase => "IsUpperCase",
1722            ComparisonOperator::ContainsWord => "ContainsWord",
1723
1724            // Tolerance
1725            ComparisonOperator::ApproximatelyEquals => "ApproximatelyEquals",
1726        }
1727    }
1728}
1729
1730#[pyclass]
1731#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1732pub enum AssertionValue {
1733    String(String),
1734    Number(f64),
1735    Integer(i64),
1736    Boolean(bool),
1737    List(Vec<AssertionValue>),
1738    Null(),
1739}
1740
1741impl AssertionValue {
1742    pub fn to_actual(self, comparison: &ComparisonOperator) -> AssertionValue {
1743        match comparison {
1744            ComparisonOperator::HasLengthEqual
1745            | ComparisonOperator::HasLengthGreaterThan
1746            | ComparisonOperator::HasLengthLessThan
1747            | ComparisonOperator::HasLengthGreaterThanOrEqual
1748            | ComparisonOperator::HasLengthLessThanOrEqual => match self {
1749                AssertionValue::List(arr) => AssertionValue::Integer(arr.len() as i64),
1750                AssertionValue::String(s) => AssertionValue::Integer(s.chars().count() as i64),
1751                _ => self,
1752            },
1753            _ => self,
1754        }
1755    }
1756
1757    pub fn to_serde_value(&self) -> Value {
1758        match self {
1759            AssertionValue::String(s) => Value::String(s.clone()),
1760            AssertionValue::Number(n) => Value::Number(serde_json::Number::from_f64(*n).unwrap()),
1761            AssertionValue::Integer(i) => Value::Number(serde_json::Number::from(*i)),
1762            AssertionValue::Boolean(b) => Value::Bool(*b),
1763            AssertionValue::List(arr) => {
1764                let json_arr: Vec<Value> = arr.iter().map(|v| v.to_serde_value()).collect();
1765                Value::Array(json_arr)
1766            }
1767            AssertionValue::Null() => Value::Null,
1768        }
1769    }
1770}
1771/// Converts a PyAny value to an AssertionValue
1772///
1773/// # Errors
1774///
1775/// Returns `EvaluationError::UnsupportedType` if the Python type cannot be converted
1776/// to an `AssertionValue`.
1777pub fn assertion_value_from_py(value: &Bound<'_, PyAny>) -> Result<AssertionValue, TypeError> {
1778    // Check None first as it's a common case
1779    if value.is_none() {
1780        return Ok(AssertionValue::Null());
1781    }
1782
1783    // Check bool before int (bool is subclass of int in Python)
1784    if value.is_instance_of::<PyBool>() {
1785        return Ok(AssertionValue::Boolean(value.extract()?));
1786    }
1787
1788    if value.is_instance_of::<PyString>() {
1789        return Ok(AssertionValue::String(value.extract()?));
1790    }
1791
1792    if value.is_instance_of::<PyInt>() {
1793        return Ok(AssertionValue::Integer(value.extract()?));
1794    }
1795
1796    if value.is_instance_of::<PyFloat>() {
1797        return Ok(AssertionValue::Number(value.extract()?));
1798    }
1799
1800    if value.is_instance_of::<PyList>() {
1801        // For list, we need to iterate, so one downcast is fine
1802        let list = value.cast::<PyList>()?; // Safe: we just checked
1803        let assertion_list = list
1804            .iter()
1805            .map(|item| assertion_value_from_py(&item))
1806            .collect::<Result<Vec<_>, _>>()?;
1807        return Ok(AssertionValue::List(assertion_list));
1808    }
1809
1810    // Return error for unsupported types
1811    Err(TypeError::UnsupportedType(
1812        value.get_type().name()?.to_string(),
1813    ))
1814}
1815
1816#[pyclass(eq, eq_int)]
1817#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1818pub enum EvaluationTaskType {
1819    Assertion,
1820    LLMJudge,
1821    Conditional,
1822    HumanValidation,
1823    TraceAssertion,
1824    AgentAssertion,
1825}
1826
1827impl Display for EvaluationTaskType {
1828    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1829        let task_type_str = match self {
1830            EvaluationTaskType::Assertion => "Assertion",
1831            EvaluationTaskType::LLMJudge => "LLMJudge",
1832            EvaluationTaskType::Conditional => "Conditional",
1833            EvaluationTaskType::HumanValidation => "HumanValidation",
1834            EvaluationTaskType::TraceAssertion => "TraceAssertion",
1835            EvaluationTaskType::AgentAssertion => "AgentAssertion",
1836        };
1837        write!(f, "{}", task_type_str)
1838    }
1839}
1840
1841impl FromStr for EvaluationTaskType {
1842    type Err = TypeError;
1843
1844    fn from_str(s: &str) -> Result<Self, Self::Err> {
1845        match s {
1846            "Assertion" => Ok(EvaluationTaskType::Assertion),
1847            "LLMJudge" => Ok(EvaluationTaskType::LLMJudge),
1848            "Conditional" => Ok(EvaluationTaskType::Conditional),
1849            "HumanValidation" => Ok(EvaluationTaskType::HumanValidation),
1850            "TraceAssertion" => Ok(EvaluationTaskType::TraceAssertion),
1851            "AgentAssertion" => Ok(EvaluationTaskType::AgentAssertion),
1852            _ => Err(TypeError::InvalidEvalType(s.to_string())),
1853        }
1854    }
1855}
1856
1857impl EvaluationTaskType {
1858    pub fn as_str(&self) -> &str {
1859        match self {
1860            EvaluationTaskType::Assertion => "Assertion",
1861            EvaluationTaskType::LLMJudge => "LLMJudge",
1862            EvaluationTaskType::Conditional => "Conditional",
1863            EvaluationTaskType::HumanValidation => "HumanValidation",
1864            EvaluationTaskType::TraceAssertion => "TraceAssertion",
1865            EvaluationTaskType::AgentAssertion => "AgentAssertion",
1866        }
1867    }
1868}
1869
1870#[pyclass]
1871#[derive(Debug, Serialize)]
1872pub struct TasksFile {
1873    pub tasks: Vec<TaskConfig>,
1874
1875    #[serde(default)]
1876    index: usize,
1877}
1878
1879#[pymethods]
1880impl TasksFile {
1881    #[staticmethod]
1882    pub fn from_path(path: PathBuf) -> Result<Self, TypeError> {
1883        let tasks_file: TasksFile = deserialize_from_path(path)?;
1884        Ok(tasks_file)
1885    }
1886
1887    pub fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
1888        slf
1889    }
1890
1891    pub fn __next__<'py>(
1892        mut slf: PyRefMut<'py, Self>,
1893    ) -> Result<Option<Bound<'py, PyAny>>, TypeError> {
1894        let py = slf.py();
1895        if slf.index < slf.tasks.len() {
1896            let task = slf.tasks[slf.index].clone().into_bound_py_any(py)?;
1897            slf.index += 1;
1898            Ok(Some(task))
1899        } else {
1900            Ok(None)
1901        }
1902    }
1903
1904    fn __getitem__<'py>(
1905        &self,
1906        py: Python<'py>,
1907        index: &Bound<'py, PyAny>,
1908    ) -> Result<Bound<'py, PyAny>, TypeError> {
1909        if let Ok(i) = index.extract::<isize>() {
1910            let len = self.tasks.len() as isize;
1911            let actual_index = if i < 0 { len + i } else { i };
1912
1913            if actual_index < 0 || actual_index >= len {
1914                return Err(TypeError::IndexOutOfBounds {
1915                    index: i,
1916                    length: self.tasks.len(),
1917                });
1918            }
1919
1920            Ok(self.tasks[actual_index as usize]
1921                .clone()
1922                .into_bound_py_any(py)?)
1923        } else if let Ok(slice) = index.cast::<PySlice>() {
1924            let indices = slice.indices(self.tasks.len() as isize)?;
1925            let result = PyList::empty(py);
1926
1927            let mut i = indices.start;
1928            while (indices.step > 0 && i < indices.stop) || (indices.step < 0 && i > indices.stop) {
1929                result.append(self.tasks[i as usize].clone().into_bound_py_any(py)?)?;
1930                i += indices.step;
1931            }
1932
1933            Ok(result.into_bound_py_any(py)?)
1934        } else {
1935            Err(TypeError::IndexOrSliceExpected)
1936        }
1937    }
1938
1939    fn __len__(&self) -> usize {
1940        self.tasks.len()
1941    }
1942
1943    fn __str__(&self) -> String {
1944        PyHelperFuncs::__str__(self)
1945    }
1946}
1947
1948#[derive(Debug, Serialize, Clone)]
1949pub enum TaskConfig {
1950    Assertion(AssertionTask),
1951    #[serde(rename = "LLMJudge")]
1952    LLMJudge(Box<LLMJudgeTask>),
1953    TraceAssertion(TraceAssertionTask),
1954    AgentAssertion(AgentAssertionTask),
1955}
1956
1957impl TaskConfig {
1958    fn into_bound_py_any<'py>(self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
1959        match self {
1960            TaskConfig::Assertion(task) => Ok(task.into_bound_py_any(py)?),
1961            TaskConfig::LLMJudge(task) => Ok(task.into_bound_py_any(py)?),
1962            TaskConfig::TraceAssertion(task) => Ok(task.into_bound_py_any(py)?),
1963            TaskConfig::AgentAssertion(task) => Ok(task.into_bound_py_any(py)?),
1964        }
1965    }
1966}
1967
1968impl<'de> Deserialize<'de> for TasksFile {
1969    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1970    where
1971        D: serde::Deserializer<'de>,
1972    {
1973        #[derive(Deserialize)]
1974        #[serde(untagged)]
1975        enum TasksFileRaw {
1976            Direct(Vec<TaskConfigRaw>),
1977            Wrapped { tasks: Vec<TaskConfigRaw> },
1978        }
1979
1980        #[derive(Deserialize)]
1981        struct TaskConfigRaw {
1982            task_type: EvaluationTaskType,
1983            #[serde(flatten)]
1984            data: Value,
1985        }
1986
1987        let raw = TasksFileRaw::deserialize(deserializer)?;
1988        let raw_tasks = match raw {
1989            TasksFileRaw::Direct(tasks) => tasks,
1990            TasksFileRaw::Wrapped { tasks } => tasks,
1991        };
1992
1993        let mut tasks = Vec::new();
1994
1995        for task_raw in raw_tasks {
1996            let task_config = match task_raw.task_type {
1997                EvaluationTaskType::Assertion => {
1998                    let mut task: AssertionTask =
1999                        serde_json::from_value(task_raw.data).map_err(|e| {
2000                            error!("Failed to deserialize AssertionTask: {}", e);
2001                            serde::de::Error::custom(e.to_string())
2002                        })?;
2003                    task.task_type = EvaluationTaskType::Assertion;
2004                    TaskConfig::Assertion(task)
2005                }
2006                EvaluationTaskType::LLMJudge => {
2007                    let mut task: LLMJudgeTask =
2008                        serde_json::from_value(task_raw.data).map_err(|e| {
2009                            error!("Failed to deserialize LLMJudgeTask: {}", e);
2010                            serde::de::Error::custom(e.to_string())
2011                        })?;
2012                    task.task_type = EvaluationTaskType::LLMJudge;
2013                    TaskConfig::LLMJudge(Box::new(task))
2014                }
2015                EvaluationTaskType::TraceAssertion => {
2016                    let mut task: TraceAssertionTask = serde_json::from_value(task_raw.data)
2017                        .map_err(|e| {
2018                            error!("Failed to deserialize TraceAssertionTask: {}", e);
2019                            serde::de::Error::custom(e.to_string())
2020                        })?;
2021                    task.task_type = EvaluationTaskType::TraceAssertion;
2022                    TaskConfig::TraceAssertion(task)
2023                }
2024                EvaluationTaskType::AgentAssertion => {
2025                    let mut task: AgentAssertionTask = serde_json::from_value(task_raw.data)
2026                        .map_err(|e| {
2027                            error!("Failed to deserialize AgentAssertionTask: {}", e);
2028                            serde::de::Error::custom(e.to_string())
2029                        })?;
2030                    task.task_type = EvaluationTaskType::AgentAssertion;
2031                    TaskConfig::AgentAssertion(task)
2032                }
2033                _ => {
2034                    return Err(serde::de::Error::custom(format!(
2035                        "Unknown task_type: {}",
2036                        task_raw.task_type
2037                    )))
2038                }
2039            };
2040            tasks.push(task_config);
2041        }
2042
2043        Ok(TasksFile { tasks, index: 0 })
2044    }
2045}