1use crate::error::TypeError;
2use crate::genai::traits::TaskAccessor;
3use crate::PyHelperFuncs;
4use core::fmt::Debug;
5use potato_head::prompt_types::Prompt;
6use pyo3::prelude::*;
7use pyo3::types::{PyBool, PyFloat, PyInt, PyList, PyString};
8use pythonize::{depythonize, pythonize};
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use std::fmt::Display;
12use std::str::FromStr;
13
14#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
15pub struct AssertionResult {
16 pub passed: bool,
17 pub actual: Value,
18 pub message: String,
19 pub expected: Value,
20}
21
22impl AssertionResult {
23 pub fn new(passed: bool, actual: Value, message: String, expected: Value) -> Self {
24 Self {
25 passed,
26 actual,
27 message,
28 expected,
29 }
30 }
31 pub fn to_metric_value(&self) -> f64 {
32 if self.passed {
33 1.0
34 } else {
35 0.0
36 }
37 }
38}
39
40#[pyclass]
41#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
42pub struct AssertionTask {
43 #[pyo3(get, set)]
44 pub id: String,
45
46 #[pyo3(get, set)]
47 pub field_path: Option<String>,
48
49 #[pyo3(get, set)]
50 pub operator: ComparisonOperator,
51
52 pub expected_value: Value,
53
54 #[pyo3(get, set)]
55 pub description: Option<String>,
56
57 #[pyo3(get, set)]
58 pub depends_on: Vec<String>,
59
60 pub task_type: EvaluationTaskType,
61
62 #[serde(skip_serializing_if = "Option::is_none")]
63 pub result: Option<AssertionResult>,
64
65 pub condition: bool,
66}
67
68#[pymethods]
69impl AssertionTask {
70 #[new]
71 #[pyo3(signature = (id, field_path, expected_value, operator, description=None, depends_on=None, condition=None))]
125 pub fn new(
126 id: String,
127 field_path: Option<String>,
128 expected_value: &Bound<'_, PyAny>,
129 operator: ComparisonOperator,
130 description: Option<String>,
131 depends_on: Option<Vec<String>>,
132 condition: Option<bool>,
133 ) -> Result<Self, TypeError> {
134 let expected_value = depythonize(expected_value)?;
135 let condition = condition.unwrap_or(false);
136
137 Ok(Self {
138 id: id.to_lowercase(),
139 field_path,
140 operator,
141 expected_value,
142 description,
143 task_type: EvaluationTaskType::Assertion,
144 depends_on: depends_on.unwrap_or_default(),
145 result: None,
146 condition,
147 })
148 }
149
150 #[getter]
151 pub fn get_expected_value<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
152 let py_value = pythonize(py, &self.expected_value)?;
153 Ok(py_value)
154 }
155}
156
157impl TaskAccessor for AssertionTask {
158 fn field_path(&self) -> Option<&str> {
159 self.field_path.as_deref()
160 }
161
162 fn id(&self) -> &str {
163 &self.id
164 }
165
166 fn operator(&self) -> &ComparisonOperator {
167 &self.operator
168 }
169
170 fn task_type(&self) -> &EvaluationTaskType {
171 &self.task_type
172 }
173
174 fn expected_value(&self) -> &Value {
175 &self.expected_value
176 }
177
178 fn depends_on(&self) -> &[String] {
179 &self.depends_on
180 }
181
182 fn add_result(&mut self, result: AssertionResult) {
183 self.result = Some(result);
184 }
185}
186
187pub trait ValueExt {
188 fn to_length(&self) -> Option<i64>;
190
191 fn as_numeric(&self) -> Option<f64>;
193
194 fn is_truthy(&self) -> bool;
196}
197
198impl ValueExt for Value {
199 fn to_length(&self) -> Option<i64> {
200 match self {
201 Value::Array(arr) => Some(arr.len() as i64),
202 Value::String(s) => Some(s.chars().count() as i64),
203 Value::Object(obj) => Some(obj.len() as i64),
204 _ => None,
205 }
206 }
207
208 fn as_numeric(&self) -> Option<f64> {
209 match self {
210 Value::Number(n) => n.as_f64(),
211 _ => None,
212 }
213 }
214
215 fn is_truthy(&self) -> bool {
216 match self {
217 Value::Null => false,
218 Value::Bool(b) => *b,
219 Value::Number(n) => n.as_f64() != Some(0.0),
220 Value::String(s) => !s.is_empty(),
221 Value::Array(arr) => !arr.is_empty(),
222 Value::Object(obj) => !obj.is_empty(),
223 }
224 }
225}
226
227#[pyclass]
229#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
230pub struct LLMJudgeTask {
231 #[pyo3(get, set)]
232 pub id: String,
233
234 #[pyo3(get)]
235 pub prompt: Prompt,
236
237 #[pyo3(get)]
238 pub field_path: Option<String>,
239
240 pub expected_value: Value,
241
242 #[pyo3(get)]
243 pub operator: ComparisonOperator,
244
245 pub task_type: EvaluationTaskType,
246
247 #[pyo3(get, set)]
248 pub depends_on: Vec<String>,
249
250 #[pyo3(get, set)]
251 pub max_retries: Option<u32>,
252
253 #[serde(skip_serializing_if = "Option::is_none")]
254 pub result: Option<AssertionResult>,
255
256 pub description: Option<String>,
257
258 pub condition: bool,
259}
260
261#[pymethods]
262impl LLMJudgeTask {
263 #[new]
283 #[pyo3(signature = (id, prompt, expected_value, field_path,operator, description=None, depends_on=None, max_retries=None, condition=None))]
284 #[allow(clippy::too_many_arguments)]
285 pub fn new(
286 id: &str,
287 prompt: Prompt,
288 expected_value: &Bound<'_, PyAny>,
289 field_path: Option<String>,
290 operator: ComparisonOperator,
291 description: Option<String>,
292 depends_on: Option<Vec<String>>,
293 max_retries: Option<u32>,
294 condition: Option<bool>,
295 ) -> Result<Self, TypeError> {
296 let expected_value = depythonize(expected_value)?;
297
298 Ok(Self {
299 id: id.to_lowercase(),
300 prompt,
301 expected_value,
302 operator,
303 task_type: EvaluationTaskType::LLMJudge,
304 depends_on: depends_on.unwrap_or_default(),
305 max_retries: max_retries.or(Some(3)),
306 field_path,
307 result: None,
308 description,
309 condition: condition.unwrap_or(false),
310 })
311 }
312
313 pub fn __str__(&self) -> String {
314 PyHelperFuncs::__str__(self)
316 }
317
318 #[getter]
319 pub fn get_expected_value<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
320 let py_value = pythonize(py, &self.expected_value)?;
321 Ok(py_value)
322 }
323}
324
325impl LLMJudgeTask {
326 #[allow(clippy::too_many_arguments)]
338 pub fn new_rs(
339 id: &str,
340 prompt: Prompt,
341 expected_value: Value,
342 field_path: Option<String>,
343 operator: ComparisonOperator,
344 depends_on: Option<Vec<String>>,
345 max_retries: Option<u32>,
346 description: Option<String>,
347 condition: Option<bool>,
348 ) -> Self {
349 Self {
350 id: id.to_lowercase(),
351 prompt,
352 expected_value,
353 operator,
354 task_type: EvaluationTaskType::LLMJudge,
355 depends_on: depends_on.unwrap_or_default(),
356 max_retries: max_retries.or(Some(3)),
357 field_path,
358 result: None,
359 description,
360 condition: condition.unwrap_or(false),
361 }
362 }
363}
364
365impl TaskAccessor for LLMJudgeTask {
366 fn field_path(&self) -> Option<&str> {
367 self.field_path.as_deref()
368 }
369
370 fn id(&self) -> &str {
371 &self.id
372 }
373 fn task_type(&self) -> &EvaluationTaskType {
374 &self.task_type
375 }
376
377 fn operator(&self) -> &ComparisonOperator {
378 &self.operator
379 }
380
381 fn expected_value(&self) -> &Value {
382 &self.expected_value
383 }
384 fn depends_on(&self) -> &[String] {
385 &self.depends_on
386 }
387 fn add_result(&mut self, result: AssertionResult) {
388 self.result = Some(result);
389 }
390}
391
392#[derive(Debug, Clone)]
393pub enum EvaluationTask {
394 Assertion(Box<AssertionTask>),
395 LLMJudge(Box<LLMJudgeTask>),
396}
397
398impl TaskAccessor for EvaluationTask {
399 fn field_path(&self) -> Option<&str> {
400 match self {
401 EvaluationTask::Assertion(t) => t.field_path(),
402 EvaluationTask::LLMJudge(t) => t.field_path(),
403 }
404 }
405
406 fn id(&self) -> &str {
407 match self {
408 EvaluationTask::Assertion(t) => t.id(),
409 EvaluationTask::LLMJudge(t) => t.id(),
410 }
411 }
412
413 fn task_type(&self) -> &EvaluationTaskType {
414 match self {
415 EvaluationTask::Assertion(t) => t.task_type(),
416 EvaluationTask::LLMJudge(t) => t.task_type(),
417 }
418 }
419
420 fn operator(&self) -> &ComparisonOperator {
421 match self {
422 EvaluationTask::Assertion(t) => t.operator(),
423 EvaluationTask::LLMJudge(t) => t.operator(),
424 }
425 }
426
427 fn expected_value(&self) -> &Value {
428 match self {
429 EvaluationTask::Assertion(t) => t.expected_value(),
430 EvaluationTask::LLMJudge(t) => t.expected_value(),
431 }
432 }
433
434 fn depends_on(&self) -> &[String] {
435 match self {
436 EvaluationTask::Assertion(t) => t.depends_on(),
437 EvaluationTask::LLMJudge(t) => t.depends_on(),
438 }
439 }
440
441 fn add_result(&mut self, result: AssertionResult) {
442 match self {
443 EvaluationTask::Assertion(t) => t.add_result(result),
444 EvaluationTask::LLMJudge(t) => t.add_result(result),
445 }
446 }
447}
448
449pub struct EvaluationTasks(Vec<EvaluationTask>);
450
451impl EvaluationTasks {
452 pub fn new() -> Self {
454 Self(Vec::new())
455 }
456
457 pub fn add_task(mut self, task: impl Into<EvaluationTask>) -> Self {
459 self.0.push(task.into());
460 self
461 }
462
463 pub fn build(self) -> Vec<EvaluationTask> {
465 self.0
466 }
467}
468
469impl From<AssertionTask> for EvaluationTask {
471 fn from(task: AssertionTask) -> Self {
472 EvaluationTask::Assertion(Box::new(task))
473 }
474}
475
476impl From<LLMJudgeTask> for EvaluationTask {
477 fn from(task: LLMJudgeTask) -> Self {
478 EvaluationTask::LLMJudge(Box::new(task))
479 }
480}
481
482impl Default for EvaluationTasks {
483 fn default() -> Self {
484 Self::new()
485 }
486}
487
488#[pyclass]
489#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
490pub enum ComparisonOperator {
491 Equals,
493 NotEqual,
494 GreaterThan,
495 GreaterThanOrEqual,
496 LessThan,
497 LessThanOrEqual,
498 Contains,
499 NotContains,
500 StartsWith,
501 EndsWith,
502 Matches,
503 HasLengthGreaterThan,
504 HasLengthLessThan,
505 HasLengthEqual,
506 HasLengthGreaterThanOrEqual,
507 HasLengthLessThanOrEqual,
508
509 IsNumeric,
511 IsString,
512 IsBoolean,
513 IsNull,
514 IsArray,
515 IsObject,
516
517 IsEmail,
519 IsUrl,
520 IsUuid,
521 IsIso8601,
522 IsJson,
523 MatchesRegex,
524
525 InRange,
527 NotInRange,
528 IsPositive,
529 IsNegative,
530 IsZero,
531
532 ContainsAll,
534 ContainsAny,
535 ContainsNone,
536 IsEmpty,
537 IsNotEmpty,
538 HasUniqueItems,
539
540 IsAlphabetic,
542 IsAlphanumeric,
543 IsLowerCase,
544 IsUpperCase,
545 ContainsWord,
546
547 ApproximatelyEquals,
549}
550
551impl Display for ComparisonOperator {
552 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
553 write!(f, "{}", self.as_str())
554 }
555}
556
557impl FromStr for ComparisonOperator {
558 type Err = TypeError;
559
560 fn from_str(s: &str) -> Result<Self, Self::Err> {
561 match s {
562 "Equals" => Ok(ComparisonOperator::Equals),
563 "NotEqual" => Ok(ComparisonOperator::NotEqual),
564 "GreaterThan" => Ok(ComparisonOperator::GreaterThan),
565 "GreaterThanOrEqual" => Ok(ComparisonOperator::GreaterThanOrEqual),
566 "LessThan" => Ok(ComparisonOperator::LessThan),
567 "LessThanOrEqual" => Ok(ComparisonOperator::LessThanOrEqual),
568 "Contains" => Ok(ComparisonOperator::Contains),
569 "NotContains" => Ok(ComparisonOperator::NotContains),
570 "StartsWith" => Ok(ComparisonOperator::StartsWith),
571 "EndsWith" => Ok(ComparisonOperator::EndsWith),
572 "Matches" => Ok(ComparisonOperator::Matches),
573 "HasLengthEqual" => Ok(ComparisonOperator::HasLengthEqual),
574 "HasLengthGreaterThan" => Ok(ComparisonOperator::HasLengthGreaterThan),
575 "HasLengthLessThan" => Ok(ComparisonOperator::HasLengthLessThan),
576 "HasLengthGreaterThanOrEqual" => Ok(ComparisonOperator::HasLengthGreaterThanOrEqual),
577 "HasLengthLessThanOrEqual" => Ok(ComparisonOperator::HasLengthLessThanOrEqual),
578
579 "IsNumeric" => Ok(ComparisonOperator::IsNumeric),
581 "IsString" => Ok(ComparisonOperator::IsString),
582 "IsBoolean" => Ok(ComparisonOperator::IsBoolean),
583 "IsNull" => Ok(ComparisonOperator::IsNull),
584 "IsArray" => Ok(ComparisonOperator::IsArray),
585 "IsObject" => Ok(ComparisonOperator::IsObject),
586
587 "IsEmail" => Ok(ComparisonOperator::IsEmail),
589 "IsUrl" => Ok(ComparisonOperator::IsUrl),
590 "IsUuid" => Ok(ComparisonOperator::IsUuid),
591 "IsIso8601" => Ok(ComparisonOperator::IsIso8601),
592 "IsJson" => Ok(ComparisonOperator::IsJson),
593 "MatchesRegex" => Ok(ComparisonOperator::MatchesRegex),
594
595 "InRange" => Ok(ComparisonOperator::InRange),
597 "NotInRange" => Ok(ComparisonOperator::NotInRange),
598 "IsPositive" => Ok(ComparisonOperator::IsPositive),
599 "IsNegative" => Ok(ComparisonOperator::IsNegative),
600 "IsZero" => Ok(ComparisonOperator::IsZero),
601
602 "ContainsAll" => Ok(ComparisonOperator::ContainsAll),
604 "ContainsAny" => Ok(ComparisonOperator::ContainsAny),
605 "ContainsNone" => Ok(ComparisonOperator::ContainsNone),
606 "IsEmpty" => Ok(ComparisonOperator::IsEmpty),
607 "IsNotEmpty" => Ok(ComparisonOperator::IsNotEmpty),
608 "HasUniqueItems" => Ok(ComparisonOperator::HasUniqueItems),
609
610 "IsAlphabetic" => Ok(ComparisonOperator::IsAlphabetic),
612 "IsAlphanumeric" => Ok(ComparisonOperator::IsAlphanumeric),
613 "IsLowerCase" => Ok(ComparisonOperator::IsLowerCase),
614 "IsUpperCase" => Ok(ComparisonOperator::IsUpperCase),
615 "ContainsWord" => Ok(ComparisonOperator::ContainsWord),
616
617 "ApproximatelyEquals" => Ok(ComparisonOperator::ApproximatelyEquals),
619
620 _ => Err(TypeError::InvalidCompressionTypeError),
621 }
622 }
623}
624
625impl ComparisonOperator {
626 pub fn as_str(&self) -> &str {
627 match self {
628 ComparisonOperator::Equals => "Equals",
629 ComparisonOperator::NotEqual => "NotEqual",
630 ComparisonOperator::GreaterThan => "GreaterThan",
631 ComparisonOperator::GreaterThanOrEqual => "GreaterThanOrEqual",
632 ComparisonOperator::LessThan => "LessThan",
633 ComparisonOperator::LessThanOrEqual => "LessThanOrEqual",
634 ComparisonOperator::Contains => "Contains",
635 ComparisonOperator::NotContains => "NotContains",
636 ComparisonOperator::StartsWith => "StartsWith",
637 ComparisonOperator::EndsWith => "EndsWith",
638 ComparisonOperator::Matches => "Matches",
639 ComparisonOperator::HasLengthEqual => "HasLengthEqual",
640 ComparisonOperator::HasLengthGreaterThan => "HasLengthGreaterThan",
641 ComparisonOperator::HasLengthLessThan => "HasLengthLessThan",
642 ComparisonOperator::HasLengthGreaterThanOrEqual => "HasLengthGreaterThanOrEqual",
643 ComparisonOperator::HasLengthLessThanOrEqual => "HasLengthLessThanOrEqual",
644
645 ComparisonOperator::IsNumeric => "IsNumeric",
647 ComparisonOperator::IsString => "IsString",
648 ComparisonOperator::IsBoolean => "IsBoolean",
649 ComparisonOperator::IsNull => "IsNull",
650 ComparisonOperator::IsArray => "IsArray",
651 ComparisonOperator::IsObject => "IsObject",
652
653 ComparisonOperator::IsEmail => "IsEmail",
655 ComparisonOperator::IsUrl => "IsUrl",
656 ComparisonOperator::IsUuid => "IsUuid",
657 ComparisonOperator::IsIso8601 => "IsIso8601",
658 ComparisonOperator::IsJson => "IsJson",
659 ComparisonOperator::MatchesRegex => "MatchesRegex",
660
661 ComparisonOperator::InRange => "InRange",
663 ComparisonOperator::NotInRange => "NotInRange",
664 ComparisonOperator::IsPositive => "IsPositive",
665 ComparisonOperator::IsNegative => "IsNegative",
666 ComparisonOperator::IsZero => "IsZero",
667
668 ComparisonOperator::ContainsAll => "ContainsAll",
670 ComparisonOperator::ContainsAny => "ContainsAny",
671 ComparisonOperator::ContainsNone => "ContainsNone",
672 ComparisonOperator::IsEmpty => "IsEmpty",
673 ComparisonOperator::IsNotEmpty => "IsNotEmpty",
674 ComparisonOperator::HasUniqueItems => "HasUniqueItems",
675
676 ComparisonOperator::IsAlphabetic => "IsAlphabetic",
678 ComparisonOperator::IsAlphanumeric => "IsAlphanumeric",
679 ComparisonOperator::IsLowerCase => "IsLowerCase",
680 ComparisonOperator::IsUpperCase => "IsUpperCase",
681 ComparisonOperator::ContainsWord => "ContainsWord",
682
683 ComparisonOperator::ApproximatelyEquals => "ApproximatelyEquals",
685 }
686 }
687}
688
689#[pyclass]
690#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
691pub enum AssertionValue {
692 String(String),
693 Number(f64),
694 Integer(i64),
695 Boolean(bool),
696 List(Vec<AssertionValue>),
697 Null(),
698}
699
700impl AssertionValue {
701 pub fn to_actual(self, comparison: &ComparisonOperator) -> AssertionValue {
702 match comparison {
703 ComparisonOperator::HasLengthEqual
704 | ComparisonOperator::HasLengthGreaterThan
705 | ComparisonOperator::HasLengthLessThan
706 | ComparisonOperator::HasLengthGreaterThanOrEqual
707 | ComparisonOperator::HasLengthLessThanOrEqual => match self {
708 AssertionValue::List(arr) => AssertionValue::Integer(arr.len() as i64),
709 AssertionValue::String(s) => AssertionValue::Integer(s.chars().count() as i64),
710 _ => self,
711 },
712 _ => self,
713 }
714 }
715
716 pub fn to_serde_value(&self) -> Value {
717 match self {
718 AssertionValue::String(s) => Value::String(s.clone()),
719 AssertionValue::Number(n) => Value::Number(serde_json::Number::from_f64(*n).unwrap()),
720 AssertionValue::Integer(i) => Value::Number(serde_json::Number::from(*i)),
721 AssertionValue::Boolean(b) => Value::Bool(*b),
722 AssertionValue::List(arr) => {
723 let json_arr: Vec<Value> = arr.iter().map(|v| v.to_serde_value()).collect();
724 Value::Array(json_arr)
725 }
726 AssertionValue::Null() => Value::Null,
727 }
728 }
729}
730pub fn assertion_value_from_py(value: &Bound<'_, PyAny>) -> Result<AssertionValue, TypeError> {
737 if value.is_none() {
739 return Ok(AssertionValue::Null());
740 }
741
742 if value.is_instance_of::<PyBool>() {
744 return Ok(AssertionValue::Boolean(value.extract()?));
745 }
746
747 if value.is_instance_of::<PyString>() {
748 return Ok(AssertionValue::String(value.extract()?));
749 }
750
751 if value.is_instance_of::<PyInt>() {
752 return Ok(AssertionValue::Integer(value.extract()?));
753 }
754
755 if value.is_instance_of::<PyFloat>() {
756 return Ok(AssertionValue::Number(value.extract()?));
757 }
758
759 if value.is_instance_of::<PyList>() {
760 let list = value.cast::<PyList>()?; let assertion_list = list
763 .iter()
764 .map(|item| assertion_value_from_py(&item))
765 .collect::<Result<Vec<_>, _>>()?;
766 return Ok(AssertionValue::List(assertion_list));
767 }
768
769 Err(TypeError::UnsupportedType(
771 value.get_type().name()?.to_string(),
772 ))
773}
774
775#[pyclass]
776#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
777pub enum EvaluationTaskType {
778 Assertion,
779 LLMJudge,
780 Conditional,
781 HumanValidation,
782}
783
784impl Display for EvaluationTaskType {
785 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
786 let task_type_str = match self {
787 EvaluationTaskType::Assertion => "Assertion",
788 EvaluationTaskType::LLMJudge => "LLMJudge",
789 EvaluationTaskType::Conditional => "Conditional",
790 EvaluationTaskType::HumanValidation => "HumanValidation",
791 };
792 write!(f, "{}", task_type_str)
793 }
794}
795
796impl FromStr for EvaluationTaskType {
797 type Err = TypeError;
798
799 fn from_str(s: &str) -> Result<Self, Self::Err> {
800 match s {
801 "Assertion" => Ok(EvaluationTaskType::Assertion),
802 "LLMJudge" => Ok(EvaluationTaskType::LLMJudge),
803 "Conditional" => Ok(EvaluationTaskType::Conditional),
804 "HumanValidation" => Ok(EvaluationTaskType::HumanValidation),
805 _ => Err(TypeError::InvalidEvalType(s.to_string())),
806 }
807 }
808}
809
810impl EvaluationTaskType {
811 pub fn as_str(&self) -> &str {
812 match self {
813 EvaluationTaskType::Assertion => "Assertion",
814 EvaluationTaskType::LLMJudge => "LLMJudge",
815 EvaluationTaskType::Conditional => "Conditional",
816 EvaluationTaskType::HumanValidation => "HumanValidation",
817 }
818 }
819}