1use crate::error::TypeError;
2use crate::genai::traits::TaskAccessor;
3use crate::PyHelperFuncs;
4use core::fmt::Debug;
5use potato_head::prompt_types::Prompt;
6use potato_head::Provider;
7use pyo3::prelude::*;
8use pyo3::types::{PyBool, PyFloat, PyInt, PyList, PySlice, PyString};
9use pyo3::IntoPyObjectExt;
10use pythonize::{depythonize, pythonize};
11use serde::de::DeserializeOwned;
12use serde::{Deserialize, Serialize};
13use serde_json::Value;
14use std::collections::HashMap;
15use std::fmt::Display;
16use std::path::PathBuf;
17use std::str::FromStr;
18use tracing::error;
19
20pub fn deserialize_from_path<T: DeserializeOwned>(path: PathBuf) -> Result<T, TypeError> {
21 let content = std::fs::read_to_string(&path)?;
22
23 let extension = path
24 .extension()
25 .and_then(|ext| ext.to_str())
26 .ok_or_else(|| TypeError::Error(format!("Invalid file path: {:?}", path)))?;
27
28 let item = match extension.to_lowercase().as_str() {
29 "json" => serde_json::from_str(&content)?,
30 "yaml" | "yml" => serde_yaml::from_str(&content)?,
31 _ => {
32 return Err(TypeError::Error(format!(
33 "Unsupported file extension '{}'. Expected .json, .yaml, or .yml",
34 extension
35 )))
36 }
37 };
38
39 Ok(item)
40}
41
42fn default_assertion_task_type() -> EvaluationTaskType {
44 EvaluationTaskType::Assertion
45}
46
47fn default_trace_assertion_task_type() -> EvaluationTaskType {
48 EvaluationTaskType::TraceAssertion
49}
50
51fn default_agent_assertion_task_type() -> EvaluationTaskType {
52 EvaluationTaskType::AgentAssertion
53}
54
55#[pyclass]
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
57pub struct AssertionResult {
58 #[pyo3(get)]
59 pub passed: bool,
60 pub actual: Value,
61
62 #[pyo3(get)]
63 pub message: String,
64
65 pub expected: Value,
66}
67
68impl AssertionResult {
69 pub fn new(passed: bool, actual: Value, message: String, expected: Value) -> Self {
70 Self {
71 passed,
72 actual,
73 message,
74 expected,
75 }
76 }
77 pub fn to_metric_value(&self) -> f64 {
78 if self.passed {
79 1.0
80 } else {
81 0.0
82 }
83 }
84}
85
86#[pymethods]
87impl AssertionResult {
88 pub fn __str__(&self) -> String {
89 PyHelperFuncs::__str__(self)
90 }
91
92 #[getter]
93 pub fn get_actual<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
94 let py_value = pythonize(py, &self.actual)?;
95 Ok(py_value)
96 }
97
98 #[getter]
99 pub fn get_expected<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
100 let py_value = pythonize(py, &self.expected)?;
101 Ok(py_value)
102 }
103}
104
105#[pyclass]
106#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
107pub struct AssertionResults {
108 #[pyo3(get)]
109 pub results: HashMap<String, AssertionResult>,
110}
111
112#[pymethods]
113impl AssertionResults {
114 pub fn __str__(&self) -> String {
115 PyHelperFuncs::__str__(self)
116 }
117
118 pub fn __getitem__(&self, key: &str) -> Result<AssertionResult, TypeError> {
119 if let Some(result) = self.results.get(key) {
120 Ok(result.clone())
121 } else {
122 Err(TypeError::KeyNotFound {
123 key: key.to_string(),
124 })
125 }
126 }
127}
128
129#[pyclass]
130#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
131pub struct AssertionTask {
132 #[pyo3(get, set)]
133 pub id: String,
134
135 #[pyo3(get, set)]
136 #[serde(default)]
137 pub context_path: Option<String>,
138
139 #[pyo3(get, set)]
140 #[serde(default)]
141 pub item_context_path: Option<String>,
142
143 #[pyo3(get, set)]
144 pub operator: ComparisonOperator,
145
146 pub expected_value: Value,
147
148 #[pyo3(get, set)]
149 #[serde(default)]
150 pub description: Option<String>,
151
152 #[pyo3(get, set)]
153 #[serde(default)]
154 pub depends_on: Vec<String>,
155
156 #[serde(default = "default_assertion_task_type")]
157 #[pyo3(get)]
158 pub task_type: EvaluationTaskType,
159
160 #[serde(skip_serializing_if = "Option::is_none")]
161 pub result: Option<AssertionResult>,
162
163 #[serde(default)]
164 pub condition: bool,
165}
166
167#[pymethods]
168impl AssertionTask {
169 #[new]
170 #[allow(clippy::too_many_arguments)]
224 #[pyo3(signature = (id, context_path, expected_value, operator, item_context_path=None, description=None, depends_on=None, condition=None))]
225 pub fn new(
226 id: String,
227 context_path: Option<String>,
228 expected_value: &Bound<'_, PyAny>,
229 operator: ComparisonOperator,
230 item_context_path: Option<String>,
231 description: Option<String>,
232 depends_on: Option<Vec<String>>,
233 condition: Option<bool>,
234 ) -> Result<Self, TypeError> {
235 let expected_value = depythonize(expected_value)?;
236 let condition = condition.unwrap_or(false);
237
238 Ok(Self {
239 id: id.to_lowercase(),
240 context_path,
241 item_context_path,
242 operator,
243 expected_value,
244 description,
245 task_type: EvaluationTaskType::Assertion,
246 depends_on: depends_on.unwrap_or_default(),
247 result: None,
248 condition,
249 })
250 }
251
252 #[getter]
253 pub fn get_expected_value<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
254 let py_value = pythonize(py, &self.expected_value)?;
255 Ok(py_value)
256 }
257
258 #[staticmethod]
259 pub fn from_path(path: PathBuf) -> Result<Self, TypeError> {
260 deserialize_from_path(path)
261 }
262}
263
264impl AssertionTask {}
265
266impl TaskAccessor for AssertionTask {
267 fn context_path(&self) -> Option<&str> {
268 self.context_path.as_deref()
269 }
270
271 fn item_context_path(&self) -> Option<&str> {
272 self.item_context_path.as_deref()
273 }
274
275 fn id(&self) -> &str {
276 &self.id
277 }
278
279 fn operator(&self) -> &ComparisonOperator {
280 &self.operator
281 }
282
283 fn task_type(&self) -> &EvaluationTaskType {
284 &self.task_type
285 }
286
287 fn expected_value(&self) -> &Value {
288 &self.expected_value
289 }
290
291 fn depends_on(&self) -> &[String] {
292 &self.depends_on
293 }
294
295 fn add_result(&mut self, result: AssertionResult) {
296 self.result = Some(result);
297 }
298}
299
300pub trait ValueExt {
301 fn to_length(&self) -> Option<i64>;
303
304 fn as_numeric(&self) -> Option<f64>;
306
307 fn is_truthy(&self) -> bool;
309}
310
311impl ValueExt for Value {
312 fn to_length(&self) -> Option<i64> {
313 match self {
314 Value::Array(arr) => Some(arr.len() as i64),
315 Value::String(s) => Some(s.chars().count() as i64),
316 Value::Object(obj) => Some(obj.len() as i64),
317 _ => None,
318 }
319 }
320
321 fn as_numeric(&self) -> Option<f64> {
322 match self {
323 Value::Number(n) => n.as_f64(),
324 _ => None,
325 }
326 }
327
328 fn is_truthy(&self) -> bool {
329 match self {
330 Value::Null => false,
331 Value::Bool(b) => *b,
332 Value::Number(n) => n.as_f64() != Some(0.0),
333 Value::String(s) => !s.is_empty(),
334 Value::Array(arr) => !arr.is_empty(),
335 Value::Object(obj) => !obj.is_empty(),
336 }
337 }
338}
339
340#[pyclass]
342#[derive(Debug, Serialize, Clone, PartialEq)]
343pub struct LLMJudgeTask {
344 #[pyo3(get, set)]
345 pub id: String,
346
347 #[pyo3(get)]
348 pub prompt: Prompt,
349
350 #[pyo3(get)]
351 #[serde(default)]
352 pub context_path: Option<String>,
353
354 pub expected_value: Value,
355
356 #[pyo3(get)]
357 pub operator: ComparisonOperator,
358
359 #[pyo3(get)]
360 pub task_type: EvaluationTaskType,
361
362 #[pyo3(get, set)]
363 #[serde(default)]
364 pub depends_on: Vec<String>,
365
366 #[pyo3(get, set)]
367 #[serde(default)]
368 pub max_retries: Option<u32>,
369
370 #[serde(skip_serializing_if = "Option::is_none")]
371 pub result: Option<AssertionResult>,
372
373 #[serde(default)]
374 pub description: Option<String>,
375
376 #[pyo3(get, set)]
377 #[serde(default)]
378 pub condition: bool,
379}
380
381#[derive(Debug, Deserialize)]
382#[serde(untagged)]
383pub enum PromptConfig {
384 Path { path: String },
385 Inline(Box<Prompt>),
386}
387
388#[derive(Debug, Deserialize)]
389struct LLMJudgeTaskConfig {
390 pub id: String,
391 pub prompt: PromptConfig,
392 pub expected_value: Value,
393 pub operator: ComparisonOperator,
394 pub context_path: Option<String>,
395 pub description: Option<String>,
396 pub depends_on: Vec<String>,
397 pub max_retries: Option<u32>,
398 pub condition: bool,
399}
400
401impl LLMJudgeTaskConfig {
402 pub fn into_task(self) -> Result<LLMJudgeTask, TypeError> {
403 let prompt = match self.prompt {
404 PromptConfig::Path { path } => {
405 Prompt::from_path(PathBuf::from(path)).inspect_err(|e| {
406 error!("Failed to deserialize Prompt from path: {}", e);
407 })?
408 }
409 PromptConfig::Inline(prompt) => *prompt,
410 };
411
412 Ok(LLMJudgeTask {
413 id: self.id.to_lowercase(),
414 prompt,
415 expected_value: self.expected_value,
416 operator: self.operator,
417 context_path: self.context_path,
418 description: self.description,
419 depends_on: self.depends_on,
420 max_retries: self.max_retries.or(Some(3)),
421 task_type: EvaluationTaskType::LLMJudge,
422 result: None,
423 condition: self.condition,
424 })
425 }
426}
427
428#[derive(Debug, Deserialize)]
429struct LLMJudgeTaskInternal {
430 pub id: String,
431 pub prompt: Prompt,
432 pub context_path: Option<String>,
433 pub expected_value: Value,
434 pub operator: ComparisonOperator,
435 pub task_type: EvaluationTaskType,
436 pub depends_on: Vec<String>,
437 pub max_retries: Option<u32>,
438 pub result: Option<AssertionResult>,
439 pub description: Option<String>,
440 pub condition: bool,
441}
442impl LLMJudgeTaskInternal {
443 pub fn into_task(self) -> LLMJudgeTask {
444 LLMJudgeTask {
445 id: self.id.to_lowercase(),
446 prompt: self.prompt,
447 context_path: self.context_path,
448 expected_value: self.expected_value,
449 operator: self.operator,
450 task_type: self.task_type,
451 depends_on: self.depends_on,
452 max_retries: self.max_retries.or(Some(3)),
453 result: self.result,
454 description: self.description,
455 condition: self.condition,
456 }
457 }
458}
459
460#[derive(Debug, Deserialize)]
461#[serde(untagged)]
462enum LLMJudgeFormat {
463 Full(Box<LLMJudgeTaskInternal>),
464 Generic(LLMJudgeTaskConfig),
465}
466
467impl<'de> Deserialize<'de> for LLMJudgeTask {
468 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
469 where
470 D: serde::Deserializer<'de>,
471 {
472 let format = LLMJudgeFormat::deserialize(deserializer)?;
473
474 match format {
475 LLMJudgeFormat::Generic(config) => config.into_task().map_err(serde::de::Error::custom),
476 LLMJudgeFormat::Full(internal) => Ok(internal.into_task()),
477 }
478 }
479}
480
481#[pymethods]
482impl LLMJudgeTask {
483 #[new]
503 #[pyo3(signature = (id, prompt, expected_value, context_path,operator, description=None, depends_on=None, max_retries=None, condition=None))]
504 #[allow(clippy::too_many_arguments)]
505 pub fn new(
506 id: &str,
507 prompt: Prompt,
508 expected_value: &Bound<'_, PyAny>,
509 context_path: Option<String>,
510 operator: ComparisonOperator,
511 description: Option<String>,
512 depends_on: Option<Vec<String>>,
513 max_retries: Option<u32>,
514 condition: Option<bool>,
515 ) -> Result<Self, TypeError> {
516 let expected_value = depythonize(expected_value)?;
517
518 Ok(Self {
519 id: id.to_lowercase(),
520 prompt,
521 expected_value,
522 operator,
523 task_type: EvaluationTaskType::LLMJudge,
524 depends_on: depends_on.unwrap_or_default(),
525 max_retries: max_retries.or(Some(3)),
526 context_path,
527 result: None,
528 description,
529 condition: condition.unwrap_or(false),
530 })
531 }
532
533 pub fn __str__(&self) -> String {
534 PyHelperFuncs::__str__(self)
536 }
537
538 #[getter]
539 pub fn get_expected_value<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
540 let py_value = pythonize(py, &self.expected_value)?;
541 Ok(py_value)
542 }
543
544 #[staticmethod]
545 pub fn from_path(path: PathBuf) -> Result<Self, TypeError> {
546 deserialize_from_path(path)
547 }
548}
549
550impl LLMJudgeTask {
551 #[allow(clippy::too_many_arguments)]
563 pub fn new_rs(
564 id: &str,
565 prompt: Prompt,
566 expected_value: Value,
567 context_path: Option<String>,
568 operator: ComparisonOperator,
569 depends_on: Option<Vec<String>>,
570 max_retries: Option<u32>,
571 description: Option<String>,
572 condition: Option<bool>,
573 ) -> Self {
574 Self {
575 id: id.to_lowercase(),
576 prompt,
577 expected_value,
578 operator,
579 task_type: EvaluationTaskType::LLMJudge,
580 depends_on: depends_on.unwrap_or_default(),
581 max_retries: max_retries.or(Some(3)),
582 context_path,
583 result: None,
584 description,
585 condition: condition.unwrap_or(false),
586 }
587 }
588}
589
590impl TaskAccessor for LLMJudgeTask {
591 fn context_path(&self) -> Option<&str> {
592 self.context_path.as_deref()
593 }
594
595 fn item_context_path(&self) -> Option<&str> {
596 None
597 }
598
599 fn id(&self) -> &str {
600 &self.id
601 }
602 fn task_type(&self) -> &EvaluationTaskType {
603 &self.task_type
604 }
605
606 fn operator(&self) -> &ComparisonOperator {
607 &self.operator
608 }
609
610 fn expected_value(&self) -> &Value {
611 &self.expected_value
612 }
613 fn depends_on(&self) -> &[String] {
614 &self.depends_on
615 }
616 fn add_result(&mut self, result: AssertionResult) {
617 self.result = Some(result);
618 }
619}
620
621#[pyclass(eq, eq_int)]
622#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
623pub enum SpanStatus {
624 Ok,
625 Error,
626 Unset,
627}
628
629#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
630pub struct PyValueWrapper(pub Value);
631
632impl<'py> IntoPyObject<'py> for PyValueWrapper {
633 type Target = PyAny;
634 type Output = Bound<'py, Self::Target>;
635 type Error = TypeError;
636
637 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
638 pythonize(py, &self.0).map_err(TypeError::from)
639 }
640}
641
642impl<'a, 'py> FromPyObject<'a, 'py> for PyValueWrapper {
643 type Error = TypeError;
644
645 fn extract(ob: pyo3::Borrowed<'a, 'py, pyo3::PyAny>) -> Result<Self, Self::Error> {
646 let value: Value = depythonize(&ob)?;
647 Ok(PyValueWrapper(value))
648 }
649}
650
651#[pyclass(eq)]
653#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
654pub enum SpanFilter {
655 ByName { name: String },
657
658 ByNamePattern { pattern: String },
660
661 WithAttribute { key: String },
663
664 WithAttributeValue { key: String, value: PyValueWrapper },
666
667 WithStatus { status: SpanStatus },
669
670 WithDuration {
672 min_ms: Option<f64>,
673 max_ms: Option<f64>,
674 },
675
676 Sequence { names: Vec<String> },
678
679 And { filters: Vec<SpanFilter> },
681
682 Or { filters: Vec<SpanFilter> },
684}
685
686#[pymethods]
687impl SpanFilter {
688 #[staticmethod]
689 pub fn by_name(name: String) -> Self {
690 SpanFilter::ByName { name }
691 }
692
693 #[staticmethod]
694 pub fn by_name_pattern(pattern: String) -> Self {
695 SpanFilter::ByNamePattern { pattern }
696 }
697
698 #[staticmethod]
699 pub fn with_attribute(key: String) -> Self {
700 SpanFilter::WithAttribute { key }
701 }
702
703 #[staticmethod]
704 pub fn with_attribute_value(key: String, value: &Bound<'_, PyAny>) -> Result<Self, TypeError> {
705 let value = PyValueWrapper(depythonize(value)?);
706 Ok(SpanFilter::WithAttributeValue { key, value })
707 }
708
709 #[staticmethod]
710 pub fn with_status(status: SpanStatus) -> Self {
711 SpanFilter::WithStatus { status }
712 }
713
714 #[staticmethod]
715 #[pyo3(signature = (min_ms=None, max_ms=None))]
716 pub fn with_duration(min_ms: Option<f64>, max_ms: Option<f64>) -> Self {
717 SpanFilter::WithDuration { min_ms, max_ms }
718 }
719
720 #[staticmethod]
721 pub fn sequence(names: Vec<String>) -> Self {
722 SpanFilter::Sequence { names }
723 }
724
725 pub fn and_(&self, other: SpanFilter) -> Self {
726 match self {
727 SpanFilter::And { filters } => {
728 let mut new_filters = filters.clone();
729 new_filters.push(other);
730 SpanFilter::And {
731 filters: new_filters,
732 }
733 }
734 _ => SpanFilter::And {
735 filters: vec![self.clone(), other],
736 },
737 }
738 }
739
740 pub fn or_(&self, other: SpanFilter) -> Self {
741 match self {
742 SpanFilter::Or { filters } => {
743 let mut new_filters = filters.clone();
744 new_filters.push(other);
745 SpanFilter::Or {
746 filters: new_filters,
747 }
748 }
749 _ => SpanFilter::Or {
750 filters: vec![self.clone(), other],
751 },
752 }
753 }
754}
755
756#[pyclass(eq, eq_int)]
757#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
758pub enum AggregationType {
759 Count,
760 Sum,
761 Average,
762 Min,
763 Max,
764 First,
765 Last,
766}
767
768#[pyclass(eq)]
770#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
771pub enum MultiResponseMode {
772 Any,
774 All,
776}
777
778#[pyclass]
780#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
781pub enum AttributeFilterTask {
782 Assertion(AssertionTask),
784 AgentAssertion(AgentAssertionTask),
786}
787
788#[pymethods]
789impl AttributeFilterTask {
790 #[staticmethod]
791 pub fn assertion(task: AssertionTask) -> Self {
792 AttributeFilterTask::Assertion(task)
793 }
794
795 #[staticmethod]
796 pub fn agent_assertion(task: AgentAssertionTask) -> Self {
797 AttributeFilterTask::AgentAssertion(task)
798 }
799}
800
801#[pyclass(eq)]
803#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
804#[allow(clippy::large_enum_variant)]
805pub enum TraceAssertion {
806 SpanSequence { span_names: Vec<String> },
808
809 SpanSet { span_names: Vec<String> },
811
812 SpanCount { filter: SpanFilter },
814
815 SpanExists { filter: SpanFilter },
817
818 SpanAttribute {
820 filter: SpanFilter,
821 attribute_key: String,
822 },
823
824 SpanDuration { filter: SpanFilter },
826
827 SpanAggregation {
829 filter: SpanFilter,
830 attribute_key: String,
831 aggregation: AggregationType,
832 },
833
834 TraceDuration {},
836
837 TraceSpanCount {},
839
840 TraceErrorCount {},
842
843 TraceServiceCount {},
845
846 TraceMaxDepth {},
848
849 TraceAttribute { attribute_key: String },
851
852 AttributeFilter {
854 key: String,
855 task: AttributeFilterTask,
856 mode: MultiResponseMode,
857 },
858}
859
860impl Display for TraceAssertion {
862 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
863 let s = serde_json::to_string(self).unwrap_or_default();
865 write!(f, "{}", s)
866 }
867}
868
869#[pymethods]
870impl TraceAssertion {
871 #[staticmethod]
872 pub fn span_sequence(span_names: Vec<String>) -> Self {
873 TraceAssertion::SpanSequence { span_names }
874 }
875
876 #[staticmethod]
877 pub fn span_set(span_names: Vec<String>) -> Self {
878 TraceAssertion::SpanSet { span_names }
879 }
880
881 #[staticmethod]
882 pub fn span_count(filter: SpanFilter) -> Self {
883 TraceAssertion::SpanCount { filter }
884 }
885
886 #[staticmethod]
887 pub fn span_exists(filter: SpanFilter) -> Self {
888 TraceAssertion::SpanExists { filter }
889 }
890
891 #[staticmethod]
892 pub fn span_attribute(filter: SpanFilter, attribute_key: String) -> Self {
893 TraceAssertion::SpanAttribute {
894 filter,
895 attribute_key,
896 }
897 }
898
899 #[staticmethod]
900 pub fn span_duration(filter: SpanFilter) -> Self {
901 TraceAssertion::SpanDuration { filter }
902 }
903
904 #[staticmethod]
905 pub fn span_aggregation(
906 filter: SpanFilter,
907 attribute_key: String,
908 aggregation: AggregationType,
909 ) -> Self {
910 TraceAssertion::SpanAggregation {
911 filter,
912 attribute_key,
913 aggregation,
914 }
915 }
916
917 #[staticmethod]
918 pub fn trace_duration() -> Self {
919 TraceAssertion::TraceDuration {}
920 }
921
922 #[staticmethod]
923 pub fn trace_span_count() -> Self {
924 TraceAssertion::TraceSpanCount {}
925 }
926
927 #[staticmethod]
928 pub fn trace_error_count() -> Self {
929 TraceAssertion::TraceErrorCount {}
930 }
931
932 #[staticmethod]
933 pub fn trace_service_count() -> Self {
934 TraceAssertion::TraceServiceCount {}
935 }
936
937 #[staticmethod]
938 pub fn trace_max_depth() -> Self {
939 TraceAssertion::TraceMaxDepth {}
940 }
941
942 #[staticmethod]
943 pub fn trace_attribute(attribute_key: String) -> Self {
944 TraceAssertion::TraceAttribute { attribute_key }
945 }
946
947 #[staticmethod]
948 pub fn attribute_filter(
949 key: String,
950 task: AttributeFilterTask,
951 mode: MultiResponseMode,
952 ) -> Self {
953 TraceAssertion::AttributeFilter { key, task, mode }
954 }
955
956 pub fn model_dump_json(&self) -> String {
957 serde_json::to_string(self).unwrap_or_default()
958 }
959}
960
961#[pyclass]
962#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
963pub struct TraceAssertionTask {
964 #[pyo3(get, set)]
965 pub id: String,
966
967 #[pyo3(get, set)]
968 pub assertion: TraceAssertion,
969
970 #[pyo3(get, set)]
971 pub operator: ComparisonOperator,
972
973 pub expected_value: Value,
974
975 #[pyo3(get, set)]
976 #[serde(default)]
977 pub description: Option<String>,
978
979 #[pyo3(get, set)]
980 #[serde(default)]
981 pub depends_on: Vec<String>,
982
983 #[serde(default = "default_trace_assertion_task_type")]
984 #[pyo3(get)]
985 pub task_type: EvaluationTaskType,
986
987 #[serde(skip_serializing_if = "Option::is_none")]
988 pub result: Option<AssertionResult>,
989
990 #[pyo3(get, set)]
991 #[serde(default)]
992 pub condition: bool,
993}
994
995#[pymethods]
996impl TraceAssertionTask {
997 #[new]
1047 #[pyo3(signature = (id, assertion, expected_value, operator, description=None, depends_on=None, condition=None))]
1049 pub fn new(
1050 id: String,
1051 assertion: TraceAssertion,
1052 expected_value: &Bound<'_, PyAny>,
1053 operator: ComparisonOperator,
1054 description: Option<String>,
1055 depends_on: Option<Vec<String>>,
1056 condition: Option<bool>,
1057 ) -> Result<Self, TypeError> {
1058 let expected_value = depythonize(expected_value)?;
1059
1060 Ok(Self {
1061 id: id.to_lowercase(),
1062 assertion,
1063 operator,
1064 expected_value,
1065 description,
1066 task_type: EvaluationTaskType::TraceAssertion,
1067 depends_on: depends_on.unwrap_or_default(),
1068 result: None,
1069 condition: condition.unwrap_or(false),
1070 })
1071 }
1072
1073 pub fn __str__(&self) -> String {
1074 PyHelperFuncs::__str__(self)
1076 }
1077
1078 #[getter]
1079 pub fn get_expected_value<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
1080 let py_value = pythonize(py, &self.expected_value)?;
1081 Ok(py_value)
1082 }
1083
1084 #[staticmethod]
1085 pub fn from_path(path: PathBuf) -> Result<Self, TypeError> {
1086 deserialize_from_path(path)
1087 }
1088}
1089
1090impl TaskAccessor for TraceAssertionTask {
1091 fn context_path(&self) -> Option<&str> {
1092 None
1093 }
1094
1095 fn item_context_path(&self) -> Option<&str> {
1096 None
1097 }
1098
1099 fn id(&self) -> &str {
1100 &self.id
1101 }
1102
1103 fn operator(&self) -> &ComparisonOperator {
1104 &self.operator
1105 }
1106
1107 fn task_type(&self) -> &EvaluationTaskType {
1108 &self.task_type
1109 }
1110
1111 fn expected_value(&self) -> &Value {
1112 &self.expected_value
1113 }
1114
1115 fn depends_on(&self) -> &[String] {
1116 &self.depends_on
1117 }
1118
1119 fn add_result(&mut self, result: AssertionResult) {
1120 self.result = Some(result);
1121 }
1122}
1123
1124#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1125pub struct ToolCall {
1126 pub name: String,
1127 pub arguments: Value,
1128 pub result: Option<Value>,
1129 pub call_id: Option<String>,
1130}
1131
1132#[pyclass]
1133#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1134pub struct TokenUsage {
1135 #[pyo3(get, set)]
1136 pub input_tokens: Option<i64>,
1137
1138 #[pyo3(get, set)]
1139 pub output_tokens: Option<i64>,
1140
1141 #[pyo3(get, set)]
1142 pub total_tokens: Option<i64>,
1143}
1144
1145#[pymethods]
1146impl TokenUsage {
1147 #[new]
1148 #[pyo3(signature = (input_tokens=None, output_tokens=None, total_tokens=None))]
1149 pub fn new(
1150 input_tokens: Option<i64>,
1151 output_tokens: Option<i64>,
1152 total_tokens: Option<i64>,
1153 ) -> Self {
1154 Self {
1155 input_tokens,
1156 output_tokens,
1157 total_tokens,
1158 }
1159 }
1160
1161 pub fn __str__(&self) -> String {
1162 serde_json::to_string_pretty(self).unwrap_or_default()
1163 }
1164}
1165
1166#[pyclass(eq)]
1167#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1168pub enum AgentAssertion {
1169 ToolCalled { name: String },
1171
1172 ToolNotCalled { name: String },
1174
1175 ToolCalledWithArgs {
1177 name: String,
1178 arguments: PyValueWrapper,
1179 },
1180
1181 ToolCallSequence { names: Vec<String> },
1183
1184 ToolCallCount { name: Option<String> },
1186
1187 ToolArgument { name: String, argument_key: String },
1189
1190 ToolResult { name: String },
1192
1193 ResponseContent {},
1195
1196 ResponseModel {},
1198
1199 ResponseFinishReason {},
1201
1202 ResponseInputTokens {},
1204
1205 ResponseOutputTokens {},
1207
1208 ResponseTotalTokens {},
1210
1211 ResponseField { path: String },
1213}
1214
1215impl Display for AgentAssertion {
1216 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1217 let s = serde_json::to_string(self).unwrap_or_default();
1218 write!(f, "{}", s)
1219 }
1220}
1221
1222#[pymethods]
1223impl AgentAssertion {
1224 #[staticmethod]
1225 pub fn tool_called(name: &str) -> Self {
1226 AgentAssertion::ToolCalled {
1227 name: name.to_string(),
1228 }
1229 }
1230
1231 #[staticmethod]
1232 pub fn tool_not_called(name: &str) -> Self {
1233 AgentAssertion::ToolNotCalled {
1234 name: name.to_string(),
1235 }
1236 }
1237
1238 #[staticmethod]
1239 pub fn tool_called_with_args(
1240 name: &str,
1241 arguments: &Bound<'_, PyAny>,
1242 ) -> Result<Self, TypeError> {
1243 let arguments: Value = depythonize(arguments)?;
1244 Ok(AgentAssertion::ToolCalledWithArgs {
1245 name: name.to_string(),
1246 arguments: PyValueWrapper(arguments),
1247 })
1248 }
1249
1250 #[staticmethod]
1251 pub fn tool_call_sequence(names: Vec<String>) -> Self {
1252 AgentAssertion::ToolCallSequence { names }
1253 }
1254
1255 #[staticmethod]
1256 #[pyo3(signature = (name=None))]
1257 pub fn tool_call_count(name: Option<String>) -> Self {
1258 AgentAssertion::ToolCallCount { name }
1259 }
1260
1261 #[staticmethod]
1262 pub fn tool_argument(name: &str, argument_key: &str) -> Self {
1263 AgentAssertion::ToolArgument {
1264 name: name.to_string(),
1265 argument_key: argument_key.to_string(),
1266 }
1267 }
1268
1269 #[staticmethod]
1270 pub fn tool_result(name: &str) -> Self {
1271 AgentAssertion::ToolResult {
1272 name: name.to_string(),
1273 }
1274 }
1275
1276 #[staticmethod]
1277 pub fn response_content() -> Self {
1278 AgentAssertion::ResponseContent {}
1279 }
1280
1281 #[staticmethod]
1282 pub fn response_model() -> Self {
1283 AgentAssertion::ResponseModel {}
1284 }
1285
1286 #[staticmethod]
1287 pub fn response_finish_reason() -> Self {
1288 AgentAssertion::ResponseFinishReason {}
1289 }
1290
1291 #[staticmethod]
1292 pub fn response_input_tokens() -> Self {
1293 AgentAssertion::ResponseInputTokens {}
1294 }
1295
1296 #[staticmethod]
1297 pub fn response_output_tokens() -> Self {
1298 AgentAssertion::ResponseOutputTokens {}
1299 }
1300
1301 #[staticmethod]
1302 pub fn response_total_tokens() -> Self {
1303 AgentAssertion::ResponseTotalTokens {}
1304 }
1305
1306 #[staticmethod]
1307 pub fn response_field(path: &str) -> Self {
1308 AgentAssertion::ResponseField {
1309 path: path.to_string(),
1310 }
1311 }
1312
1313 pub fn __str__(&self) -> String {
1314 serde_json::to_string_pretty(self).unwrap_or_default()
1315 }
1316}
1317
1318#[pyclass]
1319#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1320pub struct AgentAssertionTask {
1321 #[pyo3(get, set)]
1322 pub id: String,
1323
1324 #[pyo3(get, set)]
1325 pub assertion: AgentAssertion,
1326
1327 #[pyo3(get, set)]
1328 pub operator: ComparisonOperator,
1329
1330 pub expected_value: Value,
1331
1332 #[pyo3(get, set)]
1333 #[serde(default)]
1334 pub description: Option<String>,
1335
1336 #[pyo3(get, set)]
1337 #[serde(default)]
1338 pub depends_on: Vec<String>,
1339
1340 #[serde(default = "default_agent_assertion_task_type")]
1341 #[pyo3(get)]
1342 pub task_type: EvaluationTaskType,
1343
1344 #[serde(skip_serializing_if = "Option::is_none")]
1345 pub result: Option<AssertionResult>,
1346
1347 #[pyo3(get, set)]
1348 #[serde(default)]
1349 pub condition: bool,
1350
1351 #[pyo3(get, set)]
1352 #[serde(default)]
1353 pub provider: Option<Provider>,
1354}
1355
1356#[pymethods]
1357impl AgentAssertionTask {
1358 #[new]
1359 #[pyo3(signature = (id, assertion, expected_value, operator, description=None, depends_on=None, condition=None, provider=None))]
1360 #[allow(clippy::too_many_arguments)]
1361 pub fn new(
1362 id: String,
1363 assertion: AgentAssertion,
1364 expected_value: &Bound<'_, PyAny>,
1365 operator: ComparisonOperator,
1366 description: Option<String>,
1367 depends_on: Option<Vec<String>>,
1368 condition: Option<bool>,
1369 provider: Option<Provider>,
1370 ) -> Result<Self, TypeError> {
1371 let expected_value = depythonize(expected_value)?;
1372
1373 Ok(Self {
1374 id: id.to_lowercase(),
1375 assertion,
1376 operator,
1377 expected_value,
1378 description,
1379 task_type: EvaluationTaskType::AgentAssertion,
1380 depends_on: depends_on.unwrap_or_default(),
1381 result: None,
1382 condition: condition.unwrap_or(false),
1383 provider,
1384 })
1385 }
1386
1387 pub fn __str__(&self) -> String {
1388 PyHelperFuncs::__str__(self)
1389 }
1390
1391 pub fn model_dump_json(&self) -> String {
1392 serde_json::to_string(self).unwrap_or_default()
1393 }
1394
1395 #[getter]
1396 pub fn get_expected_value<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
1397 let py_value = pythonize(py, &self.expected_value)?;
1398 Ok(py_value)
1399 }
1400
1401 #[getter]
1402 pub fn get_result<'py>(&self, py: Python<'py>) -> Result<Option<Bound<'py, PyAny>>, TypeError> {
1403 match &self.result {
1404 Some(result) => {
1405 let py_value = pythonize(py, result)?;
1406 Ok(Some(py_value))
1407 }
1408 None => Ok(None),
1409 }
1410 }
1411
1412 #[staticmethod]
1413 pub fn from_path(path: PathBuf) -> Result<Self, TypeError> {
1414 deserialize_from_path(path)
1415 }
1416}
1417
1418impl TaskAccessor for AgentAssertionTask {
1419 fn context_path(&self) -> Option<&str> {
1420 None
1421 }
1422
1423 fn item_context_path(&self) -> Option<&str> {
1424 None
1425 }
1426
1427 fn id(&self) -> &str {
1428 &self.id
1429 }
1430
1431 fn operator(&self) -> &ComparisonOperator {
1432 &self.operator
1433 }
1434
1435 fn task_type(&self) -> &EvaluationTaskType {
1436 &self.task_type
1437 }
1438
1439 fn expected_value(&self) -> &Value {
1440 &self.expected_value
1441 }
1442
1443 fn depends_on(&self) -> &[String] {
1444 &self.depends_on
1445 }
1446
1447 fn add_result(&mut self, result: AssertionResult) {
1448 self.result = Some(result);
1449 }
1450}
1451
1452#[derive(Debug, Clone)]
1453pub enum EvaluationTask {
1454 Assertion(Box<AssertionTask>),
1455 LLMJudge(Box<LLMJudgeTask>),
1456 TraceAssertion(Box<TraceAssertionTask>),
1457 AgentAssertion(Box<AgentAssertionTask>),
1458}
1459
1460impl TaskAccessor for EvaluationTask {
1461 fn context_path(&self) -> Option<&str> {
1462 match self {
1463 EvaluationTask::Assertion(t) => t.context_path(),
1464 EvaluationTask::LLMJudge(t) => t.context_path(),
1465 EvaluationTask::TraceAssertion(t) => t.context_path(),
1466 EvaluationTask::AgentAssertion(t) => t.context_path(),
1467 }
1468 }
1469
1470 fn item_context_path(&self) -> Option<&str> {
1471 match self {
1472 EvaluationTask::Assertion(t) => t.item_context_path(),
1473 EvaluationTask::LLMJudge(t) => t.item_context_path(),
1474 EvaluationTask::TraceAssertion(t) => t.item_context_path(),
1475 EvaluationTask::AgentAssertion(t) => t.item_context_path(),
1476 }
1477 }
1478
1479 fn id(&self) -> &str {
1480 match self {
1481 EvaluationTask::Assertion(t) => t.id(),
1482 EvaluationTask::LLMJudge(t) => t.id(),
1483 EvaluationTask::TraceAssertion(t) => t.id(),
1484 EvaluationTask::AgentAssertion(t) => t.id(),
1485 }
1486 }
1487
1488 fn task_type(&self) -> &EvaluationTaskType {
1489 match self {
1490 EvaluationTask::Assertion(t) => t.task_type(),
1491 EvaluationTask::LLMJudge(t) => t.task_type(),
1492 EvaluationTask::TraceAssertion(t) => t.task_type(),
1493 EvaluationTask::AgentAssertion(t) => t.task_type(),
1494 }
1495 }
1496
1497 fn operator(&self) -> &ComparisonOperator {
1498 match self {
1499 EvaluationTask::Assertion(t) => t.operator(),
1500 EvaluationTask::LLMJudge(t) => t.operator(),
1501 EvaluationTask::TraceAssertion(t) => t.operator(),
1502 EvaluationTask::AgentAssertion(t) => t.operator(),
1503 }
1504 }
1505
1506 fn expected_value(&self) -> &Value {
1507 match self {
1508 EvaluationTask::Assertion(t) => t.expected_value(),
1509 EvaluationTask::LLMJudge(t) => t.expected_value(),
1510 EvaluationTask::TraceAssertion(t) => t.expected_value(),
1511 EvaluationTask::AgentAssertion(t) => t.expected_value(),
1512 }
1513 }
1514
1515 fn depends_on(&self) -> &[String] {
1516 match self {
1517 EvaluationTask::Assertion(t) => t.depends_on(),
1518 EvaluationTask::LLMJudge(t) => t.depends_on(),
1519 EvaluationTask::TraceAssertion(t) => t.depends_on(),
1520 EvaluationTask::AgentAssertion(t) => t.depends_on(),
1521 }
1522 }
1523
1524 fn add_result(&mut self, result: AssertionResult) {
1525 match self {
1526 EvaluationTask::Assertion(t) => t.add_result(result),
1527 EvaluationTask::LLMJudge(t) => t.add_result(result),
1528 EvaluationTask::TraceAssertion(t) => t.add_result(result),
1529 EvaluationTask::AgentAssertion(t) => t.add_result(result),
1530 }
1531 }
1532}
1533
1534pub struct EvaluationTasks(Vec<EvaluationTask>);
1535
1536impl EvaluationTasks {
1537 pub fn new() -> Self {
1539 Self(Vec::new())
1540 }
1541
1542 pub fn add_task(mut self, task: impl Into<EvaluationTask>) -> Self {
1544 self.0.push(task.into());
1545 self
1546 }
1547
1548 pub fn build(self) -> Vec<EvaluationTask> {
1550 self.0
1551 }
1552}
1553
1554impl From<AssertionTask> for EvaluationTask {
1555 fn from(task: AssertionTask) -> Self {
1556 EvaluationTask::Assertion(Box::new(task))
1557 }
1558}
1559
1560impl From<LLMJudgeTask> for EvaluationTask {
1561 fn from(task: LLMJudgeTask) -> Self {
1562 EvaluationTask::LLMJudge(Box::new(task))
1563 }
1564}
1565
1566impl From<TraceAssertionTask> for EvaluationTask {
1567 fn from(task: TraceAssertionTask) -> Self {
1568 EvaluationTask::TraceAssertion(Box::new(task))
1569 }
1570}
1571
1572impl From<AgentAssertionTask> for EvaluationTask {
1573 fn from(task: AgentAssertionTask) -> Self {
1574 EvaluationTask::AgentAssertion(Box::new(task))
1575 }
1576}
1577
1578impl Default for EvaluationTasks {
1579 fn default() -> Self {
1580 Self::new()
1581 }
1582}
1583
1584#[pyclass(eq, eq_int)]
1585#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1586pub enum ComparisonOperator {
1587 Equals,
1589 NotEqual,
1590 GreaterThan,
1591 GreaterThanOrEqual,
1592 LessThan,
1593 LessThanOrEqual,
1594 Contains,
1595 NotContains,
1596 StartsWith,
1597 EndsWith,
1598 Matches,
1599 HasLengthGreaterThan,
1600 HasLengthLessThan,
1601 HasLengthEqual,
1602 HasLengthGreaterThanOrEqual,
1603 HasLengthLessThanOrEqual,
1604
1605 IsNumeric,
1607 IsString,
1608 IsBoolean,
1609 IsNull,
1610 IsArray,
1611 IsObject,
1612
1613 IsEmail,
1615 IsUrl,
1616 IsUuid,
1617 IsIso8601,
1618 IsJson,
1619 MatchesRegex,
1620
1621 InRange,
1623 NotInRange,
1624 IsPositive,
1625 IsNegative,
1626 IsZero,
1627
1628 SequenceMatches,
1630 ContainsAll,
1631 ContainsAny,
1632 ContainsNone,
1633 IsEmpty,
1634 IsNotEmpty,
1635 HasUniqueItems,
1636
1637 IsAlphabetic,
1639 IsAlphanumeric,
1640 IsLowerCase,
1641 IsUpperCase,
1642 ContainsWord,
1643
1644 ApproximatelyEquals,
1646}
1647
1648impl Display for ComparisonOperator {
1649 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1650 write!(f, "{}", self.as_str())
1651 }
1652}
1653
1654impl FromStr for ComparisonOperator {
1655 type Err = TypeError;
1656
1657 fn from_str(s: &str) -> Result<Self, Self::Err> {
1658 match s {
1659 "Equals" => Ok(ComparisonOperator::Equals),
1660 "NotEqual" => Ok(ComparisonOperator::NotEqual),
1661 "GreaterThan" => Ok(ComparisonOperator::GreaterThan),
1662 "GreaterThanOrEqual" => Ok(ComparisonOperator::GreaterThanOrEqual),
1663 "LessThan" => Ok(ComparisonOperator::LessThan),
1664 "LessThanOrEqual" => Ok(ComparisonOperator::LessThanOrEqual),
1665 "Contains" => Ok(ComparisonOperator::Contains),
1666 "NotContains" => Ok(ComparisonOperator::NotContains),
1667 "StartsWith" => Ok(ComparisonOperator::StartsWith),
1668 "EndsWith" => Ok(ComparisonOperator::EndsWith),
1669 "Matches" => Ok(ComparisonOperator::Matches),
1670 "HasLengthEqual" => Ok(ComparisonOperator::HasLengthEqual),
1671 "HasLengthGreaterThan" => Ok(ComparisonOperator::HasLengthGreaterThan),
1672 "HasLengthLessThan" => Ok(ComparisonOperator::HasLengthLessThan),
1673 "HasLengthGreaterThanOrEqual" => Ok(ComparisonOperator::HasLengthGreaterThanOrEqual),
1674 "HasLengthLessThanOrEqual" => Ok(ComparisonOperator::HasLengthLessThanOrEqual),
1675
1676 "IsNumeric" => Ok(ComparisonOperator::IsNumeric),
1678 "IsString" => Ok(ComparisonOperator::IsString),
1679 "IsBoolean" => Ok(ComparisonOperator::IsBoolean),
1680 "IsNull" => Ok(ComparisonOperator::IsNull),
1681 "IsArray" => Ok(ComparisonOperator::IsArray),
1682 "IsObject" => Ok(ComparisonOperator::IsObject),
1683
1684 "IsEmail" => Ok(ComparisonOperator::IsEmail),
1686 "IsUrl" => Ok(ComparisonOperator::IsUrl),
1687 "IsUuid" => Ok(ComparisonOperator::IsUuid),
1688 "IsIso8601" => Ok(ComparisonOperator::IsIso8601),
1689 "IsJson" => Ok(ComparisonOperator::IsJson),
1690 "MatchesRegex" => Ok(ComparisonOperator::MatchesRegex),
1691
1692 "InRange" => Ok(ComparisonOperator::InRange),
1694 "NotInRange" => Ok(ComparisonOperator::NotInRange),
1695 "IsPositive" => Ok(ComparisonOperator::IsPositive),
1696 "IsNegative" => Ok(ComparisonOperator::IsNegative),
1697 "IsZero" => Ok(ComparisonOperator::IsZero),
1698
1699 "ContainsAll" => Ok(ComparisonOperator::ContainsAll),
1701 "ContainsAny" => Ok(ComparisonOperator::ContainsAny),
1702 "ContainsNone" => Ok(ComparisonOperator::ContainsNone),
1703 "IsEmpty" => Ok(ComparisonOperator::IsEmpty),
1704 "IsNotEmpty" => Ok(ComparisonOperator::IsNotEmpty),
1705 "HasUniqueItems" => Ok(ComparisonOperator::HasUniqueItems),
1706 "SequenceMatches" => Ok(ComparisonOperator::SequenceMatches),
1707
1708 "IsAlphabetic" => Ok(ComparisonOperator::IsAlphabetic),
1710 "IsAlphanumeric" => Ok(ComparisonOperator::IsAlphanumeric),
1711 "IsLowerCase" => Ok(ComparisonOperator::IsLowerCase),
1712 "IsUpperCase" => Ok(ComparisonOperator::IsUpperCase),
1713 "ContainsWord" => Ok(ComparisonOperator::ContainsWord),
1714
1715 "ApproximatelyEquals" => Ok(ComparisonOperator::ApproximatelyEquals),
1717
1718 _ => Err(TypeError::InvalidCompressionTypeError),
1719 }
1720 }
1721}
1722
1723impl ComparisonOperator {
1724 pub fn as_str(&self) -> &str {
1725 match self {
1726 ComparisonOperator::Equals => "Equals",
1727 ComparisonOperator::NotEqual => "NotEqual",
1728 ComparisonOperator::GreaterThan => "GreaterThan",
1729 ComparisonOperator::GreaterThanOrEqual => "GreaterThanOrEqual",
1730 ComparisonOperator::LessThan => "LessThan",
1731 ComparisonOperator::LessThanOrEqual => "LessThanOrEqual",
1732 ComparisonOperator::Contains => "Contains",
1733 ComparisonOperator::NotContains => "NotContains",
1734 ComparisonOperator::StartsWith => "StartsWith",
1735 ComparisonOperator::EndsWith => "EndsWith",
1736 ComparisonOperator::Matches => "Matches",
1737 ComparisonOperator::HasLengthEqual => "HasLengthEqual",
1738 ComparisonOperator::HasLengthGreaterThan => "HasLengthGreaterThan",
1739 ComparisonOperator::HasLengthLessThan => "HasLengthLessThan",
1740 ComparisonOperator::HasLengthGreaterThanOrEqual => "HasLengthGreaterThanOrEqual",
1741 ComparisonOperator::HasLengthLessThanOrEqual => "HasLengthLessThanOrEqual",
1742
1743 ComparisonOperator::IsNumeric => "IsNumeric",
1745 ComparisonOperator::IsString => "IsString",
1746 ComparisonOperator::IsBoolean => "IsBoolean",
1747 ComparisonOperator::IsNull => "IsNull",
1748 ComparisonOperator::IsArray => "IsArray",
1749 ComparisonOperator::IsObject => "IsObject",
1750
1751 ComparisonOperator::IsEmail => "IsEmail",
1753 ComparisonOperator::IsUrl => "IsUrl",
1754 ComparisonOperator::IsUuid => "IsUuid",
1755 ComparisonOperator::IsIso8601 => "IsIso8601",
1756 ComparisonOperator::IsJson => "IsJson",
1757 ComparisonOperator::MatchesRegex => "MatchesRegex",
1758
1759 ComparisonOperator::InRange => "InRange",
1761 ComparisonOperator::NotInRange => "NotInRange",
1762 ComparisonOperator::IsPositive => "IsPositive",
1763 ComparisonOperator::IsNegative => "IsNegative",
1764 ComparisonOperator::IsZero => "IsZero",
1765
1766 ComparisonOperator::ContainsAll => "ContainsAll",
1768 ComparisonOperator::ContainsAny => "ContainsAny",
1769 ComparisonOperator::ContainsNone => "ContainsNone",
1770 ComparisonOperator::IsEmpty => "IsEmpty",
1771 ComparisonOperator::IsNotEmpty => "IsNotEmpty",
1772 ComparisonOperator::HasUniqueItems => "HasUniqueItems",
1773 ComparisonOperator::SequenceMatches => "SequenceMatches",
1774
1775 ComparisonOperator::IsAlphabetic => "IsAlphabetic",
1777 ComparisonOperator::IsAlphanumeric => "IsAlphanumeric",
1778 ComparisonOperator::IsLowerCase => "IsLowerCase",
1779 ComparisonOperator::IsUpperCase => "IsUpperCase",
1780 ComparisonOperator::ContainsWord => "ContainsWord",
1781
1782 ComparisonOperator::ApproximatelyEquals => "ApproximatelyEquals",
1784 }
1785 }
1786}
1787
1788#[pyclass]
1789#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1790pub enum AssertionValue {
1791 String(String),
1792 Number(f64),
1793 Integer(i64),
1794 Boolean(bool),
1795 List(Vec<AssertionValue>),
1796 Null(),
1797}
1798
1799impl AssertionValue {
1800 pub fn to_actual(self, comparison: &ComparisonOperator) -> AssertionValue {
1801 match comparison {
1802 ComparisonOperator::HasLengthEqual
1803 | ComparisonOperator::HasLengthGreaterThan
1804 | ComparisonOperator::HasLengthLessThan
1805 | ComparisonOperator::HasLengthGreaterThanOrEqual
1806 | ComparisonOperator::HasLengthLessThanOrEqual => match self {
1807 AssertionValue::List(arr) => AssertionValue::Integer(arr.len() as i64),
1808 AssertionValue::String(s) => AssertionValue::Integer(s.chars().count() as i64),
1809 _ => self,
1810 },
1811 _ => self,
1812 }
1813 }
1814
1815 pub fn to_serde_value(&self) -> Value {
1816 match self {
1817 AssertionValue::String(s) => Value::String(s.clone()),
1818 AssertionValue::Number(n) => Value::Number(serde_json::Number::from_f64(*n).unwrap()),
1819 AssertionValue::Integer(i) => Value::Number(serde_json::Number::from(*i)),
1820 AssertionValue::Boolean(b) => Value::Bool(*b),
1821 AssertionValue::List(arr) => {
1822 let json_arr: Vec<Value> = arr.iter().map(|v| v.to_serde_value()).collect();
1823 Value::Array(json_arr)
1824 }
1825 AssertionValue::Null() => Value::Null,
1826 }
1827 }
1828}
1829pub fn assertion_value_from_py(value: &Bound<'_, PyAny>) -> Result<AssertionValue, TypeError> {
1836 if value.is_none() {
1838 return Ok(AssertionValue::Null());
1839 }
1840
1841 if value.is_instance_of::<PyBool>() {
1843 return Ok(AssertionValue::Boolean(value.extract()?));
1844 }
1845
1846 if value.is_instance_of::<PyString>() {
1847 return Ok(AssertionValue::String(value.extract()?));
1848 }
1849
1850 if value.is_instance_of::<PyInt>() {
1851 return Ok(AssertionValue::Integer(value.extract()?));
1852 }
1853
1854 if value.is_instance_of::<PyFloat>() {
1855 return Ok(AssertionValue::Number(value.extract()?));
1856 }
1857
1858 if value.is_instance_of::<PyList>() {
1859 let list = value.cast::<PyList>()?; let assertion_list = list
1862 .iter()
1863 .map(|item| assertion_value_from_py(&item))
1864 .collect::<Result<Vec<_>, _>>()?;
1865 return Ok(AssertionValue::List(assertion_list));
1866 }
1867
1868 Err(TypeError::UnsupportedType(
1870 value.get_type().name()?.to_string(),
1871 ))
1872}
1873
1874#[pyclass(eq, eq_int)]
1875#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1876pub enum EvaluationTaskType {
1877 Assertion,
1878 LLMJudge,
1879 Conditional,
1880 HumanValidation,
1881 TraceAssertion,
1882 AgentAssertion,
1883}
1884
1885impl Display for EvaluationTaskType {
1886 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1887 let task_type_str = match self {
1888 EvaluationTaskType::Assertion => "Assertion",
1889 EvaluationTaskType::LLMJudge => "LLMJudge",
1890 EvaluationTaskType::Conditional => "Conditional",
1891 EvaluationTaskType::HumanValidation => "HumanValidation",
1892 EvaluationTaskType::TraceAssertion => "TraceAssertion",
1893 EvaluationTaskType::AgentAssertion => "AgentAssertion",
1894 };
1895 write!(f, "{}", task_type_str)
1896 }
1897}
1898
1899impl FromStr for EvaluationTaskType {
1900 type Err = TypeError;
1901
1902 fn from_str(s: &str) -> Result<Self, Self::Err> {
1903 match s {
1904 "Assertion" => Ok(EvaluationTaskType::Assertion),
1905 "LLMJudge" => Ok(EvaluationTaskType::LLMJudge),
1906 "Conditional" => Ok(EvaluationTaskType::Conditional),
1907 "HumanValidation" => Ok(EvaluationTaskType::HumanValidation),
1908 "TraceAssertion" => Ok(EvaluationTaskType::TraceAssertion),
1909 "AgentAssertion" => Ok(EvaluationTaskType::AgentAssertion),
1910 _ => Err(TypeError::InvalidEvalType(s.to_string())),
1911 }
1912 }
1913}
1914
1915impl EvaluationTaskType {
1916 pub fn as_str(&self) -> &str {
1917 match self {
1918 EvaluationTaskType::Assertion => "Assertion",
1919 EvaluationTaskType::LLMJudge => "LLMJudge",
1920 EvaluationTaskType::Conditional => "Conditional",
1921 EvaluationTaskType::HumanValidation => "HumanValidation",
1922 EvaluationTaskType::TraceAssertion => "TraceAssertion",
1923 EvaluationTaskType::AgentAssertion => "AgentAssertion",
1924 }
1925 }
1926}
1927
1928#[pyclass]
1929#[derive(Debug, Serialize)]
1930pub struct TasksFile {
1931 pub tasks: Vec<TaskConfig>,
1932
1933 #[serde(default)]
1934 index: usize,
1935}
1936
1937#[pymethods]
1938impl TasksFile {
1939 #[staticmethod]
1940 pub fn from_path(path: PathBuf) -> Result<Self, TypeError> {
1941 let tasks_file: TasksFile = deserialize_from_path(path)?;
1942 Ok(tasks_file)
1943 }
1944
1945 pub fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
1946 slf
1947 }
1948
1949 pub fn __next__<'py>(
1950 mut slf: PyRefMut<'py, Self>,
1951 ) -> Result<Option<Bound<'py, PyAny>>, TypeError> {
1952 let py = slf.py();
1953 if slf.index < slf.tasks.len() {
1954 let task = slf.tasks[slf.index].clone().into_bound_py_any(py)?;
1955 slf.index += 1;
1956 Ok(Some(task))
1957 } else {
1958 Ok(None)
1959 }
1960 }
1961
1962 fn __getitem__<'py>(
1963 &self,
1964 py: Python<'py>,
1965 index: &Bound<'py, PyAny>,
1966 ) -> Result<Bound<'py, PyAny>, TypeError> {
1967 if let Ok(i) = index.extract::<isize>() {
1968 let len = self.tasks.len() as isize;
1969 let actual_index = if i < 0 { len + i } else { i };
1970
1971 if actual_index < 0 || actual_index >= len {
1972 return Err(TypeError::IndexOutOfBounds {
1973 index: i,
1974 length: self.tasks.len(),
1975 });
1976 }
1977
1978 Ok(self.tasks[actual_index as usize]
1979 .clone()
1980 .into_bound_py_any(py)?)
1981 } else if let Ok(slice) = index.cast::<PySlice>() {
1982 let indices = slice.indices(self.tasks.len() as isize)?;
1983 let result = PyList::empty(py);
1984
1985 let mut i = indices.start;
1986 while (indices.step > 0 && i < indices.stop) || (indices.step < 0 && i > indices.stop) {
1987 result.append(self.tasks[i as usize].clone().into_bound_py_any(py)?)?;
1988 i += indices.step;
1989 }
1990
1991 Ok(result.into_bound_py_any(py)?)
1992 } else {
1993 Err(TypeError::IndexOrSliceExpected)
1994 }
1995 }
1996
1997 fn __len__(&self) -> usize {
1998 self.tasks.len()
1999 }
2000
2001 fn __str__(&self) -> String {
2002 PyHelperFuncs::__str__(self)
2003 }
2004}
2005
2006#[derive(Debug, Serialize, Clone)]
2007#[allow(clippy::large_enum_variant)]
2008pub enum TaskConfig {
2009 Assertion(AssertionTask),
2010 #[serde(rename = "LLMJudge")]
2011 LLMJudge(Box<LLMJudgeTask>),
2012 TraceAssertion(TraceAssertionTask),
2013 AgentAssertion(AgentAssertionTask),
2014}
2015
2016impl TaskConfig {
2017 fn into_bound_py_any<'py>(self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
2018 match self {
2019 TaskConfig::Assertion(task) => Ok(task.into_bound_py_any(py)?),
2020 TaskConfig::LLMJudge(task) => Ok(task.into_bound_py_any(py)?),
2021 TaskConfig::TraceAssertion(task) => Ok(task.into_bound_py_any(py)?),
2022 TaskConfig::AgentAssertion(task) => Ok(task.into_bound_py_any(py)?),
2023 }
2024 }
2025}
2026
2027impl<'de> Deserialize<'de> for TasksFile {
2028 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
2029 where
2030 D: serde::Deserializer<'de>,
2031 {
2032 #[derive(Deserialize)]
2033 #[serde(untagged)]
2034 enum TasksFileRaw {
2035 Direct(Vec<TaskConfigRaw>),
2036 Wrapped { tasks: Vec<TaskConfigRaw> },
2037 }
2038
2039 #[derive(Deserialize)]
2040 struct TaskConfigRaw {
2041 task_type: EvaluationTaskType,
2042 #[serde(flatten)]
2043 data: Value,
2044 }
2045
2046 let raw = TasksFileRaw::deserialize(deserializer)?;
2047 let raw_tasks = match raw {
2048 TasksFileRaw::Direct(tasks) => tasks,
2049 TasksFileRaw::Wrapped { tasks } => tasks,
2050 };
2051
2052 let mut tasks = Vec::new();
2053
2054 for task_raw in raw_tasks {
2055 let task_config = match task_raw.task_type {
2056 EvaluationTaskType::Assertion => {
2057 let mut task: AssertionTask =
2058 serde_json::from_value(task_raw.data).map_err(|e| {
2059 error!("Failed to deserialize AssertionTask: {}", e);
2060 serde::de::Error::custom(e.to_string())
2061 })?;
2062 task.task_type = EvaluationTaskType::Assertion;
2063 TaskConfig::Assertion(task)
2064 }
2065 EvaluationTaskType::LLMJudge => {
2066 let mut task: LLMJudgeTask =
2067 serde_json::from_value(task_raw.data).map_err(|e| {
2068 error!("Failed to deserialize LLMJudgeTask: {}", e);
2069 serde::de::Error::custom(e.to_string())
2070 })?;
2071 task.task_type = EvaluationTaskType::LLMJudge;
2072 TaskConfig::LLMJudge(Box::new(task))
2073 }
2074 EvaluationTaskType::TraceAssertion => {
2075 let mut task: TraceAssertionTask = serde_json::from_value(task_raw.data)
2076 .map_err(|e| {
2077 error!("Failed to deserialize TraceAssertionTask: {}", e);
2078 serde::de::Error::custom(e.to_string())
2079 })?;
2080 task.task_type = EvaluationTaskType::TraceAssertion;
2081 TaskConfig::TraceAssertion(task)
2082 }
2083 EvaluationTaskType::AgentAssertion => {
2084 let mut task: AgentAssertionTask = serde_json::from_value(task_raw.data)
2085 .map_err(|e| {
2086 error!("Failed to deserialize AgentAssertionTask: {}", e);
2087 serde::de::Error::custom(e.to_string())
2088 })?;
2089 task.task_type = EvaluationTaskType::AgentAssertion;
2090 TaskConfig::AgentAssertion(task)
2091 }
2092 _ => {
2093 return Err(serde::de::Error::custom(format!(
2094 "Unknown task_type: {}",
2095 task_raw.task_type
2096 )))
2097 }
2098 };
2099 tasks.push(task_config);
2100 }
2101
2102 Ok(TasksFile { tasks, index: 0 })
2103 }
2104}