Skip to main content

scouter_types/genai/
eval.rs

1use crate::error::TypeError;
2use crate::genai::traits::TaskAccessor;
3use crate::PyHelperFuncs;
4use core::fmt::Debug;
5use potato_head::prompt_types::Prompt;
6use pyo3::prelude::*;
7use pyo3::types::{PyBool, PyFloat, PyInt, PyList, PySlice, PyString};
8use pyo3::IntoPyObjectExt;
9use pythonize::{depythonize, pythonize};
10use serde::de::DeserializeOwned;
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13use std::collections::HashMap;
14use std::fmt::Display;
15use std::path::PathBuf;
16use std::str::FromStr;
17use tracing::error;
18
19pub fn deserialize_from_path<T: DeserializeOwned>(path: PathBuf) -> Result<T, TypeError> {
20    let content = std::fs::read_to_string(&path)?;
21
22    let extension = path
23        .extension()
24        .and_then(|ext| ext.to_str())
25        .ok_or_else(|| TypeError::Error(format!("Invalid file path: {:?}", path)))?;
26
27    let item = match extension.to_lowercase().as_str() {
28        "json" => serde_json::from_str(&content)?,
29        "yaml" | "yml" => serde_yaml::from_str(&content)?,
30        _ => {
31            return Err(TypeError::Error(format!(
32                "Unsupported file extension '{}'. Expected .json, .yaml, or .yml",
33                extension
34            )))
35        }
36    };
37
38    Ok(item)
39}
40
41// Default functions for task types during deserialization
42fn default_assertion_task_type() -> EvaluationTaskType {
43    EvaluationTaskType::Assertion
44}
45
46fn default_trace_assertion_task_type() -> EvaluationTaskType {
47    EvaluationTaskType::TraceAssertion
48}
49
50#[pyclass]
51#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
52pub struct AssertionResult {
53    #[pyo3(get)]
54    pub passed: bool,
55    pub actual: Value,
56
57    #[pyo3(get)]
58    pub message: String,
59
60    pub expected: Value,
61}
62
63impl AssertionResult {
64    pub fn new(passed: bool, actual: Value, message: String, expected: Value) -> Self {
65        Self {
66            passed,
67            actual,
68            message,
69            expected,
70        }
71    }
72    pub fn to_metric_value(&self) -> f64 {
73        if self.passed {
74            1.0
75        } else {
76            0.0
77        }
78    }
79}
80
81#[pymethods]
82impl AssertionResult {
83    pub fn __str__(&self) -> String {
84        PyHelperFuncs::__str__(self)
85    }
86
87    #[getter]
88    pub fn get_actual<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
89        let py_value = pythonize(py, &self.actual)?;
90        Ok(py_value)
91    }
92
93    #[getter]
94    pub fn get_expected<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
95        let py_value = pythonize(py, &self.expected)?;
96        Ok(py_value)
97    }
98}
99
100#[pyclass]
101#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
102pub struct AssertionResults {
103    #[pyo3(get)]
104    pub results: HashMap<String, AssertionResult>,
105}
106
107#[pymethods]
108impl AssertionResults {
109    pub fn __str__(&self) -> String {
110        PyHelperFuncs::__str__(self)
111    }
112
113    pub fn __getitem__(&self, key: &str) -> Result<AssertionResult, TypeError> {
114        if let Some(result) = self.results.get(key) {
115            Ok(result.clone())
116        } else {
117            Err(TypeError::KeyNotFound {
118                key: key.to_string(),
119            })
120        }
121    }
122}
123
124#[pyclass]
125#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
126pub struct AssertionTask {
127    #[pyo3(get, set)]
128    pub id: String,
129
130    #[pyo3(get, set)]
131    #[serde(default)]
132    pub context_path: Option<String>,
133
134    #[pyo3(get, set)]
135    #[serde(default)]
136    pub item_context_path: Option<String>,
137
138    #[pyo3(get, set)]
139    pub operator: ComparisonOperator,
140
141    pub expected_value: Value,
142
143    #[pyo3(get, set)]
144    #[serde(default)]
145    pub description: Option<String>,
146
147    #[pyo3(get, set)]
148    #[serde(default)]
149    pub depends_on: Vec<String>,
150
151    #[serde(default = "default_assertion_task_type")]
152    #[pyo3(get)]
153    pub task_type: EvaluationTaskType,
154
155    #[serde(skip_serializing_if = "Option::is_none")]
156    pub result: Option<AssertionResult>,
157
158    #[serde(default)]
159    pub condition: bool,
160}
161
162#[pymethods]
163impl AssertionTask {
164    #[new]
165    /// Creates a new AssertionTask
166    /// # Examples
167    /// ```python
168    ///
169    /// # assumed passed context at runtime
170    /// # context = {
171    /// #     "response": {
172    /// #         "user": {
173    /// #             "age": 25
174    /// #         }
175    /// #     }
176    /// # }
177    ///
178    /// task = AssertionTask(
179    ///     id="Check User Age",
180    ///     context_path="response.user.age",
181    ///     operator=ComparisonOperator.GREATER_THAN,
182    ///     expected_value=18,
183    ///     description="Check if user is an adult"
184    /// )
185    ///
186    /// # assumed passed context at runtime
187    /// # context = {
188    /// #     "user": {
189    /// #         "age": 25
190    /// #     }
191    /// # }
192    ///
193    /// task = AssertionTask(
194    ///     id="Check User Age",
195    ///     context_path="user.age",
196    ///     operator=ComparisonOperator.GREATER_THAN,
197    ///     expected_value=18,
198    ///     description="Check if user is an adult"
199    /// )
200    ///
201    ///  /// # assume non-map context at runtime
202    /// # context = 25
203    ///
204    /// task = AssertionTask(
205    ///     id="Check User Age",
206    ///     operator=ComparisonOperator.GREATER_THAN,
207    ///     expected_value=18,
208    ///     description="Check if user is an adult"
209    /// )
210    /// ```
211    /// # Arguments
212    /// * `context_path`: The path to the field to be asserted
213    /// * `operator`: The comparison operator to use
214    /// * `expected_value`: The expected value for the assertion
215    /// * `description`: Optional description for the assertion
216    /// # Returns
217    /// A new AssertionTask object
218    #[allow(clippy::too_many_arguments)]
219    #[pyo3(signature = (id, context_path, expected_value, operator, item_context_path=None, description=None, depends_on=None, condition=None))]
220    pub fn new(
221        id: String,
222        context_path: Option<String>,
223        expected_value: &Bound<'_, PyAny>,
224        operator: ComparisonOperator,
225        item_context_path: Option<String>,
226        description: Option<String>,
227        depends_on: Option<Vec<String>>,
228        condition: Option<bool>,
229    ) -> Result<Self, TypeError> {
230        let expected_value = depythonize(expected_value)?;
231        let condition = condition.unwrap_or(false);
232
233        Ok(Self {
234            id: id.to_lowercase(),
235            context_path,
236            item_context_path,
237            operator,
238            expected_value,
239            description,
240            task_type: EvaluationTaskType::Assertion,
241            depends_on: depends_on.unwrap_or_default(),
242            result: None,
243            condition,
244        })
245    }
246
247    #[getter]
248    pub fn get_expected_value<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
249        let py_value = pythonize(py, &self.expected_value)?;
250        Ok(py_value)
251    }
252
253    #[staticmethod]
254    pub fn from_path(path: PathBuf) -> Result<Self, TypeError> {
255        deserialize_from_path(path)
256    }
257}
258
259impl AssertionTask {}
260
261impl TaskAccessor for AssertionTask {
262    fn context_path(&self) -> Option<&str> {
263        self.context_path.as_deref()
264    }
265
266    fn item_context_path(&self) -> Option<&str> {
267        self.item_context_path.as_deref()
268    }
269
270    fn id(&self) -> &str {
271        &self.id
272    }
273
274    fn operator(&self) -> &ComparisonOperator {
275        &self.operator
276    }
277
278    fn task_type(&self) -> &EvaluationTaskType {
279        &self.task_type
280    }
281
282    fn expected_value(&self) -> &Value {
283        &self.expected_value
284    }
285
286    fn depends_on(&self) -> &[String] {
287        &self.depends_on
288    }
289
290    fn add_result(&mut self, result: AssertionResult) {
291        self.result = Some(result);
292    }
293}
294
295pub trait ValueExt {
296    /// Convert value to length for HasLength comparisons
297    fn to_length(&self) -> Option<i64>;
298
299    /// Extract numeric value for comparison
300    fn as_numeric(&self) -> Option<f64>;
301
302    /// Check if value is truthy
303    fn is_truthy(&self) -> bool;
304}
305
306impl ValueExt for Value {
307    fn to_length(&self) -> Option<i64> {
308        match self {
309            Value::Array(arr) => Some(arr.len() as i64),
310            Value::String(s) => Some(s.chars().count() as i64),
311            Value::Object(obj) => Some(obj.len() as i64),
312            _ => None,
313        }
314    }
315
316    fn as_numeric(&self) -> Option<f64> {
317        match self {
318            Value::Number(n) => n.as_f64(),
319            _ => None,
320        }
321    }
322
323    fn is_truthy(&self) -> bool {
324        match self {
325            Value::Null => false,
326            Value::Bool(b) => *b,
327            Value::Number(n) => n.as_f64() != Some(0.0),
328            Value::String(s) => !s.is_empty(),
329            Value::Array(arr) => !arr.is_empty(),
330            Value::Object(obj) => !obj.is_empty(),
331        }
332    }
333}
334
335/// Primary class for defining an LLM as a Judge in evaluation workflows
336#[pyclass]
337#[derive(Debug, Serialize, Clone, PartialEq)]
338pub struct LLMJudgeTask {
339    #[pyo3(get, set)]
340    pub id: String,
341
342    #[pyo3(get)]
343    pub prompt: Prompt,
344
345    #[pyo3(get)]
346    #[serde(default)]
347    pub context_path: Option<String>,
348
349    pub expected_value: Value,
350
351    #[pyo3(get)]
352    pub operator: ComparisonOperator,
353
354    #[pyo3(get)]
355    pub task_type: EvaluationTaskType,
356
357    #[pyo3(get, set)]
358    #[serde(default)]
359    pub depends_on: Vec<String>,
360
361    #[pyo3(get, set)]
362    #[serde(default)]
363    pub max_retries: Option<u32>,
364
365    #[serde(skip_serializing_if = "Option::is_none")]
366    pub result: Option<AssertionResult>,
367
368    #[serde(default)]
369    pub description: Option<String>,
370
371    #[pyo3(get, set)]
372    #[serde(default)]
373    pub condition: bool,
374}
375
376#[derive(Debug, Deserialize)]
377#[serde(untagged)]
378pub enum PromptConfig {
379    Path { path: String },
380    Inline(Box<Prompt>),
381}
382
383#[derive(Debug, Deserialize)]
384struct LLMJudgeTaskConfig {
385    pub id: String,
386    pub prompt: PromptConfig,
387    pub expected_value: Value,
388    pub operator: ComparisonOperator,
389    pub context_path: Option<String>,
390    pub description: Option<String>,
391    pub depends_on: Vec<String>,
392    pub max_retries: Option<u32>,
393    pub condition: bool,
394}
395
396impl LLMJudgeTaskConfig {
397    pub fn into_task(self) -> Result<LLMJudgeTask, TypeError> {
398        let prompt = match self.prompt {
399            PromptConfig::Path { path } => {
400                Prompt::from_path(PathBuf::from(path)).inspect_err(|e| {
401                    error!("Failed to deserialize Prompt from path: {}", e);
402                })?
403            }
404            PromptConfig::Inline(prompt) => *prompt,
405        };
406
407        Ok(LLMJudgeTask {
408            id: self.id.to_lowercase(),
409            prompt,
410            expected_value: self.expected_value,
411            operator: self.operator,
412            context_path: self.context_path,
413            description: self.description,
414            depends_on: self.depends_on,
415            max_retries: self.max_retries.or(Some(3)),
416            task_type: EvaluationTaskType::LLMJudge,
417            result: None,
418            condition: self.condition,
419        })
420    }
421}
422
423#[derive(Debug, Deserialize)]
424struct LLMJudgeTaskInternal {
425    pub id: String,
426    pub prompt: Prompt,
427    pub context_path: Option<String>,
428    pub expected_value: Value,
429    pub operator: ComparisonOperator,
430    pub task_type: EvaluationTaskType,
431    pub depends_on: Vec<String>,
432    pub max_retries: Option<u32>,
433    pub result: Option<AssertionResult>,
434    pub description: Option<String>,
435    pub condition: bool,
436}
437impl LLMJudgeTaskInternal {
438    pub fn into_task(self) -> LLMJudgeTask {
439        LLMJudgeTask {
440            id: self.id.to_lowercase(),
441            prompt: self.prompt,
442            context_path: self.context_path,
443            expected_value: self.expected_value,
444            operator: self.operator,
445            task_type: self.task_type,
446            depends_on: self.depends_on,
447            max_retries: self.max_retries.or(Some(3)),
448            result: self.result,
449            description: self.description,
450            condition: self.condition,
451        }
452    }
453}
454
455#[derive(Debug, Deserialize)]
456#[serde(untagged)]
457enum LLMJudgeFormat {
458    Full(Box<LLMJudgeTaskInternal>),
459    Generic(LLMJudgeTaskConfig),
460}
461
462impl<'de> Deserialize<'de> for LLMJudgeTask {
463    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
464    where
465        D: serde::Deserializer<'de>,
466    {
467        let format = LLMJudgeFormat::deserialize(deserializer)?;
468
469        match format {
470            LLMJudgeFormat::Generic(config) => config.into_task().map_err(serde::de::Error::custom),
471            LLMJudgeFormat::Full(internal) => Ok(internal.into_task()),
472        }
473    }
474}
475
476#[pymethods]
477impl LLMJudgeTask {
478    /// Creates a new LLMJudgeTask
479    /// # Examples
480    /// ```python
481    /// task = LLMJudgeTask(
482    ///     id="Sentiment Analysis Judge",
483    ///     prompt=prompt_object,
484    ///     expected_value="Positive",
485    ///     operator=ComparisonOperator.EQUALS
486    /// )
487    /// # Arguments
488    /// * `id: The id of the judge task
489    /// * `prompt`: The prompt object to be used for evaluation
490    /// * `expected_value`: The expected value for the judgement
491    /// * `context_path`: Optional context path to extract from the context for evaluation
492    /// * `operator`: The comparison operator to use
493    /// * `depends_on`: Optional list of task IDs this task depends on
494    /// * `max_retries`: Optional maximum number of retries for this task (defaults to 3 if not provided)
495    /// # Returns
496    /// A new LLMJudgeTask object
497    #[new]
498    #[pyo3(signature = (id, prompt, expected_value,  context_path,operator, description=None, depends_on=None, max_retries=None, condition=None))]
499    #[allow(clippy::too_many_arguments)]
500    pub fn new(
501        id: &str,
502        prompt: Prompt,
503        expected_value: &Bound<'_, PyAny>,
504        context_path: Option<String>,
505        operator: ComparisonOperator,
506        description: Option<String>,
507        depends_on: Option<Vec<String>>,
508        max_retries: Option<u32>,
509        condition: Option<bool>,
510    ) -> Result<Self, TypeError> {
511        let expected_value = depythonize(expected_value)?;
512
513        Ok(Self {
514            id: id.to_lowercase(),
515            prompt,
516            expected_value,
517            operator,
518            task_type: EvaluationTaskType::LLMJudge,
519            depends_on: depends_on.unwrap_or_default(),
520            max_retries: max_retries.or(Some(3)),
521            context_path,
522            result: None,
523            description,
524            condition: condition.unwrap_or(false),
525        })
526    }
527
528    pub fn __str__(&self) -> String {
529        // serialize the struct to a string
530        PyHelperFuncs::__str__(self)
531    }
532
533    #[getter]
534    pub fn get_expected_value<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
535        let py_value = pythonize(py, &self.expected_value)?;
536        Ok(py_value)
537    }
538
539    #[staticmethod]
540    pub fn from_path(path: PathBuf) -> Result<Self, TypeError> {
541        deserialize_from_path(path)
542    }
543}
544
545impl LLMJudgeTask {
546    /// Creates a new LLMJudgeTask with Rust types
547    /// # Arguments
548    /// * `id: The id of the judge task
549    /// * `prompt`: The prompt object to be used for evaluation
550    /// * `expected_value`: The expected value for the judgement
551    /// * `context_path`: Optional context path to extract from the context for evaluation
552    /// * `operator`: The comparison operator to use
553    /// * `depends_on`: Optional list of task IDs this task depends on
554    /// * `max_retries`: Optional maximum number of retries for this task (defaults to 3 if not provided)
555    /// # Returns
556    /// A new LLMJudgeTask object
557    #[allow(clippy::too_many_arguments)]
558    pub fn new_rs(
559        id: &str,
560        prompt: Prompt,
561        expected_value: Value,
562        context_path: Option<String>,
563        operator: ComparisonOperator,
564        depends_on: Option<Vec<String>>,
565        max_retries: Option<u32>,
566        description: Option<String>,
567        condition: Option<bool>,
568    ) -> Self {
569        Self {
570            id: id.to_lowercase(),
571            prompt,
572            expected_value,
573            operator,
574            task_type: EvaluationTaskType::LLMJudge,
575            depends_on: depends_on.unwrap_or_default(),
576            max_retries: max_retries.or(Some(3)),
577            context_path,
578            result: None,
579            description,
580            condition: condition.unwrap_or(false),
581        }
582    }
583}
584
585impl TaskAccessor for LLMJudgeTask {
586    fn context_path(&self) -> Option<&str> {
587        self.context_path.as_deref()
588    }
589
590    fn item_context_path(&self) -> Option<&str> {
591        None
592    }
593
594    fn id(&self) -> &str {
595        &self.id
596    }
597    fn task_type(&self) -> &EvaluationTaskType {
598        &self.task_type
599    }
600
601    fn operator(&self) -> &ComparisonOperator {
602        &self.operator
603    }
604
605    fn expected_value(&self) -> &Value {
606        &self.expected_value
607    }
608    fn depends_on(&self) -> &[String] {
609        &self.depends_on
610    }
611    fn add_result(&mut self, result: AssertionResult) {
612        self.result = Some(result);
613    }
614}
615
616#[pyclass(eq, eq_int)]
617#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
618pub enum SpanStatus {
619    Ok,
620    Error,
621    Unset,
622}
623
624#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
625pub struct PyValueWrapper(pub Value);
626
627impl<'py> IntoPyObject<'py> for PyValueWrapper {
628    type Target = PyAny; // the Python type
629    type Output = Bound<'py, Self::Target>;
630    type Error = std::convert::Infallible;
631
632    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
633        Ok(pythonize(py, &self.0).unwrap())
634    }
635}
636
637impl<'a, 'py> FromPyObject<'a, 'py> for PyValueWrapper {
638    type Error = TypeError;
639
640    fn extract(ob: pyo3::Borrowed<'a, 'py, pyo3::PyAny>) -> Result<Self, Self::Error> {
641        let value: Value = depythonize(&ob)?;
642        Ok(PyValueWrapper(value))
643    }
644}
645
646/// Filter configuration for selecting spans to assert on
647#[pyclass(eq)]
648#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
649pub enum SpanFilter {
650    /// Match spans by exact name
651    ByName { name: String },
652
653    /// Match spans by name pattern (regex)
654    ByNamePattern { pattern: String },
655
656    /// Match spans with specific attribute key
657    WithAttribute { key: String },
658
659    /// Match spans with specific attribute key-value pair
660    WithAttributeValue { key: String, value: PyValueWrapper },
661
662    /// Match spans by status code
663    WithStatus { status: SpanStatus },
664
665    /// Match spans with duration constraints
666    WithDuration {
667        min_ms: Option<f64>,
668        max_ms: Option<f64>,
669    },
670
671    /// Match a sequence of span names in order
672    Sequence { names: Vec<String> },
673
674    /// Combine multiple filters with AND logic
675    And { filters: Vec<SpanFilter> },
676
677    /// Combine multiple filters with OR logic
678    Or { filters: Vec<SpanFilter> },
679}
680
681#[pymethods]
682impl SpanFilter {
683    #[staticmethod]
684    pub fn by_name(name: String) -> Self {
685        SpanFilter::ByName { name }
686    }
687
688    #[staticmethod]
689    pub fn by_name_pattern(pattern: String) -> Self {
690        SpanFilter::ByNamePattern { pattern }
691    }
692
693    #[staticmethod]
694    pub fn with_attribute(key: String) -> Self {
695        SpanFilter::WithAttribute { key }
696    }
697
698    #[staticmethod]
699    pub fn with_attribute_value(key: String, value: &Bound<'_, PyAny>) -> Result<Self, TypeError> {
700        let value = PyValueWrapper(depythonize(value).unwrap());
701        Ok(SpanFilter::WithAttributeValue { key, value })
702    }
703
704    #[staticmethod]
705    pub fn with_status(status: SpanStatus) -> Self {
706        SpanFilter::WithStatus { status }
707    }
708
709    #[staticmethod]
710    #[pyo3(signature = (min_ms=None, max_ms=None))]
711    pub fn with_duration(min_ms: Option<f64>, max_ms: Option<f64>) -> Self {
712        SpanFilter::WithDuration { min_ms, max_ms }
713    }
714
715    #[staticmethod]
716    pub fn sequence(names: Vec<String>) -> Self {
717        SpanFilter::Sequence { names }
718    }
719
720    pub fn and_(&self, other: SpanFilter) -> Self {
721        match self {
722            SpanFilter::And { filters } => {
723                let mut new_filters = filters.clone();
724                new_filters.push(other);
725                SpanFilter::And {
726                    filters: new_filters,
727                }
728            }
729            _ => SpanFilter::And {
730                filters: vec![self.clone(), other],
731            },
732        }
733    }
734
735    pub fn or_(&self, other: SpanFilter) -> Self {
736        match self {
737            SpanFilter::Or { filters } => {
738                let mut new_filters = filters.clone();
739                new_filters.push(other);
740                SpanFilter::Or {
741                    filters: new_filters,
742                }
743            }
744            _ => SpanFilter::Or {
745                filters: vec![self.clone(), other],
746            },
747        }
748    }
749}
750
751#[pyclass(eq, eq_int)]
752#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
753pub enum AggregationType {
754    Count,
755    Sum,
756    Average,
757    Min,
758    Max,
759    First,
760    Last,
761}
762
763/// Unified assertion target that can operate on traces or filtered spans
764#[pyclass(eq)]
765#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
766pub enum TraceAssertion {
767    /// Check if spans exist in a specific order
768    SpanSequence { span_names: Vec<String> },
769
770    /// Check if all specified span names exist (order doesn't matter)
771    SpanSet { span_names: Vec<String> },
772
773    /// Count spans matching a filter
774    SpanCount { filter: SpanFilter },
775
776    /// Check if any span matching filter exists
777    SpanExists { filter: SpanFilter },
778
779    /// Get attribute value from span(s) matching filter
780    SpanAttribute {
781        filter: SpanFilter,
782        attribute_key: String,
783    },
784
785    /// Get duration of span(s) matching filter
786    SpanDuration { filter: SpanFilter },
787
788    /// Aggregate a numeric attribute across filtered spans
789    SpanAggregation {
790        filter: SpanFilter,
791        attribute_key: String,
792        aggregation: AggregationType,
793    },
794
795    /// Check total duration of entire trace
796    TraceDuration {},
797
798    /// Count total spans in trace
799    TraceSpanCount {},
800
801    /// Count spans with errors in trace
802    TraceErrorCount {},
803
804    /// Count unique services in trace
805    TraceServiceCount {},
806
807    /// Get maximum depth of span tree
808    TraceMaxDepth {},
809
810    /// Get trace-level attribute
811    TraceAttribute { attribute_key: String },
812}
813
814// implement to_string for TraceAssertion
815impl Display for TraceAssertion {
816    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
817        // serde serialize to string
818        let s = serde_json::to_string(self).unwrap_or_default();
819        write!(f, "{}", s)
820    }
821}
822
823#[pymethods]
824impl TraceAssertion {
825    #[staticmethod]
826    pub fn span_sequence(span_names: Vec<String>) -> Self {
827        TraceAssertion::SpanSequence { span_names }
828    }
829
830    #[staticmethod]
831    pub fn span_set(span_names: Vec<String>) -> Self {
832        TraceAssertion::SpanSet { span_names }
833    }
834
835    #[staticmethod]
836    pub fn span_count(filter: SpanFilter) -> Self {
837        TraceAssertion::SpanCount { filter }
838    }
839
840    #[staticmethod]
841    pub fn span_exists(filter: SpanFilter) -> Self {
842        TraceAssertion::SpanExists { filter }
843    }
844
845    #[staticmethod]
846    pub fn span_attribute(filter: SpanFilter, attribute_key: String) -> Self {
847        TraceAssertion::SpanAttribute {
848            filter,
849            attribute_key,
850        }
851    }
852
853    #[staticmethod]
854    pub fn span_duration(filter: SpanFilter) -> Self {
855        TraceAssertion::SpanDuration { filter }
856    }
857
858    #[staticmethod]
859    pub fn span_aggregation(
860        filter: SpanFilter,
861        attribute_key: String,
862        aggregation: AggregationType,
863    ) -> Self {
864        TraceAssertion::SpanAggregation {
865            filter,
866            attribute_key,
867            aggregation,
868        }
869    }
870
871    #[staticmethod]
872    pub fn trace_duration() -> Self {
873        TraceAssertion::TraceDuration {}
874    }
875
876    #[staticmethod]
877    pub fn trace_span_count() -> Self {
878        TraceAssertion::TraceSpanCount {}
879    }
880
881    #[staticmethod]
882    pub fn trace_error_count() -> Self {
883        TraceAssertion::TraceErrorCount {}
884    }
885
886    #[staticmethod]
887    pub fn trace_service_count() -> Self {
888        TraceAssertion::TraceServiceCount {}
889    }
890
891    #[staticmethod]
892    pub fn trace_max_depth() -> Self {
893        TraceAssertion::TraceMaxDepth {}
894    }
895
896    #[staticmethod]
897    pub fn trace_attribute(attribute_key: String) -> Self {
898        TraceAssertion::TraceAttribute { attribute_key }
899    }
900
901    pub fn model_dump_json(&self) -> String {
902        serde_json::to_string(self).unwrap_or_default()
903    }
904}
905
906#[pyclass]
907#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
908pub struct TraceAssertionTask {
909    #[pyo3(get, set)]
910    pub id: String,
911
912    #[pyo3(get, set)]
913    pub assertion: TraceAssertion,
914
915    #[pyo3(get, set)]
916    pub operator: ComparisonOperator,
917
918    pub expected_value: Value,
919
920    #[pyo3(get, set)]
921    #[serde(default)]
922    pub description: Option<String>,
923
924    #[pyo3(get, set)]
925    #[serde(default)]
926    pub depends_on: Vec<String>,
927
928    #[serde(default = "default_trace_assertion_task_type")]
929    #[pyo3(get)]
930    pub task_type: EvaluationTaskType,
931
932    #[serde(skip_serializing_if = "Option::is_none")]
933    pub result: Option<AssertionResult>,
934
935    #[pyo3(get, set)]
936    #[serde(default)]
937    pub condition: bool,
938}
939
940#[pymethods]
941impl TraceAssertionTask {
942    /// Creates a new TraceAssertionTask
943    ///
944    /// # Examples
945    /// ```python
946    /// # Check execution order of spans
947    /// task = TraceAssertionTask(
948    ///     id="verify_agent_workflow",
949    ///     assertion=TraceAssertion.span_sequence(["call_tool", "run_agent", "double_check"]),
950    ///     operator=ComparisonOperator.SequenceMatches,
951    ///     expected_value=True
952    /// )
953    ///
954    /// # Check all required spans exist
955    /// task = TraceAssertionTask(
956    ///     id="verify_required_steps",
957    ///     assertion=TraceAssertion.span_set(["call_tool", "run_agent", "double_check"]),
958    ///     operator=ComparisonOperator.ContainsAll,
959    ///     expected_value=True
960    /// )
961    ///
962    /// # Check total trace duration
963    /// task = TraceAssertionTask(
964    ///     id="verify_performance",
965    ///     assertion=TraceAssertion.trace_duration(),
966    ///     operator=ComparisonOperator.LessThan,
967    ///     expected_value=5000.0
968    /// )
969    ///
970    /// # Check count of specific spans
971    /// task = TraceAssertionTask(
972    ///     id="verify_retry_count",
973    ///     assertion=TraceAssertion.span_count(
974    ///         SpanFilter.by_name("retry_operation")
975    ///     ),
976    ///     operator=ComparisonOperator.LessThanOrEqual,
977    ///     expected_value=3
978    /// )
979    ///
980    /// # Check span attribute
981    /// task = TraceAssertionTask(
982    ///     id="verify_model_used",
983    ///     assertion=TraceAssertion.span_attribute(
984    ///         SpanFilter.by_name("llm.generate"),
985    ///         "model"
986    ///     ),
987    ///     operator=ComparisonOperator.Equals,
988    ///     expected_value="gpt-4"
989    /// )
990    /// ```
991    #[new]
992    /// Creates a new TraceAssertionTask
993    #[pyo3(signature = (id, assertion, expected_value, operator, description=None, depends_on=None, condition=None))]
994    pub fn new(
995        id: String,
996        assertion: TraceAssertion,
997        expected_value: &Bound<'_, PyAny>,
998        operator: ComparisonOperator,
999        description: Option<String>,
1000        depends_on: Option<Vec<String>>,
1001        condition: Option<bool>,
1002    ) -> Result<Self, TypeError> {
1003        let expected_value = depythonize(expected_value)?;
1004
1005        Ok(Self {
1006            id: id.to_lowercase(),
1007            assertion,
1008            operator,
1009            expected_value,
1010            description,
1011            task_type: EvaluationTaskType::TraceAssertion,
1012            depends_on: depends_on.unwrap_or_default(),
1013            result: None,
1014            condition: condition.unwrap_or(false),
1015        })
1016    }
1017
1018    pub fn __str__(&self) -> String {
1019        // serialize the struct to a string
1020        PyHelperFuncs::__str__(self)
1021    }
1022
1023    #[getter]
1024    pub fn get_expected_value<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
1025        let py_value = pythonize(py, &self.expected_value)?;
1026        Ok(py_value)
1027    }
1028
1029    #[staticmethod]
1030    pub fn from_path(path: PathBuf) -> Result<Self, TypeError> {
1031        deserialize_from_path(path)
1032    }
1033}
1034
1035impl TaskAccessor for TraceAssertionTask {
1036    fn context_path(&self) -> Option<&str> {
1037        None
1038    }
1039
1040    fn item_context_path(&self) -> Option<&str> {
1041        None
1042    }
1043
1044    fn id(&self) -> &str {
1045        &self.id
1046    }
1047
1048    fn operator(&self) -> &ComparisonOperator {
1049        &self.operator
1050    }
1051
1052    fn task_type(&self) -> &EvaluationTaskType {
1053        &self.task_type
1054    }
1055
1056    fn expected_value(&self) -> &Value {
1057        &self.expected_value
1058    }
1059
1060    fn depends_on(&self) -> &[String] {
1061        &self.depends_on
1062    }
1063
1064    fn add_result(&mut self, result: AssertionResult) {
1065        self.result = Some(result);
1066    }
1067}
1068
1069#[derive(Debug, Clone)]
1070pub enum EvaluationTask {
1071    Assertion(Box<AssertionTask>),
1072    LLMJudge(Box<LLMJudgeTask>),
1073    TraceAssertion(Box<TraceAssertionTask>),
1074}
1075
1076impl TaskAccessor for EvaluationTask {
1077    fn context_path(&self) -> Option<&str> {
1078        match self {
1079            EvaluationTask::Assertion(t) => t.context_path(),
1080            EvaluationTask::LLMJudge(t) => t.context_path(),
1081            EvaluationTask::TraceAssertion(t) => t.context_path(),
1082        }
1083    }
1084
1085    fn item_context_path(&self) -> Option<&str> {
1086        match self {
1087            EvaluationTask::Assertion(t) => t.item_context_path(),
1088            EvaluationTask::LLMJudge(t) => t.item_context_path(),
1089            EvaluationTask::TraceAssertion(t) => t.item_context_path(),
1090        }
1091    }
1092
1093    fn id(&self) -> &str {
1094        match self {
1095            EvaluationTask::Assertion(t) => t.id(),
1096            EvaluationTask::LLMJudge(t) => t.id(),
1097            EvaluationTask::TraceAssertion(t) => t.id(),
1098        }
1099    }
1100
1101    fn task_type(&self) -> &EvaluationTaskType {
1102        match self {
1103            EvaluationTask::Assertion(t) => t.task_type(),
1104            EvaluationTask::LLMJudge(t) => t.task_type(),
1105            EvaluationTask::TraceAssertion(t) => t.task_type(),
1106        }
1107    }
1108
1109    fn operator(&self) -> &ComparisonOperator {
1110        match self {
1111            EvaluationTask::Assertion(t) => t.operator(),
1112            EvaluationTask::LLMJudge(t) => t.operator(),
1113            EvaluationTask::TraceAssertion(t) => t.operator(),
1114        }
1115    }
1116
1117    fn expected_value(&self) -> &Value {
1118        match self {
1119            EvaluationTask::Assertion(t) => t.expected_value(),
1120            EvaluationTask::LLMJudge(t) => t.expected_value(),
1121            EvaluationTask::TraceAssertion(t) => t.expected_value(),
1122        }
1123    }
1124
1125    fn depends_on(&self) -> &[String] {
1126        match self {
1127            EvaluationTask::Assertion(t) => t.depends_on(),
1128            EvaluationTask::LLMJudge(t) => t.depends_on(),
1129            EvaluationTask::TraceAssertion(t) => t.depends_on(),
1130        }
1131    }
1132
1133    fn add_result(&mut self, result: AssertionResult) {
1134        match self {
1135            EvaluationTask::Assertion(t) => t.add_result(result),
1136            EvaluationTask::LLMJudge(t) => t.add_result(result),
1137            EvaluationTask::TraceAssertion(t) => t.add_result(result),
1138        }
1139    }
1140}
1141
1142pub struct EvaluationTasks(Vec<EvaluationTask>);
1143
1144impl EvaluationTasks {
1145    /// Creates a new empty builder
1146    pub fn new() -> Self {
1147        Self(Vec::new())
1148    }
1149
1150    /// Generic method that accepts anything implementing Into<EvaluationTask>
1151    pub fn add_task(mut self, task: impl Into<EvaluationTask>) -> Self {
1152        self.0.push(task.into());
1153        self
1154    }
1155
1156    /// Builds and returns the Vec<EvaluationTask>
1157    pub fn build(self) -> Vec<EvaluationTask> {
1158        self.0
1159    }
1160}
1161
1162impl From<AssertionTask> for EvaluationTask {
1163    fn from(task: AssertionTask) -> Self {
1164        EvaluationTask::Assertion(Box::new(task))
1165    }
1166}
1167
1168impl From<LLMJudgeTask> for EvaluationTask {
1169    fn from(task: LLMJudgeTask) -> Self {
1170        EvaluationTask::LLMJudge(Box::new(task))
1171    }
1172}
1173
1174impl From<TraceAssertionTask> for EvaluationTask {
1175    fn from(task: TraceAssertionTask) -> Self {
1176        EvaluationTask::TraceAssertion(Box::new(task))
1177    }
1178}
1179
1180impl Default for EvaluationTasks {
1181    fn default() -> Self {
1182        Self::new()
1183    }
1184}
1185
1186#[pyclass(eq, eq_int)]
1187#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1188pub enum ComparisonOperator {
1189    // Existing operators
1190    Equals,
1191    NotEqual,
1192    GreaterThan,
1193    GreaterThanOrEqual,
1194    LessThan,
1195    LessThanOrEqual,
1196    Contains,
1197    NotContains,
1198    StartsWith,
1199    EndsWith,
1200    Matches,
1201    HasLengthGreaterThan,
1202    HasLengthLessThan,
1203    HasLengthEqual,
1204    HasLengthGreaterThanOrEqual,
1205    HasLengthLessThanOrEqual,
1206
1207    // Type Validation Operators
1208    IsNumeric,
1209    IsString,
1210    IsBoolean,
1211    IsNull,
1212    IsArray,
1213    IsObject,
1214
1215    // Pattern & Format Validators
1216    IsEmail,
1217    IsUrl,
1218    IsUuid,
1219    IsIso8601,
1220    IsJson,
1221    MatchesRegex,
1222
1223    // Numeric Range Operators
1224    InRange,
1225    NotInRange,
1226    IsPositive,
1227    IsNegative,
1228    IsZero,
1229
1230    // Collection/Array Operators
1231    SequenceMatches,
1232    ContainsAll,
1233    ContainsAny,
1234    ContainsNone,
1235    IsEmpty,
1236    IsNotEmpty,
1237    HasUniqueItems,
1238
1239    // String Operators
1240    IsAlphabetic,
1241    IsAlphanumeric,
1242    IsLowerCase,
1243    IsUpperCase,
1244    ContainsWord,
1245
1246    // Comparison with Tolerance
1247    ApproximatelyEquals,
1248}
1249
1250impl Display for ComparisonOperator {
1251    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1252        write!(f, "{}", self.as_str())
1253    }
1254}
1255
1256impl FromStr for ComparisonOperator {
1257    type Err = TypeError;
1258
1259    fn from_str(s: &str) -> Result<Self, Self::Err> {
1260        match s {
1261            "Equals" => Ok(ComparisonOperator::Equals),
1262            "NotEqual" => Ok(ComparisonOperator::NotEqual),
1263            "GreaterThan" => Ok(ComparisonOperator::GreaterThan),
1264            "GreaterThanOrEqual" => Ok(ComparisonOperator::GreaterThanOrEqual),
1265            "LessThan" => Ok(ComparisonOperator::LessThan),
1266            "LessThanOrEqual" => Ok(ComparisonOperator::LessThanOrEqual),
1267            "Contains" => Ok(ComparisonOperator::Contains),
1268            "NotContains" => Ok(ComparisonOperator::NotContains),
1269            "StartsWith" => Ok(ComparisonOperator::StartsWith),
1270            "EndsWith" => Ok(ComparisonOperator::EndsWith),
1271            "Matches" => Ok(ComparisonOperator::Matches),
1272            "HasLengthEqual" => Ok(ComparisonOperator::HasLengthEqual),
1273            "HasLengthGreaterThan" => Ok(ComparisonOperator::HasLengthGreaterThan),
1274            "HasLengthLessThan" => Ok(ComparisonOperator::HasLengthLessThan),
1275            "HasLengthGreaterThanOrEqual" => Ok(ComparisonOperator::HasLengthGreaterThanOrEqual),
1276            "HasLengthLessThanOrEqual" => Ok(ComparisonOperator::HasLengthLessThanOrEqual),
1277
1278            // Type Validation
1279            "IsNumeric" => Ok(ComparisonOperator::IsNumeric),
1280            "IsString" => Ok(ComparisonOperator::IsString),
1281            "IsBoolean" => Ok(ComparisonOperator::IsBoolean),
1282            "IsNull" => Ok(ComparisonOperator::IsNull),
1283            "IsArray" => Ok(ComparisonOperator::IsArray),
1284            "IsObject" => Ok(ComparisonOperator::IsObject),
1285
1286            // Pattern & Format
1287            "IsEmail" => Ok(ComparisonOperator::IsEmail),
1288            "IsUrl" => Ok(ComparisonOperator::IsUrl),
1289            "IsUuid" => Ok(ComparisonOperator::IsUuid),
1290            "IsIso8601" => Ok(ComparisonOperator::IsIso8601),
1291            "IsJson" => Ok(ComparisonOperator::IsJson),
1292            "MatchesRegex" => Ok(ComparisonOperator::MatchesRegex),
1293
1294            // Numeric Range
1295            "InRange" => Ok(ComparisonOperator::InRange),
1296            "NotInRange" => Ok(ComparisonOperator::NotInRange),
1297            "IsPositive" => Ok(ComparisonOperator::IsPositive),
1298            "IsNegative" => Ok(ComparisonOperator::IsNegative),
1299            "IsZero" => Ok(ComparisonOperator::IsZero),
1300
1301            // Collection/Array
1302            "ContainsAll" => Ok(ComparisonOperator::ContainsAll),
1303            "ContainsAny" => Ok(ComparisonOperator::ContainsAny),
1304            "ContainsNone" => Ok(ComparisonOperator::ContainsNone),
1305            "IsEmpty" => Ok(ComparisonOperator::IsEmpty),
1306            "IsNotEmpty" => Ok(ComparisonOperator::IsNotEmpty),
1307            "HasUniqueItems" => Ok(ComparisonOperator::HasUniqueItems),
1308            "SequenceMatches" => Ok(ComparisonOperator::SequenceMatches),
1309
1310            // String
1311            "IsAlphabetic" => Ok(ComparisonOperator::IsAlphabetic),
1312            "IsAlphanumeric" => Ok(ComparisonOperator::IsAlphanumeric),
1313            "IsLowerCase" => Ok(ComparisonOperator::IsLowerCase),
1314            "IsUpperCase" => Ok(ComparisonOperator::IsUpperCase),
1315            "ContainsWord" => Ok(ComparisonOperator::ContainsWord),
1316
1317            // Tolerance
1318            "ApproximatelyEquals" => Ok(ComparisonOperator::ApproximatelyEquals),
1319
1320            _ => Err(TypeError::InvalidCompressionTypeError),
1321        }
1322    }
1323}
1324
1325impl ComparisonOperator {
1326    pub fn as_str(&self) -> &str {
1327        match self {
1328            ComparisonOperator::Equals => "Equals",
1329            ComparisonOperator::NotEqual => "NotEqual",
1330            ComparisonOperator::GreaterThan => "GreaterThan",
1331            ComparisonOperator::GreaterThanOrEqual => "GreaterThanOrEqual",
1332            ComparisonOperator::LessThan => "LessThan",
1333            ComparisonOperator::LessThanOrEqual => "LessThanOrEqual",
1334            ComparisonOperator::Contains => "Contains",
1335            ComparisonOperator::NotContains => "NotContains",
1336            ComparisonOperator::StartsWith => "StartsWith",
1337            ComparisonOperator::EndsWith => "EndsWith",
1338            ComparisonOperator::Matches => "Matches",
1339            ComparisonOperator::HasLengthEqual => "HasLengthEqual",
1340            ComparisonOperator::HasLengthGreaterThan => "HasLengthGreaterThan",
1341            ComparisonOperator::HasLengthLessThan => "HasLengthLessThan",
1342            ComparisonOperator::HasLengthGreaterThanOrEqual => "HasLengthGreaterThanOrEqual",
1343            ComparisonOperator::HasLengthLessThanOrEqual => "HasLengthLessThanOrEqual",
1344
1345            // Type Validation
1346            ComparisonOperator::IsNumeric => "IsNumeric",
1347            ComparisonOperator::IsString => "IsString",
1348            ComparisonOperator::IsBoolean => "IsBoolean",
1349            ComparisonOperator::IsNull => "IsNull",
1350            ComparisonOperator::IsArray => "IsArray",
1351            ComparisonOperator::IsObject => "IsObject",
1352
1353            // Pattern & Format
1354            ComparisonOperator::IsEmail => "IsEmail",
1355            ComparisonOperator::IsUrl => "IsUrl",
1356            ComparisonOperator::IsUuid => "IsUuid",
1357            ComparisonOperator::IsIso8601 => "IsIso8601",
1358            ComparisonOperator::IsJson => "IsJson",
1359            ComparisonOperator::MatchesRegex => "MatchesRegex",
1360
1361            // Numeric Range
1362            ComparisonOperator::InRange => "InRange",
1363            ComparisonOperator::NotInRange => "NotInRange",
1364            ComparisonOperator::IsPositive => "IsPositive",
1365            ComparisonOperator::IsNegative => "IsNegative",
1366            ComparisonOperator::IsZero => "IsZero",
1367
1368            // Collection/Array
1369            ComparisonOperator::ContainsAll => "ContainsAll",
1370            ComparisonOperator::ContainsAny => "ContainsAny",
1371            ComparisonOperator::ContainsNone => "ContainsNone",
1372            ComparisonOperator::IsEmpty => "IsEmpty",
1373            ComparisonOperator::IsNotEmpty => "IsNotEmpty",
1374            ComparisonOperator::HasUniqueItems => "HasUniqueItems",
1375            ComparisonOperator::SequenceMatches => "SequenceMatches",
1376
1377            // String
1378            ComparisonOperator::IsAlphabetic => "IsAlphabetic",
1379            ComparisonOperator::IsAlphanumeric => "IsAlphanumeric",
1380            ComparisonOperator::IsLowerCase => "IsLowerCase",
1381            ComparisonOperator::IsUpperCase => "IsUpperCase",
1382            ComparisonOperator::ContainsWord => "ContainsWord",
1383
1384            // Tolerance
1385            ComparisonOperator::ApproximatelyEquals => "ApproximatelyEquals",
1386        }
1387    }
1388}
1389
1390#[pyclass]
1391#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1392pub enum AssertionValue {
1393    String(String),
1394    Number(f64),
1395    Integer(i64),
1396    Boolean(bool),
1397    List(Vec<AssertionValue>),
1398    Null(),
1399}
1400
1401impl AssertionValue {
1402    pub fn to_actual(self, comparison: &ComparisonOperator) -> AssertionValue {
1403        match comparison {
1404            ComparisonOperator::HasLengthEqual
1405            | ComparisonOperator::HasLengthGreaterThan
1406            | ComparisonOperator::HasLengthLessThan
1407            | ComparisonOperator::HasLengthGreaterThanOrEqual
1408            | ComparisonOperator::HasLengthLessThanOrEqual => match self {
1409                AssertionValue::List(arr) => AssertionValue::Integer(arr.len() as i64),
1410                AssertionValue::String(s) => AssertionValue::Integer(s.chars().count() as i64),
1411                _ => self,
1412            },
1413            _ => self,
1414        }
1415    }
1416
1417    pub fn to_serde_value(&self) -> Value {
1418        match self {
1419            AssertionValue::String(s) => Value::String(s.clone()),
1420            AssertionValue::Number(n) => Value::Number(serde_json::Number::from_f64(*n).unwrap()),
1421            AssertionValue::Integer(i) => Value::Number(serde_json::Number::from(*i)),
1422            AssertionValue::Boolean(b) => Value::Bool(*b),
1423            AssertionValue::List(arr) => {
1424                let json_arr: Vec<Value> = arr.iter().map(|v| v.to_serde_value()).collect();
1425                Value::Array(json_arr)
1426            }
1427            AssertionValue::Null() => Value::Null,
1428        }
1429    }
1430}
1431/// Converts a PyAny value to an AssertionValue
1432///
1433/// # Errors
1434///
1435/// Returns `EvaluationError::UnsupportedType` if the Python type cannot be converted
1436/// to an `AssertionValue`.
1437pub fn assertion_value_from_py(value: &Bound<'_, PyAny>) -> Result<AssertionValue, TypeError> {
1438    // Check None first as it's a common case
1439    if value.is_none() {
1440        return Ok(AssertionValue::Null());
1441    }
1442
1443    // Check bool before int (bool is subclass of int in Python)
1444    if value.is_instance_of::<PyBool>() {
1445        return Ok(AssertionValue::Boolean(value.extract()?));
1446    }
1447
1448    if value.is_instance_of::<PyString>() {
1449        return Ok(AssertionValue::String(value.extract()?));
1450    }
1451
1452    if value.is_instance_of::<PyInt>() {
1453        return Ok(AssertionValue::Integer(value.extract()?));
1454    }
1455
1456    if value.is_instance_of::<PyFloat>() {
1457        return Ok(AssertionValue::Number(value.extract()?));
1458    }
1459
1460    if value.is_instance_of::<PyList>() {
1461        // For list, we need to iterate, so one downcast is fine
1462        let list = value.cast::<PyList>()?; // Safe: we just checked
1463        let assertion_list = list
1464            .iter()
1465            .map(|item| assertion_value_from_py(&item))
1466            .collect::<Result<Vec<_>, _>>()?;
1467        return Ok(AssertionValue::List(assertion_list));
1468    }
1469
1470    // Return error for unsupported types
1471    Err(TypeError::UnsupportedType(
1472        value.get_type().name()?.to_string(),
1473    ))
1474}
1475
1476#[pyclass(eq, eq_int)]
1477#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1478pub enum EvaluationTaskType {
1479    Assertion,
1480    LLMJudge,
1481    Conditional,
1482    HumanValidation,
1483    TraceAssertion,
1484}
1485
1486impl Display for EvaluationTaskType {
1487    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1488        let task_type_str = match self {
1489            EvaluationTaskType::Assertion => "Assertion",
1490            EvaluationTaskType::LLMJudge => "LLMJudge",
1491            EvaluationTaskType::Conditional => "Conditional",
1492            EvaluationTaskType::HumanValidation => "HumanValidation",
1493            EvaluationTaskType::TraceAssertion => "TraceAssertion",
1494        };
1495        write!(f, "{}", task_type_str)
1496    }
1497}
1498
1499impl FromStr for EvaluationTaskType {
1500    type Err = TypeError;
1501
1502    fn from_str(s: &str) -> Result<Self, Self::Err> {
1503        match s {
1504            "Assertion" => Ok(EvaluationTaskType::Assertion),
1505            "LLMJudge" => Ok(EvaluationTaskType::LLMJudge),
1506            "Conditional" => Ok(EvaluationTaskType::Conditional),
1507            "HumanValidation" => Ok(EvaluationTaskType::HumanValidation),
1508            "TraceAssertion" => Ok(EvaluationTaskType::TraceAssertion),
1509            _ => Err(TypeError::InvalidEvalType(s.to_string())),
1510        }
1511    }
1512}
1513
1514impl EvaluationTaskType {
1515    pub fn as_str(&self) -> &str {
1516        match self {
1517            EvaluationTaskType::Assertion => "Assertion",
1518            EvaluationTaskType::LLMJudge => "LLMJudge",
1519            EvaluationTaskType::Conditional => "Conditional",
1520            EvaluationTaskType::HumanValidation => "HumanValidation",
1521            EvaluationTaskType::TraceAssertion => "TraceAssertion",
1522        }
1523    }
1524}
1525
1526#[pyclass]
1527#[derive(Debug, Serialize)]
1528pub struct TasksFile {
1529    pub tasks: Vec<TaskConfig>,
1530
1531    #[serde(default)]
1532    index: usize,
1533}
1534
1535#[pymethods]
1536impl TasksFile {
1537    #[staticmethod]
1538    pub fn from_path(path: PathBuf) -> Result<Self, TypeError> {
1539        let tasks_file: TasksFile = deserialize_from_path(path)?;
1540        Ok(tasks_file)
1541    }
1542
1543    pub fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
1544        slf
1545    }
1546
1547    pub fn __next__<'py>(
1548        mut slf: PyRefMut<'py, Self>,
1549    ) -> Result<Option<Bound<'py, PyAny>>, TypeError> {
1550        let py = slf.py();
1551        if slf.index < slf.tasks.len() {
1552            let task = slf.tasks[slf.index].clone().into_bound_py_any(py)?;
1553            slf.index += 1;
1554            Ok(Some(task))
1555        } else {
1556            Ok(None)
1557        }
1558    }
1559
1560    fn __getitem__<'py>(
1561        &self,
1562        py: Python<'py>,
1563        index: &Bound<'py, PyAny>,
1564    ) -> Result<Bound<'py, PyAny>, TypeError> {
1565        if let Ok(i) = index.extract::<isize>() {
1566            let len = self.tasks.len() as isize;
1567            let actual_index = if i < 0 { len + i } else { i };
1568
1569            if actual_index < 0 || actual_index >= len {
1570                return Err(TypeError::IndexOutOfBounds {
1571                    index: i,
1572                    length: self.tasks.len(),
1573                });
1574            }
1575
1576            Ok(self.tasks[actual_index as usize]
1577                .clone()
1578                .into_bound_py_any(py)?)
1579        } else if let Ok(slice) = index.cast::<PySlice>() {
1580            let indices = slice.indices(self.tasks.len() as isize)?;
1581            let result = PyList::empty(py);
1582
1583            let mut i = indices.start;
1584            while (indices.step > 0 && i < indices.stop) || (indices.step < 0 && i > indices.stop) {
1585                result.append(self.tasks[i as usize].clone().into_bound_py_any(py)?)?;
1586                i += indices.step;
1587            }
1588
1589            Ok(result.into_bound_py_any(py)?)
1590        } else {
1591            Err(TypeError::IndexOrSliceExpected)
1592        }
1593    }
1594
1595    fn __len__(&self) -> usize {
1596        self.tasks.len()
1597    }
1598
1599    fn __str__(&self) -> String {
1600        PyHelperFuncs::__str__(self)
1601    }
1602}
1603
1604#[derive(Debug, Serialize, Clone)]
1605pub enum TaskConfig {
1606    Assertion(AssertionTask),
1607    #[serde(rename = "LLMJudge")]
1608    LLMJudge(Box<LLMJudgeTask>),
1609    TraceAssertion(TraceAssertionTask),
1610}
1611
1612impl TaskConfig {
1613    fn into_bound_py_any<'py>(self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
1614        match self {
1615            TaskConfig::Assertion(task) => Ok(task.into_bound_py_any(py)?),
1616            TaskConfig::LLMJudge(task) => Ok(task.into_bound_py_any(py)?),
1617            TaskConfig::TraceAssertion(task) => Ok(task.into_bound_py_any(py)?),
1618        }
1619    }
1620}
1621
1622impl<'de> Deserialize<'de> for TasksFile {
1623    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1624    where
1625        D: serde::Deserializer<'de>,
1626    {
1627        #[derive(Deserialize)]
1628        #[serde(untagged)]
1629        enum TasksFileRaw {
1630            Direct(Vec<TaskConfigRaw>),
1631            Wrapped { tasks: Vec<TaskConfigRaw> },
1632        }
1633
1634        #[derive(Deserialize)]
1635        struct TaskConfigRaw {
1636            task_type: EvaluationTaskType,
1637            #[serde(flatten)]
1638            data: Value,
1639        }
1640
1641        let raw = TasksFileRaw::deserialize(deserializer)?;
1642        let raw_tasks = match raw {
1643            TasksFileRaw::Direct(tasks) => tasks,
1644            TasksFileRaw::Wrapped { tasks } => tasks,
1645        };
1646
1647        let mut tasks = Vec::new();
1648
1649        for task_raw in raw_tasks {
1650            let task_config = match task_raw.task_type {
1651                EvaluationTaskType::Assertion => {
1652                    let mut task: AssertionTask =
1653                        serde_json::from_value(task_raw.data).map_err(|e| {
1654                            error!("Failed to deserialize AssertionTask: {}", e);
1655                            serde::de::Error::custom(e.to_string())
1656                        })?;
1657                    task.task_type = EvaluationTaskType::Assertion;
1658                    TaskConfig::Assertion(task)
1659                }
1660                EvaluationTaskType::LLMJudge => {
1661                    let mut task: LLMJudgeTask =
1662                        serde_json::from_value(task_raw.data).map_err(|e| {
1663                            error!("Failed to deserialize LLMJudgeTask: {}", e);
1664                            serde::de::Error::custom(e.to_string())
1665                        })?;
1666                    task.task_type = EvaluationTaskType::LLMJudge;
1667                    TaskConfig::LLMJudge(Box::new(task))
1668                }
1669                EvaluationTaskType::TraceAssertion => {
1670                    let mut task: TraceAssertionTask = serde_json::from_value(task_raw.data)
1671                        .map_err(|e| {
1672                            error!("Failed to deserialize TraceAssertionTask: {}", e);
1673                            serde::de::Error::custom(e.to_string())
1674                        })?;
1675                    task.task_type = EvaluationTaskType::TraceAssertion;
1676                    TaskConfig::TraceAssertion(task)
1677                }
1678                _ => {
1679                    return Err(serde::de::Error::custom(format!(
1680                        "Unknown task_type: {}",
1681                        task_raw.task_type
1682                    )))
1683                }
1684            };
1685            tasks.push(task_config);
1686        }
1687
1688        Ok(TasksFile { tasks, index: 0 })
1689    }
1690}