scouter_types/genai/
eval.rs

1use crate::error::TypeError;
2use crate::genai::traits::TaskAccessor;
3use crate::PyHelperFuncs;
4use core::fmt::Debug;
5use potato_head::prompt_types::Prompt;
6use pyo3::prelude::*;
7use pyo3::types::{PyBool, PyFloat, PyInt, PyList, PyString};
8use pythonize::{depythonize, pythonize};
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use std::fmt::Display;
12use std::str::FromStr;
13
14#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
15pub struct AssertionResult {
16    pub passed: bool,
17    pub actual: Value,
18    pub message: String,
19    pub expected: Value,
20}
21
22impl AssertionResult {
23    pub fn new(passed: bool, actual: Value, message: String, expected: Value) -> Self {
24        Self {
25            passed,
26            actual,
27            message,
28            expected,
29        }
30    }
31    pub fn to_metric_value(&self) -> f64 {
32        if self.passed {
33            1.0
34        } else {
35            0.0
36        }
37    }
38}
39
40#[pyclass]
41#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
42pub struct AssertionTask {
43    #[pyo3(get, set)]
44    pub id: String,
45
46    #[pyo3(get, set)]
47    pub field_path: Option<String>,
48
49    #[pyo3(get, set)]
50    pub operator: ComparisonOperator,
51
52    pub expected_value: Value,
53
54    #[pyo3(get, set)]
55    pub description: Option<String>,
56
57    #[pyo3(get, set)]
58    pub depends_on: Vec<String>,
59
60    pub task_type: EvaluationTaskType,
61
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub result: Option<AssertionResult>,
64
65    pub condition: bool,
66}
67
68#[pymethods]
69impl AssertionTask {
70    #[new]
71    /// Creates a new AssertionTask
72    /// # Examples
73    /// ```python
74    ///
75    /// # assumed passed context at runtime
76    /// # context = {
77    /// #     "response": {
78    /// #         "user": {
79    /// #             "age": 25
80    /// #         }
81    /// #     }
82    /// # }
83    ///
84    /// task = AssertionTask(
85    ///     id="Check User Age",
86    ///     field_path="response.user.age",
87    ///     operator=ComparisonOperator.GREATER_THAN,
88    ///     expected_value=18,
89    ///     description="Check if user is an adult"
90    /// )
91    ///
92    /// # assumed passed context at runtime
93    /// # context = {
94    /// #     "user": {
95    /// #         "age": 25
96    /// #     }
97    /// # }
98    ///
99    /// task = AssertionTask(
100    ///     id="Check User Age",
101    ///     field_path="user.age",
102    ///     operator=ComparisonOperator.GREATER_THAN,
103    ///     expected_value=18,
104    ///     description="Check if user is an adult"
105    /// )
106    ///
107    ///  /// # assume non-map context at runtime
108    /// # context = 25
109    ///
110    /// task = AssertionTask(
111    ///     id="Check User Age",
112    ///     operator=ComparisonOperator.GREATER_THAN,
113    ///     expected_value=18,
114    ///     description="Check if user is an adult"
115    /// )
116    /// ```
117    /// # Arguments
118    /// * `field_path`: The path to the field to be asserted
119    /// * `operator`: The comparison operator to use
120    /// * `expected_value`: The expected value for the assertion
121    /// * `description`: Optional description for the assertion
122    /// # Returns
123    /// A new AssertionTask object
124    #[pyo3(signature = (id, field_path, expected_value, operator, description=None, depends_on=None, condition=None))]
125    pub fn new(
126        id: String,
127        field_path: Option<String>,
128        expected_value: &Bound<'_, PyAny>,
129        operator: ComparisonOperator,
130        description: Option<String>,
131        depends_on: Option<Vec<String>>,
132        condition: Option<bool>,
133    ) -> Result<Self, TypeError> {
134        let expected_value = depythonize(expected_value)?;
135        let condition = condition.unwrap_or(false);
136
137        Ok(Self {
138            id: id.to_lowercase(),
139            field_path,
140            operator,
141            expected_value,
142            description,
143            task_type: EvaluationTaskType::Assertion,
144            depends_on: depends_on.unwrap_or_default(),
145            result: None,
146            condition,
147        })
148    }
149
150    #[getter]
151    pub fn get_expected_value<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
152        let py_value = pythonize(py, &self.expected_value)?;
153        Ok(py_value)
154    }
155}
156
157impl TaskAccessor for AssertionTask {
158    fn field_path(&self) -> Option<&str> {
159        self.field_path.as_deref()
160    }
161
162    fn id(&self) -> &str {
163        &self.id
164    }
165
166    fn operator(&self) -> &ComparisonOperator {
167        &self.operator
168    }
169
170    fn task_type(&self) -> &EvaluationTaskType {
171        &self.task_type
172    }
173
174    fn expected_value(&self) -> &Value {
175        &self.expected_value
176    }
177
178    fn depends_on(&self) -> &[String] {
179        &self.depends_on
180    }
181
182    fn add_result(&mut self, result: AssertionResult) {
183        self.result = Some(result);
184    }
185}
186
187pub trait ValueExt {
188    /// Convert value to length for HasLength comparisons
189    fn to_length(&self) -> Option<i64>;
190
191    /// Extract numeric value for comparison
192    fn as_numeric(&self) -> Option<f64>;
193
194    /// Check if value is truthy
195    fn is_truthy(&self) -> bool;
196}
197
198impl ValueExt for Value {
199    fn to_length(&self) -> Option<i64> {
200        match self {
201            Value::Array(arr) => Some(arr.len() as i64),
202            Value::String(s) => Some(s.chars().count() as i64),
203            Value::Object(obj) => Some(obj.len() as i64),
204            _ => None,
205        }
206    }
207
208    fn as_numeric(&self) -> Option<f64> {
209        match self {
210            Value::Number(n) => n.as_f64(),
211            _ => None,
212        }
213    }
214
215    fn is_truthy(&self) -> bool {
216        match self {
217            Value::Null => false,
218            Value::Bool(b) => *b,
219            Value::Number(n) => n.as_f64() != Some(0.0),
220            Value::String(s) => !s.is_empty(),
221            Value::Array(arr) => !arr.is_empty(),
222            Value::Object(obj) => !obj.is_empty(),
223        }
224    }
225}
226
227/// Primary class for defining an LLM as a Judge in evaluation workflows
228#[pyclass]
229#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
230pub struct LLMJudgeTask {
231    #[pyo3(get, set)]
232    pub id: String,
233
234    #[pyo3(get)]
235    pub prompt: Prompt,
236
237    #[pyo3(get)]
238    pub field_path: Option<String>,
239
240    pub expected_value: Value,
241
242    #[pyo3(get)]
243    pub operator: ComparisonOperator,
244
245    pub task_type: EvaluationTaskType,
246
247    #[pyo3(get, set)]
248    pub depends_on: Vec<String>,
249
250    #[pyo3(get, set)]
251    pub max_retries: Option<u32>,
252
253    #[serde(skip_serializing_if = "Option::is_none")]
254    pub result: Option<AssertionResult>,
255
256    pub description: Option<String>,
257
258    pub condition: bool,
259}
260
261#[pymethods]
262impl LLMJudgeTask {
263    /// Creates a new LLMJudgeTask
264    /// # Examples
265    /// ```python
266    /// task = LLMJudgeTask(
267    ///     id="Sentiment Analysis Judge",
268    ///     prompt=prompt_object,
269    ///     expected_value="Positive",
270    ///     operator=ComparisonOperator.EQUALS
271    /// )
272    /// # Arguments
273    /// * `id: The id of the judge task
274    /// * `prompt`: The prompt object to be used for evaluation
275    /// * `expected_value`: The expected value for the judgement
276    /// * `field_path`: Optional field path to extract from the context for evaluation
277    /// * `operator`: The comparison operator to use
278    /// * `depends_on`: Optional list of task IDs this task depends on
279    /// * `max_retries`: Optional maximum number of retries for this task (defaults to 3 if not provided)
280    /// # Returns
281    /// A new LLMJudgeTask object
282    #[new]
283    #[pyo3(signature = (id, prompt, expected_value,  field_path,operator, description=None, depends_on=None, max_retries=None, condition=None))]
284    #[allow(clippy::too_many_arguments)]
285    pub fn new(
286        id: &str,
287        prompt: Prompt,
288        expected_value: &Bound<'_, PyAny>,
289        field_path: Option<String>,
290        operator: ComparisonOperator,
291        description: Option<String>,
292        depends_on: Option<Vec<String>>,
293        max_retries: Option<u32>,
294        condition: Option<bool>,
295    ) -> Result<Self, TypeError> {
296        let expected_value = depythonize(expected_value)?;
297
298        Ok(Self {
299            id: id.to_lowercase(),
300            prompt,
301            expected_value,
302            operator,
303            task_type: EvaluationTaskType::LLMJudge,
304            depends_on: depends_on.unwrap_or_default(),
305            max_retries: max_retries.or(Some(3)),
306            field_path,
307            result: None,
308            description,
309            condition: condition.unwrap_or(false),
310        })
311    }
312
313    pub fn __str__(&self) -> String {
314        // serialize the struct to a string
315        PyHelperFuncs::__str__(self)
316    }
317
318    #[getter]
319    pub fn get_expected_value<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
320        let py_value = pythonize(py, &self.expected_value)?;
321        Ok(py_value)
322    }
323}
324
325impl LLMJudgeTask {
326    /// Creates a new LLMJudgeTask with Rust types
327    /// # Arguments
328    /// * `id: The id of the judge task
329    /// * `prompt`: The prompt object to be used for evaluation
330    /// * `expected_value`: The expected value for the judgement
331    /// * `field_path`: Optional field path to extract from the context for evaluation
332    /// * `operator`: The comparison operator to use
333    /// * `depends_on`: Optional list of task IDs this task depends on
334    /// * `max_retries`: Optional maximum number of retries for this task (defaults to 3 if not provided)
335    /// # Returns
336    /// A new LLMJudgeTask object
337    #[allow(clippy::too_many_arguments)]
338    pub fn new_rs(
339        id: &str,
340        prompt: Prompt,
341        expected_value: Value,
342        field_path: Option<String>,
343        operator: ComparisonOperator,
344        depends_on: Option<Vec<String>>,
345        max_retries: Option<u32>,
346        description: Option<String>,
347        condition: Option<bool>,
348    ) -> Self {
349        Self {
350            id: id.to_lowercase(),
351            prompt,
352            expected_value,
353            operator,
354            task_type: EvaluationTaskType::LLMJudge,
355            depends_on: depends_on.unwrap_or_default(),
356            max_retries: max_retries.or(Some(3)),
357            field_path,
358            result: None,
359            description,
360            condition: condition.unwrap_or(false),
361        }
362    }
363}
364
365impl TaskAccessor for LLMJudgeTask {
366    fn field_path(&self) -> Option<&str> {
367        self.field_path.as_deref()
368    }
369
370    fn id(&self) -> &str {
371        &self.id
372    }
373    fn task_type(&self) -> &EvaluationTaskType {
374        &self.task_type
375    }
376
377    fn operator(&self) -> &ComparisonOperator {
378        &self.operator
379    }
380
381    fn expected_value(&self) -> &Value {
382        &self.expected_value
383    }
384    fn depends_on(&self) -> &[String] {
385        &self.depends_on
386    }
387    fn add_result(&mut self, result: AssertionResult) {
388        self.result = Some(result);
389    }
390}
391
392#[derive(Debug, Clone)]
393pub enum EvaluationTask {
394    Assertion(Box<AssertionTask>),
395    LLMJudge(Box<LLMJudgeTask>),
396}
397
398impl TaskAccessor for EvaluationTask {
399    fn field_path(&self) -> Option<&str> {
400        match self {
401            EvaluationTask::Assertion(t) => t.field_path(),
402            EvaluationTask::LLMJudge(t) => t.field_path(),
403        }
404    }
405
406    fn id(&self) -> &str {
407        match self {
408            EvaluationTask::Assertion(t) => t.id(),
409            EvaluationTask::LLMJudge(t) => t.id(),
410        }
411    }
412
413    fn task_type(&self) -> &EvaluationTaskType {
414        match self {
415            EvaluationTask::Assertion(t) => t.task_type(),
416            EvaluationTask::LLMJudge(t) => t.task_type(),
417        }
418    }
419
420    fn operator(&self) -> &ComparisonOperator {
421        match self {
422            EvaluationTask::Assertion(t) => t.operator(),
423            EvaluationTask::LLMJudge(t) => t.operator(),
424        }
425    }
426
427    fn expected_value(&self) -> &Value {
428        match self {
429            EvaluationTask::Assertion(t) => t.expected_value(),
430            EvaluationTask::LLMJudge(t) => t.expected_value(),
431        }
432    }
433
434    fn depends_on(&self) -> &[String] {
435        match self {
436            EvaluationTask::Assertion(t) => t.depends_on(),
437            EvaluationTask::LLMJudge(t) => t.depends_on(),
438        }
439    }
440
441    fn add_result(&mut self, result: AssertionResult) {
442        match self {
443            EvaluationTask::Assertion(t) => t.add_result(result),
444            EvaluationTask::LLMJudge(t) => t.add_result(result),
445        }
446    }
447}
448
449pub struct EvaluationTasks(Vec<EvaluationTask>);
450
451impl EvaluationTasks {
452    /// Creates a new empty builder
453    pub fn new() -> Self {
454        Self(Vec::new())
455    }
456
457    /// Generic method that accepts anything implementing Into<EvaluationTask>
458    pub fn add_task(mut self, task: impl Into<EvaluationTask>) -> Self {
459        self.0.push(task.into());
460        self
461    }
462
463    /// Builds and returns the Vec<EvaluationTask>
464    pub fn build(self) -> Vec<EvaluationTask> {
465        self.0
466    }
467}
468
469// Implement From trait for automatic conversion
470impl From<AssertionTask> for EvaluationTask {
471    fn from(task: AssertionTask) -> Self {
472        EvaluationTask::Assertion(Box::new(task))
473    }
474}
475
476impl From<LLMJudgeTask> for EvaluationTask {
477    fn from(task: LLMJudgeTask) -> Self {
478        EvaluationTask::LLMJudge(Box::new(task))
479    }
480}
481
482impl Default for EvaluationTasks {
483    fn default() -> Self {
484        Self::new()
485    }
486}
487
488#[pyclass]
489#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
490pub enum ComparisonOperator {
491    // Existing operators
492    Equals,
493    NotEqual,
494    GreaterThan,
495    GreaterThanOrEqual,
496    LessThan,
497    LessThanOrEqual,
498    Contains,
499    NotContains,
500    StartsWith,
501    EndsWith,
502    Matches,
503    HasLengthGreaterThan,
504    HasLengthLessThan,
505    HasLengthEqual,
506    HasLengthGreaterThanOrEqual,
507    HasLengthLessThanOrEqual,
508
509    // Type Validation Operators
510    IsNumeric,
511    IsString,
512    IsBoolean,
513    IsNull,
514    IsArray,
515    IsObject,
516
517    // Pattern & Format Validators
518    IsEmail,
519    IsUrl,
520    IsUuid,
521    IsIso8601,
522    IsJson,
523    MatchesRegex,
524
525    // Numeric Range Operators
526    InRange,
527    NotInRange,
528    IsPositive,
529    IsNegative,
530    IsZero,
531
532    // Collection/Array Operators
533    ContainsAll,
534    ContainsAny,
535    ContainsNone,
536    IsEmpty,
537    IsNotEmpty,
538    HasUniqueItems,
539
540    // String Operators
541    IsAlphabetic,
542    IsAlphanumeric,
543    IsLowerCase,
544    IsUpperCase,
545    ContainsWord,
546
547    // Comparison with Tolerance
548    ApproximatelyEquals,
549}
550
551impl Display for ComparisonOperator {
552    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
553        write!(f, "{}", self.as_str())
554    }
555}
556
557impl FromStr for ComparisonOperator {
558    type Err = TypeError;
559
560    fn from_str(s: &str) -> Result<Self, Self::Err> {
561        match s {
562            "Equals" => Ok(ComparisonOperator::Equals),
563            "NotEqual" => Ok(ComparisonOperator::NotEqual),
564            "GreaterThan" => Ok(ComparisonOperator::GreaterThan),
565            "GreaterThanOrEqual" => Ok(ComparisonOperator::GreaterThanOrEqual),
566            "LessThan" => Ok(ComparisonOperator::LessThan),
567            "LessThanOrEqual" => Ok(ComparisonOperator::LessThanOrEqual),
568            "Contains" => Ok(ComparisonOperator::Contains),
569            "NotContains" => Ok(ComparisonOperator::NotContains),
570            "StartsWith" => Ok(ComparisonOperator::StartsWith),
571            "EndsWith" => Ok(ComparisonOperator::EndsWith),
572            "Matches" => Ok(ComparisonOperator::Matches),
573            "HasLengthEqual" => Ok(ComparisonOperator::HasLengthEqual),
574            "HasLengthGreaterThan" => Ok(ComparisonOperator::HasLengthGreaterThan),
575            "HasLengthLessThan" => Ok(ComparisonOperator::HasLengthLessThan),
576            "HasLengthGreaterThanOrEqual" => Ok(ComparisonOperator::HasLengthGreaterThanOrEqual),
577            "HasLengthLessThanOrEqual" => Ok(ComparisonOperator::HasLengthLessThanOrEqual),
578
579            // Type Validation
580            "IsNumeric" => Ok(ComparisonOperator::IsNumeric),
581            "IsString" => Ok(ComparisonOperator::IsString),
582            "IsBoolean" => Ok(ComparisonOperator::IsBoolean),
583            "IsNull" => Ok(ComparisonOperator::IsNull),
584            "IsArray" => Ok(ComparisonOperator::IsArray),
585            "IsObject" => Ok(ComparisonOperator::IsObject),
586
587            // Pattern & Format
588            "IsEmail" => Ok(ComparisonOperator::IsEmail),
589            "IsUrl" => Ok(ComparisonOperator::IsUrl),
590            "IsUuid" => Ok(ComparisonOperator::IsUuid),
591            "IsIso8601" => Ok(ComparisonOperator::IsIso8601),
592            "IsJson" => Ok(ComparisonOperator::IsJson),
593            "MatchesRegex" => Ok(ComparisonOperator::MatchesRegex),
594
595            // Numeric Range
596            "InRange" => Ok(ComparisonOperator::InRange),
597            "NotInRange" => Ok(ComparisonOperator::NotInRange),
598            "IsPositive" => Ok(ComparisonOperator::IsPositive),
599            "IsNegative" => Ok(ComparisonOperator::IsNegative),
600            "IsZero" => Ok(ComparisonOperator::IsZero),
601
602            // Collection/Array
603            "ContainsAll" => Ok(ComparisonOperator::ContainsAll),
604            "ContainsAny" => Ok(ComparisonOperator::ContainsAny),
605            "ContainsNone" => Ok(ComparisonOperator::ContainsNone),
606            "IsEmpty" => Ok(ComparisonOperator::IsEmpty),
607            "IsNotEmpty" => Ok(ComparisonOperator::IsNotEmpty),
608            "HasUniqueItems" => Ok(ComparisonOperator::HasUniqueItems),
609
610            // String
611            "IsAlphabetic" => Ok(ComparisonOperator::IsAlphabetic),
612            "IsAlphanumeric" => Ok(ComparisonOperator::IsAlphanumeric),
613            "IsLowerCase" => Ok(ComparisonOperator::IsLowerCase),
614            "IsUpperCase" => Ok(ComparisonOperator::IsUpperCase),
615            "ContainsWord" => Ok(ComparisonOperator::ContainsWord),
616
617            // Tolerance
618            "ApproximatelyEquals" => Ok(ComparisonOperator::ApproximatelyEquals),
619
620            _ => Err(TypeError::InvalidCompressionTypeError),
621        }
622    }
623}
624
625impl ComparisonOperator {
626    pub fn as_str(&self) -> &str {
627        match self {
628            ComparisonOperator::Equals => "Equals",
629            ComparisonOperator::NotEqual => "NotEqual",
630            ComparisonOperator::GreaterThan => "GreaterThan",
631            ComparisonOperator::GreaterThanOrEqual => "GreaterThanOrEqual",
632            ComparisonOperator::LessThan => "LessThan",
633            ComparisonOperator::LessThanOrEqual => "LessThanOrEqual",
634            ComparisonOperator::Contains => "Contains",
635            ComparisonOperator::NotContains => "NotContains",
636            ComparisonOperator::StartsWith => "StartsWith",
637            ComparisonOperator::EndsWith => "EndsWith",
638            ComparisonOperator::Matches => "Matches",
639            ComparisonOperator::HasLengthEqual => "HasLengthEqual",
640            ComparisonOperator::HasLengthGreaterThan => "HasLengthGreaterThan",
641            ComparisonOperator::HasLengthLessThan => "HasLengthLessThan",
642            ComparisonOperator::HasLengthGreaterThanOrEqual => "HasLengthGreaterThanOrEqual",
643            ComparisonOperator::HasLengthLessThanOrEqual => "HasLengthLessThanOrEqual",
644
645            // Type Validation
646            ComparisonOperator::IsNumeric => "IsNumeric",
647            ComparisonOperator::IsString => "IsString",
648            ComparisonOperator::IsBoolean => "IsBoolean",
649            ComparisonOperator::IsNull => "IsNull",
650            ComparisonOperator::IsArray => "IsArray",
651            ComparisonOperator::IsObject => "IsObject",
652
653            // Pattern & Format
654            ComparisonOperator::IsEmail => "IsEmail",
655            ComparisonOperator::IsUrl => "IsUrl",
656            ComparisonOperator::IsUuid => "IsUuid",
657            ComparisonOperator::IsIso8601 => "IsIso8601",
658            ComparisonOperator::IsJson => "IsJson",
659            ComparisonOperator::MatchesRegex => "MatchesRegex",
660
661            // Numeric Range
662            ComparisonOperator::InRange => "InRange",
663            ComparisonOperator::NotInRange => "NotInRange",
664            ComparisonOperator::IsPositive => "IsPositive",
665            ComparisonOperator::IsNegative => "IsNegative",
666            ComparisonOperator::IsZero => "IsZero",
667
668            // Collection/Array
669            ComparisonOperator::ContainsAll => "ContainsAll",
670            ComparisonOperator::ContainsAny => "ContainsAny",
671            ComparisonOperator::ContainsNone => "ContainsNone",
672            ComparisonOperator::IsEmpty => "IsEmpty",
673            ComparisonOperator::IsNotEmpty => "IsNotEmpty",
674            ComparisonOperator::HasUniqueItems => "HasUniqueItems",
675
676            // String
677            ComparisonOperator::IsAlphabetic => "IsAlphabetic",
678            ComparisonOperator::IsAlphanumeric => "IsAlphanumeric",
679            ComparisonOperator::IsLowerCase => "IsLowerCase",
680            ComparisonOperator::IsUpperCase => "IsUpperCase",
681            ComparisonOperator::ContainsWord => "ContainsWord",
682
683            // Tolerance
684            ComparisonOperator::ApproximatelyEquals => "ApproximatelyEquals",
685        }
686    }
687}
688
689#[pyclass]
690#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
691pub enum AssertionValue {
692    String(String),
693    Number(f64),
694    Integer(i64),
695    Boolean(bool),
696    List(Vec<AssertionValue>),
697    Null(),
698}
699
700impl AssertionValue {
701    pub fn to_actual(self, comparison: &ComparisonOperator) -> AssertionValue {
702        match comparison {
703            ComparisonOperator::HasLengthEqual
704            | ComparisonOperator::HasLengthGreaterThan
705            | ComparisonOperator::HasLengthLessThan
706            | ComparisonOperator::HasLengthGreaterThanOrEqual
707            | ComparisonOperator::HasLengthLessThanOrEqual => match self {
708                AssertionValue::List(arr) => AssertionValue::Integer(arr.len() as i64),
709                AssertionValue::String(s) => AssertionValue::Integer(s.chars().count() as i64),
710                _ => self,
711            },
712            _ => self,
713        }
714    }
715
716    pub fn to_serde_value(&self) -> Value {
717        match self {
718            AssertionValue::String(s) => Value::String(s.clone()),
719            AssertionValue::Number(n) => Value::Number(serde_json::Number::from_f64(*n).unwrap()),
720            AssertionValue::Integer(i) => Value::Number(serde_json::Number::from(*i)),
721            AssertionValue::Boolean(b) => Value::Bool(*b),
722            AssertionValue::List(arr) => {
723                let json_arr: Vec<Value> = arr.iter().map(|v| v.to_serde_value()).collect();
724                Value::Array(json_arr)
725            }
726            AssertionValue::Null() => Value::Null,
727        }
728    }
729}
730/// Converts a PyAny value to an AssertionValue
731///
732/// # Errors
733///
734/// Returns `EvaluationError::UnsupportedType` if the Python type cannot be converted
735/// to an `AssertionValue`.
736pub fn assertion_value_from_py(value: &Bound<'_, PyAny>) -> Result<AssertionValue, TypeError> {
737    // Check None first as it's a common case
738    if value.is_none() {
739        return Ok(AssertionValue::Null());
740    }
741
742    // Check bool before int (bool is subclass of int in Python)
743    if value.is_instance_of::<PyBool>() {
744        return Ok(AssertionValue::Boolean(value.extract()?));
745    }
746
747    if value.is_instance_of::<PyString>() {
748        return Ok(AssertionValue::String(value.extract()?));
749    }
750
751    if value.is_instance_of::<PyInt>() {
752        return Ok(AssertionValue::Integer(value.extract()?));
753    }
754
755    if value.is_instance_of::<PyFloat>() {
756        return Ok(AssertionValue::Number(value.extract()?));
757    }
758
759    if value.is_instance_of::<PyList>() {
760        // For list, we need to iterate, so one downcast is fine
761        let list = value.cast::<PyList>()?; // Safe: we just checked
762        let assertion_list = list
763            .iter()
764            .map(|item| assertion_value_from_py(&item))
765            .collect::<Result<Vec<_>, _>>()?;
766        return Ok(AssertionValue::List(assertion_list));
767    }
768
769    // Return error for unsupported types
770    Err(TypeError::UnsupportedType(
771        value.get_type().name()?.to_string(),
772    ))
773}
774
775#[pyclass]
776#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
777pub enum EvaluationTaskType {
778    Assertion,
779    LLMJudge,
780    Conditional,
781    HumanValidation,
782}
783
784impl Display for EvaluationTaskType {
785    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
786        let task_type_str = match self {
787            EvaluationTaskType::Assertion => "Assertion",
788            EvaluationTaskType::LLMJudge => "LLMJudge",
789            EvaluationTaskType::Conditional => "Conditional",
790            EvaluationTaskType::HumanValidation => "HumanValidation",
791        };
792        write!(f, "{}", task_type_str)
793    }
794}
795
796impl FromStr for EvaluationTaskType {
797    type Err = TypeError;
798
799    fn from_str(s: &str) -> Result<Self, Self::Err> {
800        match s {
801            "Assertion" => Ok(EvaluationTaskType::Assertion),
802            "LLMJudge" => Ok(EvaluationTaskType::LLMJudge),
803            "Conditional" => Ok(EvaluationTaskType::Conditional),
804            "HumanValidation" => Ok(EvaluationTaskType::HumanValidation),
805            _ => Err(TypeError::InvalidEvalType(s.to_string())),
806        }
807    }
808}
809
810impl EvaluationTaskType {
811    pub fn as_str(&self) -> &str {
812        match self {
813            EvaluationTaskType::Assertion => "Assertion",
814            EvaluationTaskType::LLMJudge => "LLMJudge",
815            EvaluationTaskType::Conditional => "Conditional",
816            EvaluationTaskType::HumanValidation => "HumanValidation",
817        }
818    }
819}