1use crate::error::TypeError;
2use crate::genai::traits::TaskAccessor;
3use crate::PyHelperFuncs;
4use core::fmt::Debug;
5use potato_head::prompt_types::Prompt;
6use pyo3::prelude::*;
7use pyo3::types::{PyBool, PyFloat, PyInt, PyList, PySlice, PyString};
8use pyo3::IntoPyObjectExt;
9use pythonize::{depythonize, pythonize};
10use serde::de::DeserializeOwned;
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13use std::collections::HashMap;
14use std::fmt::Display;
15use std::path::PathBuf;
16use std::str::FromStr;
17use tracing::error;
18
19pub fn deserialize_from_path<T: DeserializeOwned>(path: PathBuf) -> Result<T, TypeError> {
20 let content = std::fs::read_to_string(&path)?;
21
22 let extension = path
23 .extension()
24 .and_then(|ext| ext.to_str())
25 .ok_or_else(|| TypeError::Error(format!("Invalid file path: {:?}", path)))?;
26
27 let item = match extension.to_lowercase().as_str() {
28 "json" => serde_json::from_str(&content)?,
29 "yaml" | "yml" => serde_yaml::from_str(&content)?,
30 _ => {
31 return Err(TypeError::Error(format!(
32 "Unsupported file extension '{}'. Expected .json, .yaml, or .yml",
33 extension
34 )))
35 }
36 };
37
38 Ok(item)
39}
40
41fn default_assertion_task_type() -> EvaluationTaskType {
43 EvaluationTaskType::Assertion
44}
45
46fn default_trace_assertion_task_type() -> EvaluationTaskType {
47 EvaluationTaskType::TraceAssertion
48}
49
50fn default_agent_assertion_task_type() -> EvaluationTaskType {
51 EvaluationTaskType::AgentAssertion
52}
53
54#[pyclass]
55#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
56pub struct AssertionResult {
57 #[pyo3(get)]
58 pub passed: bool,
59 pub actual: Value,
60
61 #[pyo3(get)]
62 pub message: String,
63
64 pub expected: Value,
65}
66
67impl AssertionResult {
68 pub fn new(passed: bool, actual: Value, message: String, expected: Value) -> Self {
69 Self {
70 passed,
71 actual,
72 message,
73 expected,
74 }
75 }
76 pub fn to_metric_value(&self) -> f64 {
77 if self.passed {
78 1.0
79 } else {
80 0.0
81 }
82 }
83}
84
85#[pymethods]
86impl AssertionResult {
87 pub fn __str__(&self) -> String {
88 PyHelperFuncs::__str__(self)
89 }
90
91 #[getter]
92 pub fn get_actual<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
93 let py_value = pythonize(py, &self.actual)?;
94 Ok(py_value)
95 }
96
97 #[getter]
98 pub fn get_expected<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
99 let py_value = pythonize(py, &self.expected)?;
100 Ok(py_value)
101 }
102}
103
104#[pyclass]
105#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
106pub struct AssertionResults {
107 #[pyo3(get)]
108 pub results: HashMap<String, AssertionResult>,
109}
110
111#[pymethods]
112impl AssertionResults {
113 pub fn __str__(&self) -> String {
114 PyHelperFuncs::__str__(self)
115 }
116
117 pub fn __getitem__(&self, key: &str) -> Result<AssertionResult, TypeError> {
118 if let Some(result) = self.results.get(key) {
119 Ok(result.clone())
120 } else {
121 Err(TypeError::KeyNotFound {
122 key: key.to_string(),
123 })
124 }
125 }
126}
127
128#[pyclass]
129#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
130pub struct AssertionTask {
131 #[pyo3(get, set)]
132 pub id: String,
133
134 #[pyo3(get, set)]
135 #[serde(default)]
136 pub context_path: Option<String>,
137
138 #[pyo3(get, set)]
139 #[serde(default)]
140 pub item_context_path: Option<String>,
141
142 #[pyo3(get, set)]
143 pub operator: ComparisonOperator,
144
145 pub expected_value: Value,
146
147 #[pyo3(get, set)]
148 #[serde(default)]
149 pub description: Option<String>,
150
151 #[pyo3(get, set)]
152 #[serde(default)]
153 pub depends_on: Vec<String>,
154
155 #[serde(default = "default_assertion_task_type")]
156 #[pyo3(get)]
157 pub task_type: EvaluationTaskType,
158
159 #[serde(skip_serializing_if = "Option::is_none")]
160 pub result: Option<AssertionResult>,
161
162 #[serde(default)]
163 pub condition: bool,
164}
165
166#[pymethods]
167impl AssertionTask {
168 #[new]
169 #[allow(clippy::too_many_arguments)]
223 #[pyo3(signature = (id, context_path, expected_value, operator, item_context_path=None, description=None, depends_on=None, condition=None))]
224 pub fn new(
225 id: String,
226 context_path: Option<String>,
227 expected_value: &Bound<'_, PyAny>,
228 operator: ComparisonOperator,
229 item_context_path: Option<String>,
230 description: Option<String>,
231 depends_on: Option<Vec<String>>,
232 condition: Option<bool>,
233 ) -> Result<Self, TypeError> {
234 let expected_value = depythonize(expected_value)?;
235 let condition = condition.unwrap_or(false);
236
237 Ok(Self {
238 id: id.to_lowercase(),
239 context_path,
240 item_context_path,
241 operator,
242 expected_value,
243 description,
244 task_type: EvaluationTaskType::Assertion,
245 depends_on: depends_on.unwrap_or_default(),
246 result: None,
247 condition,
248 })
249 }
250
251 #[getter]
252 pub fn get_expected_value<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
253 let py_value = pythonize(py, &self.expected_value)?;
254 Ok(py_value)
255 }
256
257 #[staticmethod]
258 pub fn from_path(path: PathBuf) -> Result<Self, TypeError> {
259 deserialize_from_path(path)
260 }
261}
262
263impl AssertionTask {}
264
265impl TaskAccessor for AssertionTask {
266 fn context_path(&self) -> Option<&str> {
267 self.context_path.as_deref()
268 }
269
270 fn item_context_path(&self) -> Option<&str> {
271 self.item_context_path.as_deref()
272 }
273
274 fn id(&self) -> &str {
275 &self.id
276 }
277
278 fn operator(&self) -> &ComparisonOperator {
279 &self.operator
280 }
281
282 fn task_type(&self) -> &EvaluationTaskType {
283 &self.task_type
284 }
285
286 fn expected_value(&self) -> &Value {
287 &self.expected_value
288 }
289
290 fn depends_on(&self) -> &[String] {
291 &self.depends_on
292 }
293
294 fn add_result(&mut self, result: AssertionResult) {
295 self.result = Some(result);
296 }
297}
298
299pub trait ValueExt {
300 fn to_length(&self) -> Option<i64>;
302
303 fn as_numeric(&self) -> Option<f64>;
305
306 fn is_truthy(&self) -> bool;
308}
309
310impl ValueExt for Value {
311 fn to_length(&self) -> Option<i64> {
312 match self {
313 Value::Array(arr) => Some(arr.len() as i64),
314 Value::String(s) => Some(s.chars().count() as i64),
315 Value::Object(obj) => Some(obj.len() as i64),
316 _ => None,
317 }
318 }
319
320 fn as_numeric(&self) -> Option<f64> {
321 match self {
322 Value::Number(n) => n.as_f64(),
323 _ => None,
324 }
325 }
326
327 fn is_truthy(&self) -> bool {
328 match self {
329 Value::Null => false,
330 Value::Bool(b) => *b,
331 Value::Number(n) => n.as_f64() != Some(0.0),
332 Value::String(s) => !s.is_empty(),
333 Value::Array(arr) => !arr.is_empty(),
334 Value::Object(obj) => !obj.is_empty(),
335 }
336 }
337}
338
339#[pyclass]
341#[derive(Debug, Serialize, Clone, PartialEq)]
342pub struct LLMJudgeTask {
343 #[pyo3(get, set)]
344 pub id: String,
345
346 #[pyo3(get)]
347 pub prompt: Prompt,
348
349 #[pyo3(get)]
350 #[serde(default)]
351 pub context_path: Option<String>,
352
353 pub expected_value: Value,
354
355 #[pyo3(get)]
356 pub operator: ComparisonOperator,
357
358 #[pyo3(get)]
359 pub task_type: EvaluationTaskType,
360
361 #[pyo3(get, set)]
362 #[serde(default)]
363 pub depends_on: Vec<String>,
364
365 #[pyo3(get, set)]
366 #[serde(default)]
367 pub max_retries: Option<u32>,
368
369 #[serde(skip_serializing_if = "Option::is_none")]
370 pub result: Option<AssertionResult>,
371
372 #[serde(default)]
373 pub description: Option<String>,
374
375 #[pyo3(get, set)]
376 #[serde(default)]
377 pub condition: bool,
378}
379
380#[derive(Debug, Deserialize)]
381#[serde(untagged)]
382pub enum PromptConfig {
383 Path { path: String },
384 Inline(Box<Prompt>),
385}
386
387#[derive(Debug, Deserialize)]
388struct LLMJudgeTaskConfig {
389 pub id: String,
390 pub prompt: PromptConfig,
391 pub expected_value: Value,
392 pub operator: ComparisonOperator,
393 pub context_path: Option<String>,
394 pub description: Option<String>,
395 pub depends_on: Vec<String>,
396 pub max_retries: Option<u32>,
397 pub condition: bool,
398}
399
400impl LLMJudgeTaskConfig {
401 pub fn into_task(self) -> Result<LLMJudgeTask, TypeError> {
402 let prompt = match self.prompt {
403 PromptConfig::Path { path } => {
404 Prompt::from_path(PathBuf::from(path)).inspect_err(|e| {
405 error!("Failed to deserialize Prompt from path: {}", e);
406 })?
407 }
408 PromptConfig::Inline(prompt) => *prompt,
409 };
410
411 Ok(LLMJudgeTask {
412 id: self.id.to_lowercase(),
413 prompt,
414 expected_value: self.expected_value,
415 operator: self.operator,
416 context_path: self.context_path,
417 description: self.description,
418 depends_on: self.depends_on,
419 max_retries: self.max_retries.or(Some(3)),
420 task_type: EvaluationTaskType::LLMJudge,
421 result: None,
422 condition: self.condition,
423 })
424 }
425}
426
427#[derive(Debug, Deserialize)]
428struct LLMJudgeTaskInternal {
429 pub id: String,
430 pub prompt: Prompt,
431 pub context_path: Option<String>,
432 pub expected_value: Value,
433 pub operator: ComparisonOperator,
434 pub task_type: EvaluationTaskType,
435 pub depends_on: Vec<String>,
436 pub max_retries: Option<u32>,
437 pub result: Option<AssertionResult>,
438 pub description: Option<String>,
439 pub condition: bool,
440}
441impl LLMJudgeTaskInternal {
442 pub fn into_task(self) -> LLMJudgeTask {
443 LLMJudgeTask {
444 id: self.id.to_lowercase(),
445 prompt: self.prompt,
446 context_path: self.context_path,
447 expected_value: self.expected_value,
448 operator: self.operator,
449 task_type: self.task_type,
450 depends_on: self.depends_on,
451 max_retries: self.max_retries.or(Some(3)),
452 result: self.result,
453 description: self.description,
454 condition: self.condition,
455 }
456 }
457}
458
459#[derive(Debug, Deserialize)]
460#[serde(untagged)]
461enum LLMJudgeFormat {
462 Full(Box<LLMJudgeTaskInternal>),
463 Generic(LLMJudgeTaskConfig),
464}
465
466impl<'de> Deserialize<'de> for LLMJudgeTask {
467 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
468 where
469 D: serde::Deserializer<'de>,
470 {
471 let format = LLMJudgeFormat::deserialize(deserializer)?;
472
473 match format {
474 LLMJudgeFormat::Generic(config) => config.into_task().map_err(serde::de::Error::custom),
475 LLMJudgeFormat::Full(internal) => Ok(internal.into_task()),
476 }
477 }
478}
479
480#[pymethods]
481impl LLMJudgeTask {
482 #[new]
502 #[pyo3(signature = (id, prompt, expected_value, context_path,operator, description=None, depends_on=None, max_retries=None, condition=None))]
503 #[allow(clippy::too_many_arguments)]
504 pub fn new(
505 id: &str,
506 prompt: Prompt,
507 expected_value: &Bound<'_, PyAny>,
508 context_path: Option<String>,
509 operator: ComparisonOperator,
510 description: Option<String>,
511 depends_on: Option<Vec<String>>,
512 max_retries: Option<u32>,
513 condition: Option<bool>,
514 ) -> Result<Self, TypeError> {
515 let expected_value = depythonize(expected_value)?;
516
517 Ok(Self {
518 id: id.to_lowercase(),
519 prompt,
520 expected_value,
521 operator,
522 task_type: EvaluationTaskType::LLMJudge,
523 depends_on: depends_on.unwrap_or_default(),
524 max_retries: max_retries.or(Some(3)),
525 context_path,
526 result: None,
527 description,
528 condition: condition.unwrap_or(false),
529 })
530 }
531
532 pub fn __str__(&self) -> String {
533 PyHelperFuncs::__str__(self)
535 }
536
537 #[getter]
538 pub fn get_expected_value<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
539 let py_value = pythonize(py, &self.expected_value)?;
540 Ok(py_value)
541 }
542
543 #[staticmethod]
544 pub fn from_path(path: PathBuf) -> Result<Self, TypeError> {
545 deserialize_from_path(path)
546 }
547}
548
549impl LLMJudgeTask {
550 #[allow(clippy::too_many_arguments)]
562 pub fn new_rs(
563 id: &str,
564 prompt: Prompt,
565 expected_value: Value,
566 context_path: Option<String>,
567 operator: ComparisonOperator,
568 depends_on: Option<Vec<String>>,
569 max_retries: Option<u32>,
570 description: Option<String>,
571 condition: Option<bool>,
572 ) -> Self {
573 Self {
574 id: id.to_lowercase(),
575 prompt,
576 expected_value,
577 operator,
578 task_type: EvaluationTaskType::LLMJudge,
579 depends_on: depends_on.unwrap_or_default(),
580 max_retries: max_retries.or(Some(3)),
581 context_path,
582 result: None,
583 description,
584 condition: condition.unwrap_or(false),
585 }
586 }
587}
588
589impl TaskAccessor for LLMJudgeTask {
590 fn context_path(&self) -> Option<&str> {
591 self.context_path.as_deref()
592 }
593
594 fn item_context_path(&self) -> Option<&str> {
595 None
596 }
597
598 fn id(&self) -> &str {
599 &self.id
600 }
601 fn task_type(&self) -> &EvaluationTaskType {
602 &self.task_type
603 }
604
605 fn operator(&self) -> &ComparisonOperator {
606 &self.operator
607 }
608
609 fn expected_value(&self) -> &Value {
610 &self.expected_value
611 }
612 fn depends_on(&self) -> &[String] {
613 &self.depends_on
614 }
615 fn add_result(&mut self, result: AssertionResult) {
616 self.result = Some(result);
617 }
618}
619
620#[pyclass(eq, eq_int)]
621#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
622pub enum SpanStatus {
623 Ok,
624 Error,
625 Unset,
626}
627
628#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
629pub struct PyValueWrapper(pub Value);
630
631impl<'py> IntoPyObject<'py> for PyValueWrapper {
632 type Target = PyAny;
633 type Output = Bound<'py, Self::Target>;
634 type Error = TypeError;
635
636 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
637 pythonize(py, &self.0).map_err(TypeError::from)
638 }
639}
640
641impl<'a, 'py> FromPyObject<'a, 'py> for PyValueWrapper {
642 type Error = TypeError;
643
644 fn extract(ob: pyo3::Borrowed<'a, 'py, pyo3::PyAny>) -> Result<Self, Self::Error> {
645 let value: Value = depythonize(&ob)?;
646 Ok(PyValueWrapper(value))
647 }
648}
649
650#[pyclass(eq)]
652#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
653pub enum SpanFilter {
654 ByName { name: String },
656
657 ByNamePattern { pattern: String },
659
660 WithAttribute { key: String },
662
663 WithAttributeValue { key: String, value: PyValueWrapper },
665
666 WithStatus { status: SpanStatus },
668
669 WithDuration {
671 min_ms: Option<f64>,
672 max_ms: Option<f64>,
673 },
674
675 Sequence { names: Vec<String> },
677
678 And { filters: Vec<SpanFilter> },
680
681 Or { filters: Vec<SpanFilter> },
683}
684
685#[pymethods]
686impl SpanFilter {
687 #[staticmethod]
688 pub fn by_name(name: String) -> Self {
689 SpanFilter::ByName { name }
690 }
691
692 #[staticmethod]
693 pub fn by_name_pattern(pattern: String) -> Self {
694 SpanFilter::ByNamePattern { pattern }
695 }
696
697 #[staticmethod]
698 pub fn with_attribute(key: String) -> Self {
699 SpanFilter::WithAttribute { key }
700 }
701
702 #[staticmethod]
703 pub fn with_attribute_value(key: String, value: &Bound<'_, PyAny>) -> Result<Self, TypeError> {
704 let value = PyValueWrapper(depythonize(value)?);
705 Ok(SpanFilter::WithAttributeValue { key, value })
706 }
707
708 #[staticmethod]
709 pub fn with_status(status: SpanStatus) -> Self {
710 SpanFilter::WithStatus { status }
711 }
712
713 #[staticmethod]
714 #[pyo3(signature = (min_ms=None, max_ms=None))]
715 pub fn with_duration(min_ms: Option<f64>, max_ms: Option<f64>) -> Self {
716 SpanFilter::WithDuration { min_ms, max_ms }
717 }
718
719 #[staticmethod]
720 pub fn sequence(names: Vec<String>) -> Self {
721 SpanFilter::Sequence { names }
722 }
723
724 pub fn and_(&self, other: SpanFilter) -> Self {
725 match self {
726 SpanFilter::And { filters } => {
727 let mut new_filters = filters.clone();
728 new_filters.push(other);
729 SpanFilter::And {
730 filters: new_filters,
731 }
732 }
733 _ => SpanFilter::And {
734 filters: vec![self.clone(), other],
735 },
736 }
737 }
738
739 pub fn or_(&self, other: SpanFilter) -> Self {
740 match self {
741 SpanFilter::Or { filters } => {
742 let mut new_filters = filters.clone();
743 new_filters.push(other);
744 SpanFilter::Or {
745 filters: new_filters,
746 }
747 }
748 _ => SpanFilter::Or {
749 filters: vec![self.clone(), other],
750 },
751 }
752 }
753}
754
755#[pyclass(eq, eq_int)]
756#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
757pub enum AggregationType {
758 Count,
759 Sum,
760 Average,
761 Min,
762 Max,
763 First,
764 Last,
765}
766
767#[pyclass(eq)]
769#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
770pub enum TraceAssertion {
771 SpanSequence { span_names: Vec<String> },
773
774 SpanSet { span_names: Vec<String> },
776
777 SpanCount { filter: SpanFilter },
779
780 SpanExists { filter: SpanFilter },
782
783 SpanAttribute {
785 filter: SpanFilter,
786 attribute_key: String,
787 },
788
789 SpanDuration { filter: SpanFilter },
791
792 SpanAggregation {
794 filter: SpanFilter,
795 attribute_key: String,
796 aggregation: AggregationType,
797 },
798
799 TraceDuration {},
801
802 TraceSpanCount {},
804
805 TraceErrorCount {},
807
808 TraceServiceCount {},
810
811 TraceMaxDepth {},
813
814 TraceAttribute { attribute_key: String },
816}
817
818impl Display for TraceAssertion {
820 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
821 let s = serde_json::to_string(self).unwrap_or_default();
823 write!(f, "{}", s)
824 }
825}
826
827#[pymethods]
828impl TraceAssertion {
829 #[staticmethod]
830 pub fn span_sequence(span_names: Vec<String>) -> Self {
831 TraceAssertion::SpanSequence { span_names }
832 }
833
834 #[staticmethod]
835 pub fn span_set(span_names: Vec<String>) -> Self {
836 TraceAssertion::SpanSet { span_names }
837 }
838
839 #[staticmethod]
840 pub fn span_count(filter: SpanFilter) -> Self {
841 TraceAssertion::SpanCount { filter }
842 }
843
844 #[staticmethod]
845 pub fn span_exists(filter: SpanFilter) -> Self {
846 TraceAssertion::SpanExists { filter }
847 }
848
849 #[staticmethod]
850 pub fn span_attribute(filter: SpanFilter, attribute_key: String) -> Self {
851 TraceAssertion::SpanAttribute {
852 filter,
853 attribute_key,
854 }
855 }
856
857 #[staticmethod]
858 pub fn span_duration(filter: SpanFilter) -> Self {
859 TraceAssertion::SpanDuration { filter }
860 }
861
862 #[staticmethod]
863 pub fn span_aggregation(
864 filter: SpanFilter,
865 attribute_key: String,
866 aggregation: AggregationType,
867 ) -> Self {
868 TraceAssertion::SpanAggregation {
869 filter,
870 attribute_key,
871 aggregation,
872 }
873 }
874
875 #[staticmethod]
876 pub fn trace_duration() -> Self {
877 TraceAssertion::TraceDuration {}
878 }
879
880 #[staticmethod]
881 pub fn trace_span_count() -> Self {
882 TraceAssertion::TraceSpanCount {}
883 }
884
885 #[staticmethod]
886 pub fn trace_error_count() -> Self {
887 TraceAssertion::TraceErrorCount {}
888 }
889
890 #[staticmethod]
891 pub fn trace_service_count() -> Self {
892 TraceAssertion::TraceServiceCount {}
893 }
894
895 #[staticmethod]
896 pub fn trace_max_depth() -> Self {
897 TraceAssertion::TraceMaxDepth {}
898 }
899
900 #[staticmethod]
901 pub fn trace_attribute(attribute_key: String) -> Self {
902 TraceAssertion::TraceAttribute { attribute_key }
903 }
904
905 pub fn model_dump_json(&self) -> String {
906 serde_json::to_string(self).unwrap_or_default()
907 }
908}
909
910#[pyclass]
911#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
912pub struct TraceAssertionTask {
913 #[pyo3(get, set)]
914 pub id: String,
915
916 #[pyo3(get, set)]
917 pub assertion: TraceAssertion,
918
919 #[pyo3(get, set)]
920 pub operator: ComparisonOperator,
921
922 pub expected_value: Value,
923
924 #[pyo3(get, set)]
925 #[serde(default)]
926 pub description: Option<String>,
927
928 #[pyo3(get, set)]
929 #[serde(default)]
930 pub depends_on: Vec<String>,
931
932 #[serde(default = "default_trace_assertion_task_type")]
933 #[pyo3(get)]
934 pub task_type: EvaluationTaskType,
935
936 #[serde(skip_serializing_if = "Option::is_none")]
937 pub result: Option<AssertionResult>,
938
939 #[pyo3(get, set)]
940 #[serde(default)]
941 pub condition: bool,
942}
943
944#[pymethods]
945impl TraceAssertionTask {
946 #[new]
996 #[pyo3(signature = (id, assertion, expected_value, operator, description=None, depends_on=None, condition=None))]
998 pub fn new(
999 id: String,
1000 assertion: TraceAssertion,
1001 expected_value: &Bound<'_, PyAny>,
1002 operator: ComparisonOperator,
1003 description: Option<String>,
1004 depends_on: Option<Vec<String>>,
1005 condition: Option<bool>,
1006 ) -> Result<Self, TypeError> {
1007 let expected_value = depythonize(expected_value)?;
1008
1009 Ok(Self {
1010 id: id.to_lowercase(),
1011 assertion,
1012 operator,
1013 expected_value,
1014 description,
1015 task_type: EvaluationTaskType::TraceAssertion,
1016 depends_on: depends_on.unwrap_or_default(),
1017 result: None,
1018 condition: condition.unwrap_or(false),
1019 })
1020 }
1021
1022 pub fn __str__(&self) -> String {
1023 PyHelperFuncs::__str__(self)
1025 }
1026
1027 #[getter]
1028 pub fn get_expected_value<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
1029 let py_value = pythonize(py, &self.expected_value)?;
1030 Ok(py_value)
1031 }
1032
1033 #[staticmethod]
1034 pub fn from_path(path: PathBuf) -> Result<Self, TypeError> {
1035 deserialize_from_path(path)
1036 }
1037}
1038
1039impl TaskAccessor for TraceAssertionTask {
1040 fn context_path(&self) -> Option<&str> {
1041 None
1042 }
1043
1044 fn item_context_path(&self) -> Option<&str> {
1045 None
1046 }
1047
1048 fn id(&self) -> &str {
1049 &self.id
1050 }
1051
1052 fn operator(&self) -> &ComparisonOperator {
1053 &self.operator
1054 }
1055
1056 fn task_type(&self) -> &EvaluationTaskType {
1057 &self.task_type
1058 }
1059
1060 fn expected_value(&self) -> &Value {
1061 &self.expected_value
1062 }
1063
1064 fn depends_on(&self) -> &[String] {
1065 &self.depends_on
1066 }
1067
1068 fn add_result(&mut self, result: AssertionResult) {
1069 self.result = Some(result);
1070 }
1071}
1072
1073#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1074pub struct ToolCall {
1075 pub name: String,
1076 pub arguments: Value,
1077 pub result: Option<Value>,
1078 pub call_id: Option<String>,
1079}
1080
1081#[pyclass]
1082#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1083pub struct TokenUsage {
1084 #[pyo3(get, set)]
1085 pub input_tokens: Option<i64>,
1086
1087 #[pyo3(get, set)]
1088 pub output_tokens: Option<i64>,
1089
1090 #[pyo3(get, set)]
1091 pub total_tokens: Option<i64>,
1092}
1093
1094#[pymethods]
1095impl TokenUsage {
1096 #[new]
1097 #[pyo3(signature = (input_tokens=None, output_tokens=None, total_tokens=None))]
1098 pub fn new(
1099 input_tokens: Option<i64>,
1100 output_tokens: Option<i64>,
1101 total_tokens: Option<i64>,
1102 ) -> Self {
1103 Self {
1104 input_tokens,
1105 output_tokens,
1106 total_tokens,
1107 }
1108 }
1109
1110 pub fn __str__(&self) -> String {
1111 serde_json::to_string_pretty(self).unwrap_or_default()
1112 }
1113}
1114
1115#[pyclass(eq)]
1116#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1117pub enum AgentAssertion {
1118 ToolCalled { name: String },
1120
1121 ToolNotCalled { name: String },
1123
1124 ToolCalledWithArgs {
1126 name: String,
1127 arguments: PyValueWrapper,
1128 },
1129
1130 ToolCallSequence { names: Vec<String> },
1132
1133 ToolCallCount { name: Option<String> },
1135
1136 ToolArgument { name: String, argument_key: String },
1138
1139 ToolResult { name: String },
1141
1142 ResponseContent {},
1144
1145 ResponseModel {},
1147
1148 ResponseFinishReason {},
1150
1151 ResponseInputTokens {},
1153
1154 ResponseOutputTokens {},
1156
1157 ResponseTotalTokens {},
1159
1160 ResponseField { path: String },
1162}
1163
1164impl Display for AgentAssertion {
1165 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1166 let s = serde_json::to_string(self).unwrap_or_default();
1167 write!(f, "{}", s)
1168 }
1169}
1170
1171#[pymethods]
1172impl AgentAssertion {
1173 #[staticmethod]
1174 pub fn tool_called(name: &str) -> Self {
1175 AgentAssertion::ToolCalled {
1176 name: name.to_string(),
1177 }
1178 }
1179
1180 #[staticmethod]
1181 pub fn tool_not_called(name: &str) -> Self {
1182 AgentAssertion::ToolNotCalled {
1183 name: name.to_string(),
1184 }
1185 }
1186
1187 #[staticmethod]
1188 pub fn tool_called_with_args(
1189 name: &str,
1190 arguments: &Bound<'_, PyAny>,
1191 ) -> Result<Self, TypeError> {
1192 let arguments: Value = depythonize(arguments)?;
1193 Ok(AgentAssertion::ToolCalledWithArgs {
1194 name: name.to_string(),
1195 arguments: PyValueWrapper(arguments),
1196 })
1197 }
1198
1199 #[staticmethod]
1200 pub fn tool_call_sequence(names: Vec<String>) -> Self {
1201 AgentAssertion::ToolCallSequence { names }
1202 }
1203
1204 #[staticmethod]
1205 #[pyo3(signature = (name=None))]
1206 pub fn tool_call_count(name: Option<String>) -> Self {
1207 AgentAssertion::ToolCallCount { name }
1208 }
1209
1210 #[staticmethod]
1211 pub fn tool_argument(name: &str, argument_key: &str) -> Self {
1212 AgentAssertion::ToolArgument {
1213 name: name.to_string(),
1214 argument_key: argument_key.to_string(),
1215 }
1216 }
1217
1218 #[staticmethod]
1219 pub fn tool_result(name: &str) -> Self {
1220 AgentAssertion::ToolResult {
1221 name: name.to_string(),
1222 }
1223 }
1224
1225 #[staticmethod]
1226 pub fn response_content() -> Self {
1227 AgentAssertion::ResponseContent {}
1228 }
1229
1230 #[staticmethod]
1231 pub fn response_model() -> Self {
1232 AgentAssertion::ResponseModel {}
1233 }
1234
1235 #[staticmethod]
1236 pub fn response_finish_reason() -> Self {
1237 AgentAssertion::ResponseFinishReason {}
1238 }
1239
1240 #[staticmethod]
1241 pub fn response_input_tokens() -> Self {
1242 AgentAssertion::ResponseInputTokens {}
1243 }
1244
1245 #[staticmethod]
1246 pub fn response_output_tokens() -> Self {
1247 AgentAssertion::ResponseOutputTokens {}
1248 }
1249
1250 #[staticmethod]
1251 pub fn response_total_tokens() -> Self {
1252 AgentAssertion::ResponseTotalTokens {}
1253 }
1254
1255 #[staticmethod]
1256 pub fn response_field(path: &str) -> Self {
1257 AgentAssertion::ResponseField {
1258 path: path.to_string(),
1259 }
1260 }
1261
1262 pub fn __str__(&self) -> String {
1263 serde_json::to_string_pretty(self).unwrap_or_default()
1264 }
1265}
1266
1267#[pyclass]
1268#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1269pub struct AgentAssertionTask {
1270 #[pyo3(get, set)]
1271 pub id: String,
1272
1273 #[pyo3(get, set)]
1274 pub assertion: AgentAssertion,
1275
1276 #[pyo3(get, set)]
1277 pub operator: ComparisonOperator,
1278
1279 pub expected_value: Value,
1280
1281 #[pyo3(get, set)]
1282 #[serde(default)]
1283 pub description: Option<String>,
1284
1285 #[pyo3(get, set)]
1286 #[serde(default)]
1287 pub depends_on: Vec<String>,
1288
1289 #[serde(default = "default_agent_assertion_task_type")]
1290 #[pyo3(get)]
1291 pub task_type: EvaluationTaskType,
1292
1293 #[serde(skip_serializing_if = "Option::is_none")]
1294 pub result: Option<AssertionResult>,
1295
1296 #[pyo3(get, set)]
1297 #[serde(default)]
1298 pub condition: bool,
1299}
1300
1301#[pymethods]
1302impl AgentAssertionTask {
1303 #[new]
1304 #[pyo3(signature = (id, assertion, expected_value, operator, description=None, depends_on=None, condition=None))]
1305 pub fn new(
1306 id: String,
1307 assertion: AgentAssertion,
1308 expected_value: &Bound<'_, PyAny>,
1309 operator: ComparisonOperator,
1310 description: Option<String>,
1311 depends_on: Option<Vec<String>>,
1312 condition: Option<bool>,
1313 ) -> Result<Self, TypeError> {
1314 let expected_value = depythonize(expected_value)?;
1315
1316 Ok(Self {
1317 id: id.to_lowercase(),
1318 assertion,
1319 operator,
1320 expected_value,
1321 description,
1322 task_type: EvaluationTaskType::AgentAssertion,
1323 depends_on: depends_on.unwrap_or_default(),
1324 result: None,
1325 condition: condition.unwrap_or(false),
1326 })
1327 }
1328
1329 pub fn __str__(&self) -> String {
1330 PyHelperFuncs::__str__(self)
1331 }
1332
1333 pub fn model_dump_json(&self) -> String {
1334 serde_json::to_string(self).unwrap_or_default()
1335 }
1336
1337 #[getter]
1338 pub fn get_expected_value<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
1339 let py_value = pythonize(py, &self.expected_value)?;
1340 Ok(py_value)
1341 }
1342
1343 #[getter]
1344 pub fn get_result<'py>(&self, py: Python<'py>) -> Result<Option<Bound<'py, PyAny>>, TypeError> {
1345 match &self.result {
1346 Some(result) => {
1347 let py_value = pythonize(py, result)?;
1348 Ok(Some(py_value))
1349 }
1350 None => Ok(None),
1351 }
1352 }
1353
1354 #[staticmethod]
1355 pub fn from_path(path: PathBuf) -> Result<Self, TypeError> {
1356 deserialize_from_path(path)
1357 }
1358}
1359
1360impl TaskAccessor for AgentAssertionTask {
1361 fn context_path(&self) -> Option<&str> {
1362 None
1363 }
1364
1365 fn item_context_path(&self) -> Option<&str> {
1366 None
1367 }
1368
1369 fn id(&self) -> &str {
1370 &self.id
1371 }
1372
1373 fn operator(&self) -> &ComparisonOperator {
1374 &self.operator
1375 }
1376
1377 fn task_type(&self) -> &EvaluationTaskType {
1378 &self.task_type
1379 }
1380
1381 fn expected_value(&self) -> &Value {
1382 &self.expected_value
1383 }
1384
1385 fn depends_on(&self) -> &[String] {
1386 &self.depends_on
1387 }
1388
1389 fn add_result(&mut self, result: AssertionResult) {
1390 self.result = Some(result);
1391 }
1392}
1393
1394#[derive(Debug, Clone)]
1395pub enum EvaluationTask {
1396 Assertion(Box<AssertionTask>),
1397 LLMJudge(Box<LLMJudgeTask>),
1398 TraceAssertion(Box<TraceAssertionTask>),
1399 AgentAssertion(Box<AgentAssertionTask>),
1400}
1401
1402impl TaskAccessor for EvaluationTask {
1403 fn context_path(&self) -> Option<&str> {
1404 match self {
1405 EvaluationTask::Assertion(t) => t.context_path(),
1406 EvaluationTask::LLMJudge(t) => t.context_path(),
1407 EvaluationTask::TraceAssertion(t) => t.context_path(),
1408 EvaluationTask::AgentAssertion(t) => t.context_path(),
1409 }
1410 }
1411
1412 fn item_context_path(&self) -> Option<&str> {
1413 match self {
1414 EvaluationTask::Assertion(t) => t.item_context_path(),
1415 EvaluationTask::LLMJudge(t) => t.item_context_path(),
1416 EvaluationTask::TraceAssertion(t) => t.item_context_path(),
1417 EvaluationTask::AgentAssertion(t) => t.item_context_path(),
1418 }
1419 }
1420
1421 fn id(&self) -> &str {
1422 match self {
1423 EvaluationTask::Assertion(t) => t.id(),
1424 EvaluationTask::LLMJudge(t) => t.id(),
1425 EvaluationTask::TraceAssertion(t) => t.id(),
1426 EvaluationTask::AgentAssertion(t) => t.id(),
1427 }
1428 }
1429
1430 fn task_type(&self) -> &EvaluationTaskType {
1431 match self {
1432 EvaluationTask::Assertion(t) => t.task_type(),
1433 EvaluationTask::LLMJudge(t) => t.task_type(),
1434 EvaluationTask::TraceAssertion(t) => t.task_type(),
1435 EvaluationTask::AgentAssertion(t) => t.task_type(),
1436 }
1437 }
1438
1439 fn operator(&self) -> &ComparisonOperator {
1440 match self {
1441 EvaluationTask::Assertion(t) => t.operator(),
1442 EvaluationTask::LLMJudge(t) => t.operator(),
1443 EvaluationTask::TraceAssertion(t) => t.operator(),
1444 EvaluationTask::AgentAssertion(t) => t.operator(),
1445 }
1446 }
1447
1448 fn expected_value(&self) -> &Value {
1449 match self {
1450 EvaluationTask::Assertion(t) => t.expected_value(),
1451 EvaluationTask::LLMJudge(t) => t.expected_value(),
1452 EvaluationTask::TraceAssertion(t) => t.expected_value(),
1453 EvaluationTask::AgentAssertion(t) => t.expected_value(),
1454 }
1455 }
1456
1457 fn depends_on(&self) -> &[String] {
1458 match self {
1459 EvaluationTask::Assertion(t) => t.depends_on(),
1460 EvaluationTask::LLMJudge(t) => t.depends_on(),
1461 EvaluationTask::TraceAssertion(t) => t.depends_on(),
1462 EvaluationTask::AgentAssertion(t) => t.depends_on(),
1463 }
1464 }
1465
1466 fn add_result(&mut self, result: AssertionResult) {
1467 match self {
1468 EvaluationTask::Assertion(t) => t.add_result(result),
1469 EvaluationTask::LLMJudge(t) => t.add_result(result),
1470 EvaluationTask::TraceAssertion(t) => t.add_result(result),
1471 EvaluationTask::AgentAssertion(t) => t.add_result(result),
1472 }
1473 }
1474}
1475
1476pub struct EvaluationTasks(Vec<EvaluationTask>);
1477
1478impl EvaluationTasks {
1479 pub fn new() -> Self {
1481 Self(Vec::new())
1482 }
1483
1484 pub fn add_task(mut self, task: impl Into<EvaluationTask>) -> Self {
1486 self.0.push(task.into());
1487 self
1488 }
1489
1490 pub fn build(self) -> Vec<EvaluationTask> {
1492 self.0
1493 }
1494}
1495
1496impl From<AssertionTask> for EvaluationTask {
1497 fn from(task: AssertionTask) -> Self {
1498 EvaluationTask::Assertion(Box::new(task))
1499 }
1500}
1501
1502impl From<LLMJudgeTask> for EvaluationTask {
1503 fn from(task: LLMJudgeTask) -> Self {
1504 EvaluationTask::LLMJudge(Box::new(task))
1505 }
1506}
1507
1508impl From<TraceAssertionTask> for EvaluationTask {
1509 fn from(task: TraceAssertionTask) -> Self {
1510 EvaluationTask::TraceAssertion(Box::new(task))
1511 }
1512}
1513
1514impl From<AgentAssertionTask> for EvaluationTask {
1515 fn from(task: AgentAssertionTask) -> Self {
1516 EvaluationTask::AgentAssertion(Box::new(task))
1517 }
1518}
1519
1520impl Default for EvaluationTasks {
1521 fn default() -> Self {
1522 Self::new()
1523 }
1524}
1525
1526#[pyclass(eq, eq_int)]
1527#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1528pub enum ComparisonOperator {
1529 Equals,
1531 NotEqual,
1532 GreaterThan,
1533 GreaterThanOrEqual,
1534 LessThan,
1535 LessThanOrEqual,
1536 Contains,
1537 NotContains,
1538 StartsWith,
1539 EndsWith,
1540 Matches,
1541 HasLengthGreaterThan,
1542 HasLengthLessThan,
1543 HasLengthEqual,
1544 HasLengthGreaterThanOrEqual,
1545 HasLengthLessThanOrEqual,
1546
1547 IsNumeric,
1549 IsString,
1550 IsBoolean,
1551 IsNull,
1552 IsArray,
1553 IsObject,
1554
1555 IsEmail,
1557 IsUrl,
1558 IsUuid,
1559 IsIso8601,
1560 IsJson,
1561 MatchesRegex,
1562
1563 InRange,
1565 NotInRange,
1566 IsPositive,
1567 IsNegative,
1568 IsZero,
1569
1570 SequenceMatches,
1572 ContainsAll,
1573 ContainsAny,
1574 ContainsNone,
1575 IsEmpty,
1576 IsNotEmpty,
1577 HasUniqueItems,
1578
1579 IsAlphabetic,
1581 IsAlphanumeric,
1582 IsLowerCase,
1583 IsUpperCase,
1584 ContainsWord,
1585
1586 ApproximatelyEquals,
1588}
1589
1590impl Display for ComparisonOperator {
1591 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1592 write!(f, "{}", self.as_str())
1593 }
1594}
1595
1596impl FromStr for ComparisonOperator {
1597 type Err = TypeError;
1598
1599 fn from_str(s: &str) -> Result<Self, Self::Err> {
1600 match s {
1601 "Equals" => Ok(ComparisonOperator::Equals),
1602 "NotEqual" => Ok(ComparisonOperator::NotEqual),
1603 "GreaterThan" => Ok(ComparisonOperator::GreaterThan),
1604 "GreaterThanOrEqual" => Ok(ComparisonOperator::GreaterThanOrEqual),
1605 "LessThan" => Ok(ComparisonOperator::LessThan),
1606 "LessThanOrEqual" => Ok(ComparisonOperator::LessThanOrEqual),
1607 "Contains" => Ok(ComparisonOperator::Contains),
1608 "NotContains" => Ok(ComparisonOperator::NotContains),
1609 "StartsWith" => Ok(ComparisonOperator::StartsWith),
1610 "EndsWith" => Ok(ComparisonOperator::EndsWith),
1611 "Matches" => Ok(ComparisonOperator::Matches),
1612 "HasLengthEqual" => Ok(ComparisonOperator::HasLengthEqual),
1613 "HasLengthGreaterThan" => Ok(ComparisonOperator::HasLengthGreaterThan),
1614 "HasLengthLessThan" => Ok(ComparisonOperator::HasLengthLessThan),
1615 "HasLengthGreaterThanOrEqual" => Ok(ComparisonOperator::HasLengthGreaterThanOrEqual),
1616 "HasLengthLessThanOrEqual" => Ok(ComparisonOperator::HasLengthLessThanOrEqual),
1617
1618 "IsNumeric" => Ok(ComparisonOperator::IsNumeric),
1620 "IsString" => Ok(ComparisonOperator::IsString),
1621 "IsBoolean" => Ok(ComparisonOperator::IsBoolean),
1622 "IsNull" => Ok(ComparisonOperator::IsNull),
1623 "IsArray" => Ok(ComparisonOperator::IsArray),
1624 "IsObject" => Ok(ComparisonOperator::IsObject),
1625
1626 "IsEmail" => Ok(ComparisonOperator::IsEmail),
1628 "IsUrl" => Ok(ComparisonOperator::IsUrl),
1629 "IsUuid" => Ok(ComparisonOperator::IsUuid),
1630 "IsIso8601" => Ok(ComparisonOperator::IsIso8601),
1631 "IsJson" => Ok(ComparisonOperator::IsJson),
1632 "MatchesRegex" => Ok(ComparisonOperator::MatchesRegex),
1633
1634 "InRange" => Ok(ComparisonOperator::InRange),
1636 "NotInRange" => Ok(ComparisonOperator::NotInRange),
1637 "IsPositive" => Ok(ComparisonOperator::IsPositive),
1638 "IsNegative" => Ok(ComparisonOperator::IsNegative),
1639 "IsZero" => Ok(ComparisonOperator::IsZero),
1640
1641 "ContainsAll" => Ok(ComparisonOperator::ContainsAll),
1643 "ContainsAny" => Ok(ComparisonOperator::ContainsAny),
1644 "ContainsNone" => Ok(ComparisonOperator::ContainsNone),
1645 "IsEmpty" => Ok(ComparisonOperator::IsEmpty),
1646 "IsNotEmpty" => Ok(ComparisonOperator::IsNotEmpty),
1647 "HasUniqueItems" => Ok(ComparisonOperator::HasUniqueItems),
1648 "SequenceMatches" => Ok(ComparisonOperator::SequenceMatches),
1649
1650 "IsAlphabetic" => Ok(ComparisonOperator::IsAlphabetic),
1652 "IsAlphanumeric" => Ok(ComparisonOperator::IsAlphanumeric),
1653 "IsLowerCase" => Ok(ComparisonOperator::IsLowerCase),
1654 "IsUpperCase" => Ok(ComparisonOperator::IsUpperCase),
1655 "ContainsWord" => Ok(ComparisonOperator::ContainsWord),
1656
1657 "ApproximatelyEquals" => Ok(ComparisonOperator::ApproximatelyEquals),
1659
1660 _ => Err(TypeError::InvalidCompressionTypeError),
1661 }
1662 }
1663}
1664
1665impl ComparisonOperator {
1666 pub fn as_str(&self) -> &str {
1667 match self {
1668 ComparisonOperator::Equals => "Equals",
1669 ComparisonOperator::NotEqual => "NotEqual",
1670 ComparisonOperator::GreaterThan => "GreaterThan",
1671 ComparisonOperator::GreaterThanOrEqual => "GreaterThanOrEqual",
1672 ComparisonOperator::LessThan => "LessThan",
1673 ComparisonOperator::LessThanOrEqual => "LessThanOrEqual",
1674 ComparisonOperator::Contains => "Contains",
1675 ComparisonOperator::NotContains => "NotContains",
1676 ComparisonOperator::StartsWith => "StartsWith",
1677 ComparisonOperator::EndsWith => "EndsWith",
1678 ComparisonOperator::Matches => "Matches",
1679 ComparisonOperator::HasLengthEqual => "HasLengthEqual",
1680 ComparisonOperator::HasLengthGreaterThan => "HasLengthGreaterThan",
1681 ComparisonOperator::HasLengthLessThan => "HasLengthLessThan",
1682 ComparisonOperator::HasLengthGreaterThanOrEqual => "HasLengthGreaterThanOrEqual",
1683 ComparisonOperator::HasLengthLessThanOrEqual => "HasLengthLessThanOrEqual",
1684
1685 ComparisonOperator::IsNumeric => "IsNumeric",
1687 ComparisonOperator::IsString => "IsString",
1688 ComparisonOperator::IsBoolean => "IsBoolean",
1689 ComparisonOperator::IsNull => "IsNull",
1690 ComparisonOperator::IsArray => "IsArray",
1691 ComparisonOperator::IsObject => "IsObject",
1692
1693 ComparisonOperator::IsEmail => "IsEmail",
1695 ComparisonOperator::IsUrl => "IsUrl",
1696 ComparisonOperator::IsUuid => "IsUuid",
1697 ComparisonOperator::IsIso8601 => "IsIso8601",
1698 ComparisonOperator::IsJson => "IsJson",
1699 ComparisonOperator::MatchesRegex => "MatchesRegex",
1700
1701 ComparisonOperator::InRange => "InRange",
1703 ComparisonOperator::NotInRange => "NotInRange",
1704 ComparisonOperator::IsPositive => "IsPositive",
1705 ComparisonOperator::IsNegative => "IsNegative",
1706 ComparisonOperator::IsZero => "IsZero",
1707
1708 ComparisonOperator::ContainsAll => "ContainsAll",
1710 ComparisonOperator::ContainsAny => "ContainsAny",
1711 ComparisonOperator::ContainsNone => "ContainsNone",
1712 ComparisonOperator::IsEmpty => "IsEmpty",
1713 ComparisonOperator::IsNotEmpty => "IsNotEmpty",
1714 ComparisonOperator::HasUniqueItems => "HasUniqueItems",
1715 ComparisonOperator::SequenceMatches => "SequenceMatches",
1716
1717 ComparisonOperator::IsAlphabetic => "IsAlphabetic",
1719 ComparisonOperator::IsAlphanumeric => "IsAlphanumeric",
1720 ComparisonOperator::IsLowerCase => "IsLowerCase",
1721 ComparisonOperator::IsUpperCase => "IsUpperCase",
1722 ComparisonOperator::ContainsWord => "ContainsWord",
1723
1724 ComparisonOperator::ApproximatelyEquals => "ApproximatelyEquals",
1726 }
1727 }
1728}
1729
1730#[pyclass]
1731#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1732pub enum AssertionValue {
1733 String(String),
1734 Number(f64),
1735 Integer(i64),
1736 Boolean(bool),
1737 List(Vec<AssertionValue>),
1738 Null(),
1739}
1740
1741impl AssertionValue {
1742 pub fn to_actual(self, comparison: &ComparisonOperator) -> AssertionValue {
1743 match comparison {
1744 ComparisonOperator::HasLengthEqual
1745 | ComparisonOperator::HasLengthGreaterThan
1746 | ComparisonOperator::HasLengthLessThan
1747 | ComparisonOperator::HasLengthGreaterThanOrEqual
1748 | ComparisonOperator::HasLengthLessThanOrEqual => match self {
1749 AssertionValue::List(arr) => AssertionValue::Integer(arr.len() as i64),
1750 AssertionValue::String(s) => AssertionValue::Integer(s.chars().count() as i64),
1751 _ => self,
1752 },
1753 _ => self,
1754 }
1755 }
1756
1757 pub fn to_serde_value(&self) -> Value {
1758 match self {
1759 AssertionValue::String(s) => Value::String(s.clone()),
1760 AssertionValue::Number(n) => Value::Number(serde_json::Number::from_f64(*n).unwrap()),
1761 AssertionValue::Integer(i) => Value::Number(serde_json::Number::from(*i)),
1762 AssertionValue::Boolean(b) => Value::Bool(*b),
1763 AssertionValue::List(arr) => {
1764 let json_arr: Vec<Value> = arr.iter().map(|v| v.to_serde_value()).collect();
1765 Value::Array(json_arr)
1766 }
1767 AssertionValue::Null() => Value::Null,
1768 }
1769 }
1770}
1771pub fn assertion_value_from_py(value: &Bound<'_, PyAny>) -> Result<AssertionValue, TypeError> {
1778 if value.is_none() {
1780 return Ok(AssertionValue::Null());
1781 }
1782
1783 if value.is_instance_of::<PyBool>() {
1785 return Ok(AssertionValue::Boolean(value.extract()?));
1786 }
1787
1788 if value.is_instance_of::<PyString>() {
1789 return Ok(AssertionValue::String(value.extract()?));
1790 }
1791
1792 if value.is_instance_of::<PyInt>() {
1793 return Ok(AssertionValue::Integer(value.extract()?));
1794 }
1795
1796 if value.is_instance_of::<PyFloat>() {
1797 return Ok(AssertionValue::Number(value.extract()?));
1798 }
1799
1800 if value.is_instance_of::<PyList>() {
1801 let list = value.cast::<PyList>()?; let assertion_list = list
1804 .iter()
1805 .map(|item| assertion_value_from_py(&item))
1806 .collect::<Result<Vec<_>, _>>()?;
1807 return Ok(AssertionValue::List(assertion_list));
1808 }
1809
1810 Err(TypeError::UnsupportedType(
1812 value.get_type().name()?.to_string(),
1813 ))
1814}
1815
1816#[pyclass(eq, eq_int)]
1817#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1818pub enum EvaluationTaskType {
1819 Assertion,
1820 LLMJudge,
1821 Conditional,
1822 HumanValidation,
1823 TraceAssertion,
1824 AgentAssertion,
1825}
1826
1827impl Display for EvaluationTaskType {
1828 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1829 let task_type_str = match self {
1830 EvaluationTaskType::Assertion => "Assertion",
1831 EvaluationTaskType::LLMJudge => "LLMJudge",
1832 EvaluationTaskType::Conditional => "Conditional",
1833 EvaluationTaskType::HumanValidation => "HumanValidation",
1834 EvaluationTaskType::TraceAssertion => "TraceAssertion",
1835 EvaluationTaskType::AgentAssertion => "AgentAssertion",
1836 };
1837 write!(f, "{}", task_type_str)
1838 }
1839}
1840
1841impl FromStr for EvaluationTaskType {
1842 type Err = TypeError;
1843
1844 fn from_str(s: &str) -> Result<Self, Self::Err> {
1845 match s {
1846 "Assertion" => Ok(EvaluationTaskType::Assertion),
1847 "LLMJudge" => Ok(EvaluationTaskType::LLMJudge),
1848 "Conditional" => Ok(EvaluationTaskType::Conditional),
1849 "HumanValidation" => Ok(EvaluationTaskType::HumanValidation),
1850 "TraceAssertion" => Ok(EvaluationTaskType::TraceAssertion),
1851 "AgentAssertion" => Ok(EvaluationTaskType::AgentAssertion),
1852 _ => Err(TypeError::InvalidEvalType(s.to_string())),
1853 }
1854 }
1855}
1856
1857impl EvaluationTaskType {
1858 pub fn as_str(&self) -> &str {
1859 match self {
1860 EvaluationTaskType::Assertion => "Assertion",
1861 EvaluationTaskType::LLMJudge => "LLMJudge",
1862 EvaluationTaskType::Conditional => "Conditional",
1863 EvaluationTaskType::HumanValidation => "HumanValidation",
1864 EvaluationTaskType::TraceAssertion => "TraceAssertion",
1865 EvaluationTaskType::AgentAssertion => "AgentAssertion",
1866 }
1867 }
1868}
1869
1870#[pyclass]
1871#[derive(Debug, Serialize)]
1872pub struct TasksFile {
1873 pub tasks: Vec<TaskConfig>,
1874
1875 #[serde(default)]
1876 index: usize,
1877}
1878
1879#[pymethods]
1880impl TasksFile {
1881 #[staticmethod]
1882 pub fn from_path(path: PathBuf) -> Result<Self, TypeError> {
1883 let tasks_file: TasksFile = deserialize_from_path(path)?;
1884 Ok(tasks_file)
1885 }
1886
1887 pub fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
1888 slf
1889 }
1890
1891 pub fn __next__<'py>(
1892 mut slf: PyRefMut<'py, Self>,
1893 ) -> Result<Option<Bound<'py, PyAny>>, TypeError> {
1894 let py = slf.py();
1895 if slf.index < slf.tasks.len() {
1896 let task = slf.tasks[slf.index].clone().into_bound_py_any(py)?;
1897 slf.index += 1;
1898 Ok(Some(task))
1899 } else {
1900 Ok(None)
1901 }
1902 }
1903
1904 fn __getitem__<'py>(
1905 &self,
1906 py: Python<'py>,
1907 index: &Bound<'py, PyAny>,
1908 ) -> Result<Bound<'py, PyAny>, TypeError> {
1909 if let Ok(i) = index.extract::<isize>() {
1910 let len = self.tasks.len() as isize;
1911 let actual_index = if i < 0 { len + i } else { i };
1912
1913 if actual_index < 0 || actual_index >= len {
1914 return Err(TypeError::IndexOutOfBounds {
1915 index: i,
1916 length: self.tasks.len(),
1917 });
1918 }
1919
1920 Ok(self.tasks[actual_index as usize]
1921 .clone()
1922 .into_bound_py_any(py)?)
1923 } else if let Ok(slice) = index.cast::<PySlice>() {
1924 let indices = slice.indices(self.tasks.len() as isize)?;
1925 let result = PyList::empty(py);
1926
1927 let mut i = indices.start;
1928 while (indices.step > 0 && i < indices.stop) || (indices.step < 0 && i > indices.stop) {
1929 result.append(self.tasks[i as usize].clone().into_bound_py_any(py)?)?;
1930 i += indices.step;
1931 }
1932
1933 Ok(result.into_bound_py_any(py)?)
1934 } else {
1935 Err(TypeError::IndexOrSliceExpected)
1936 }
1937 }
1938
1939 fn __len__(&self) -> usize {
1940 self.tasks.len()
1941 }
1942
1943 fn __str__(&self) -> String {
1944 PyHelperFuncs::__str__(self)
1945 }
1946}
1947
1948#[derive(Debug, Serialize, Clone)]
1949pub enum TaskConfig {
1950 Assertion(AssertionTask),
1951 #[serde(rename = "LLMJudge")]
1952 LLMJudge(Box<LLMJudgeTask>),
1953 TraceAssertion(TraceAssertionTask),
1954 AgentAssertion(AgentAssertionTask),
1955}
1956
1957impl TaskConfig {
1958 fn into_bound_py_any<'py>(self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
1959 match self {
1960 TaskConfig::Assertion(task) => Ok(task.into_bound_py_any(py)?),
1961 TaskConfig::LLMJudge(task) => Ok(task.into_bound_py_any(py)?),
1962 TaskConfig::TraceAssertion(task) => Ok(task.into_bound_py_any(py)?),
1963 TaskConfig::AgentAssertion(task) => Ok(task.into_bound_py_any(py)?),
1964 }
1965 }
1966}
1967
1968impl<'de> Deserialize<'de> for TasksFile {
1969 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1970 where
1971 D: serde::Deserializer<'de>,
1972 {
1973 #[derive(Deserialize)]
1974 #[serde(untagged)]
1975 enum TasksFileRaw {
1976 Direct(Vec<TaskConfigRaw>),
1977 Wrapped { tasks: Vec<TaskConfigRaw> },
1978 }
1979
1980 #[derive(Deserialize)]
1981 struct TaskConfigRaw {
1982 task_type: EvaluationTaskType,
1983 #[serde(flatten)]
1984 data: Value,
1985 }
1986
1987 let raw = TasksFileRaw::deserialize(deserializer)?;
1988 let raw_tasks = match raw {
1989 TasksFileRaw::Direct(tasks) => tasks,
1990 TasksFileRaw::Wrapped { tasks } => tasks,
1991 };
1992
1993 let mut tasks = Vec::new();
1994
1995 for task_raw in raw_tasks {
1996 let task_config = match task_raw.task_type {
1997 EvaluationTaskType::Assertion => {
1998 let mut task: AssertionTask =
1999 serde_json::from_value(task_raw.data).map_err(|e| {
2000 error!("Failed to deserialize AssertionTask: {}", e);
2001 serde::de::Error::custom(e.to_string())
2002 })?;
2003 task.task_type = EvaluationTaskType::Assertion;
2004 TaskConfig::Assertion(task)
2005 }
2006 EvaluationTaskType::LLMJudge => {
2007 let mut task: LLMJudgeTask =
2008 serde_json::from_value(task_raw.data).map_err(|e| {
2009 error!("Failed to deserialize LLMJudgeTask: {}", e);
2010 serde::de::Error::custom(e.to_string())
2011 })?;
2012 task.task_type = EvaluationTaskType::LLMJudge;
2013 TaskConfig::LLMJudge(Box::new(task))
2014 }
2015 EvaluationTaskType::TraceAssertion => {
2016 let mut task: TraceAssertionTask = serde_json::from_value(task_raw.data)
2017 .map_err(|e| {
2018 error!("Failed to deserialize TraceAssertionTask: {}", e);
2019 serde::de::Error::custom(e.to_string())
2020 })?;
2021 task.task_type = EvaluationTaskType::TraceAssertion;
2022 TaskConfig::TraceAssertion(task)
2023 }
2024 EvaluationTaskType::AgentAssertion => {
2025 let mut task: AgentAssertionTask = serde_json::from_value(task_raw.data)
2026 .map_err(|e| {
2027 error!("Failed to deserialize AgentAssertionTask: {}", e);
2028 serde::de::Error::custom(e.to_string())
2029 })?;
2030 task.task_type = EvaluationTaskType::AgentAssertion;
2031 TaskConfig::AgentAssertion(task)
2032 }
2033 _ => {
2034 return Err(serde::de::Error::custom(format!(
2035 "Unknown task_type: {}",
2036 task_raw.task_type
2037 )))
2038 }
2039 };
2040 tasks.push(task_config);
2041 }
2042
2043 Ok(TasksFile { tasks, index: 0 })
2044 }
2045}