scouter_types/llm/
profile.rs

1use crate::error::{ProfileError, TypeError};
2use crate::llm::alert::LLMAlertConfig;
3use crate::llm::alert::LLMDriftMetric;
4use crate::util::{json_to_pyobject, pyobject_to_json};
5use crate::ProfileRequest;
6use crate::{scouter_version, LLMMetricRecord};
7use crate::{
8    DispatchDriftConfig, DriftArgs, DriftType, FileName, ProfileArgs, ProfileBaseArgs,
9    PyHelperFuncs, VersionRequest, DEFAULT_VERSION, MISSING,
10};
11use core::fmt::Debug;
12use potato_head::prompt::ResponseType;
13use potato_head::Task;
14use potato_head::Workflow;
15use potato_head::{Agent, Prompt};
16use pyo3::prelude::*;
17use pyo3::types::PyDict;
18use scouter_semver::VersionType;
19use scouter_state::app_state;
20use serde::{Deserialize, Serialize};
21use serde_json::Value;
22use std::collections::hash_map::Entry;
23use std::collections::HashMap;
24use std::path::PathBuf;
25use std::sync::Arc;
26use tracing::{debug, error, instrument};
27
28#[pyclass]
29#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
30pub struct LLMDriftConfig {
31    #[pyo3(get, set)]
32    pub sample_rate: usize,
33
34    #[pyo3(get, set)]
35    pub space: String,
36
37    #[pyo3(get, set)]
38    pub name: String,
39
40    #[pyo3(get, set)]
41    pub version: String,
42
43    #[pyo3(get, set)]
44    pub alert_config: LLMAlertConfig,
45
46    #[pyo3(get, set)]
47    #[serde(default = "default_drift_type")]
48    pub drift_type: DriftType,
49}
50
51fn default_drift_type() -> DriftType {
52    DriftType::LLM
53}
54
55impl DispatchDriftConfig for LLMDriftConfig {
56    fn get_drift_args(&self) -> DriftArgs {
57        DriftArgs {
58            name: self.name.clone(),
59            space: self.space.clone(),
60            version: self.version.clone(),
61            dispatch_config: self.alert_config.dispatch_config.clone(),
62        }
63    }
64}
65
66#[pymethods]
67#[allow(clippy::too_many_arguments)]
68impl LLMDriftConfig {
69    #[new]
70    #[pyo3(signature = (space=MISSING, name=MISSING, version=DEFAULT_VERSION, sample_rate=5, alert_config=LLMAlertConfig::default(), config_path=None))]
71    pub fn new(
72        space: &str,
73        name: &str,
74        version: &str,
75        sample_rate: usize,
76        alert_config: LLMAlertConfig,
77        config_path: Option<PathBuf>,
78    ) -> Result<Self, ProfileError> {
79        if let Some(config_path) = config_path {
80            let config = LLMDriftConfig::load_from_json_file(config_path)?;
81            return Ok(config);
82        }
83
84        Ok(Self {
85            sample_rate,
86            space: space.to_string(),
87            name: name.to_string(),
88            version: version.to_string(),
89            alert_config,
90            drift_type: DriftType::LLM,
91        })
92    }
93
94    #[staticmethod]
95    pub fn load_from_json_file(path: PathBuf) -> Result<LLMDriftConfig, ProfileError> {
96        // deserialize the string to a struct
97
98        let file = std::fs::read_to_string(&path)?;
99
100        Ok(serde_json::from_str(&file)?)
101    }
102
103    pub fn __str__(&self) -> String {
104        // serialize the struct to a string
105        PyHelperFuncs::__str__(self)
106    }
107
108    pub fn model_dump_json(&self) -> String {
109        // serialize the struct to a string
110        PyHelperFuncs::__json__(self)
111    }
112
113    #[allow(clippy::too_many_arguments)]
114    #[pyo3(signature = (space=None, name=None, version=None, alert_config=None))]
115    pub fn update_config_args(
116        &mut self,
117        space: Option<String>,
118        name: Option<String>,
119        version: Option<String>,
120        alert_config: Option<LLMAlertConfig>,
121    ) -> Result<(), TypeError> {
122        if name.is_some() {
123            self.name = name.ok_or(TypeError::MissingNameError)?;
124        }
125
126        if space.is_some() {
127            self.space = space.ok_or(TypeError::MissingSpaceError)?;
128        }
129
130        if version.is_some() {
131            self.version = version.ok_or(TypeError::MissingVersionError)?;
132        }
133
134        if alert_config.is_some() {
135            self.alert_config = alert_config.ok_or(TypeError::MissingAlertConfigError)?;
136        }
137
138        Ok(())
139    }
140}
141
142/// Validates that a prompt contains at least one required parameter.
143///
144/// LLM evaluation prompts must have either "input" or "output" parameters
145/// to access the data being evaluated.
146///
147/// # Arguments
148/// * `prompt` - The prompt to validate
149/// * `id` - Identifier for error reporting
150///
151/// # Returns
152/// * `Ok(())` if validation passes
153/// * `Err(ProfileError::MissingPromptParametersError)` if no required parameters found
154///
155/// # Errors
156/// Returns an error if the prompt lacks both "input" and "output" parameters.
157fn validate_prompt_parameters(prompt: &Prompt, id: &str) -> Result<(), ProfileError> {
158    let has_at_least_one_param = !prompt.parameters.is_empty();
159
160    if !has_at_least_one_param {
161        return Err(ProfileError::NeedAtLeastOneBoundParameterError(
162            id.to_string(),
163        ));
164    }
165
166    Ok(())
167}
168
169/// Validates that a prompt has the correct response type for scoring.
170///
171/// LLM evaluation prompts must return scores for drift detection.
172///
173/// # Arguments
174/// * `prompt` - The prompt to validate
175/// * `id` - Identifier for error reporting
176///
177/// # Returns
178/// * `Ok(())` if validation passes
179/// * `Err(ProfileError::InvalidResponseType)` if response type is not Score
180fn validate_prompt_response_type(prompt: &Prompt, id: &str) -> Result<(), ProfileError> {
181    if prompt.response_type != ResponseType::Score {
182        return Err(ProfileError::InvalidResponseType(id.to_string()));
183    }
184
185    Ok(())
186}
187
188/// Helper function to safely retrieve a task from the workflow.
189fn get_workflow_task<'a>(
190    workflow: &'a Workflow,
191    task_id: &'a str,
192) -> Result<&'a Arc<std::sync::RwLock<potato_head::Task>>, ProfileError> {
193    workflow
194        .task_list
195        .tasks
196        .get(task_id)
197        .ok_or_else(|| ProfileError::NoTasksFoundError(format!("Task '{task_id}' not found")))
198}
199
200/// Helper function to validate first tasks in workflow execution.
201fn validate_first_tasks(
202    workflow: &Workflow,
203    execution_order: &HashMap<i32, std::collections::HashSet<String>>,
204) -> Result<(), ProfileError> {
205    let first_tasks = execution_order
206        .get(&1)
207        .ok_or_else(|| ProfileError::NoTasksFoundError("No initial tasks found".to_string()))?;
208
209    for task_id in first_tasks {
210        let task = get_workflow_task(workflow, task_id)?;
211        let task_guard = task
212            .read()
213            .map_err(|_| ProfileError::NoTasksFoundError("Failed to read task".to_string()))?;
214
215        validate_prompt_parameters(&task_guard.prompt, &task_guard.id)?;
216    }
217
218    Ok(())
219}
220
221/// Helper function to validate last tasks in workflow execution.
222fn validate_last_tasks(
223    workflow: &Workflow,
224    execution_order: &HashMap<i32, std::collections::HashSet<String>>,
225    metrics: &[LLMDriftMetric],
226) -> Result<(), ProfileError> {
227    let last_step = execution_order.len() as i32;
228    let last_tasks = execution_order
229        .get(&last_step)
230        .ok_or_else(|| ProfileError::NoTasksFoundError("No final tasks found".to_string()))?;
231
232    for task_id in last_tasks {
233        // assert task_id exists in metrics (all output tasks must have a corresponding metric)
234        if !metrics.iter().any(|m| m.name == *task_id) {
235            return Err(ProfileError::MetricNotFoundForOutputTask(task_id.clone()));
236        }
237
238        let task = get_workflow_task(workflow, task_id)?;
239        let task_guard = task
240            .read()
241            .map_err(|_| ProfileError::NoTasksFoundError("Failed to read task".to_string()))?;
242
243        validate_prompt_response_type(&task_guard.prompt, &task_guard.id)?;
244    }
245
246    Ok(())
247}
248
249/// Validates workflow execution parameters and response types.
250///
251/// Ensures that:
252/// - First tasks have required prompt parameters
253/// - Last tasks have Score response type
254///
255/// # Arguments
256/// * `workflow` - The workflow to validate
257///
258/// # Returns
259/// * `Ok(())` if validation passes
260/// * `Err(ProfileError)` if validation fails
261///
262/// # Errors
263/// Returns various ProfileError types based on validation failures.
264fn validate_workflow(workflow: &Workflow, metrics: &[LLMDriftMetric]) -> Result<(), ProfileError> {
265    let execution_order = workflow.execution_plan()?;
266
267    // Validate first tasks have required parameters
268    validate_first_tasks(workflow, &execution_order)?;
269
270    // Validate last tasks have correct response type
271    validate_last_tasks(workflow, &execution_order, metrics)?;
272
273    Ok(())
274}
275
276#[pyclass]
277#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
278pub struct LLMDriftProfile {
279    #[pyo3(get)]
280    pub config: LLMDriftConfig,
281
282    #[pyo3(get)]
283    pub metrics: Vec<LLMDriftMetric>,
284
285    #[pyo3(get)]
286    pub scouter_version: String,
287
288    pub workflow: Workflow,
289
290    #[pyo3(get)]
291    pub metric_names: Vec<String>,
292}
293#[pymethods]
294impl LLMDriftProfile {
295    #[new]
296    #[pyo3(signature = (config, metrics, workflow=None))]
297    /// Create a new LLMDriftProfile
298    /// LLM evaluations are run asynchronously on the scouter server.
299    ///
300    /// # Logic flow:
301    ///  1. If a user provides only a list of metrics, a workflow will be created from the metrics using `from_metrics` method.
302    ///  2. If a user provides a workflow, It will be parsed and validated using `from_workflow` method.
303    ///     - The user must also provide a list of metrics that will be used to evaluate the output of the workflow.
304    ///     - The metric names must correspond to the final task names in the workflow
305    /// In addition, baseline metrics and threshold will be extracted from the LLMDriftMetric.
306    /// # Arguments
307    /// * `config` - LLMDriftConfig - The configuration for the LLM drift profile
308    /// * `metrics` - Option<Bound<'_, PyList>> - Optional list of metrics that will be used to evaluate the LLM
309    /// * `workflow` - Option<Bound<'_, PyAny>> - Optional workflow to use for the LLM drift profile
310    ///
311    /// # Returns
312    /// * `Result<Self, ProfileError>` - The LLMDriftProfile
313    ///
314    /// # Errors
315    /// * `ProfileError::MissingWorkflowError` - If the workflow is
316    #[instrument(skip_all)]
317    pub fn new(
318        config: LLMDriftConfig,
319        metrics: Vec<LLMDriftMetric>,
320        workflow: Option<Bound<'_, PyAny>>,
321    ) -> Result<Self, ProfileError> {
322        match workflow {
323            Some(py_workflow) => {
324                // Extract and validate workflow from Python object
325                let workflow = Self::extract_workflow(&py_workflow).map_err(|e| {
326                    error!("Failed to extract workflow: {}", e);
327                    e
328                })?;
329                validate_workflow(&workflow, &metrics)?;
330                Self::from_workflow(config, workflow, metrics)
331            }
332            None => {
333                // Ensure metrics are provided when no workflow specified
334                if metrics.is_empty() {
335                    return Err(ProfileError::EmptyMetricsList);
336                }
337                app_state().block_on(async { Self::from_metrics(config, metrics).await })
338            }
339        }
340    }
341
342    pub fn __str__(&self) -> String {
343        // serialize the struct to a string
344        PyHelperFuncs::__str__(self)
345    }
346
347    pub fn model_dump_json(&self) -> String {
348        // serialize the struct to a string
349        PyHelperFuncs::__json__(self)
350    }
351
352    pub fn model_dump(&self, py: Python) -> Result<Py<PyDict>, ProfileError> {
353        let json_str = serde_json::to_string(&self)?;
354
355        let json_value: Value = serde_json::from_str(&json_str)?;
356
357        // Create a new Python dictionary
358        let dict = PyDict::new(py);
359
360        // Convert JSON to Python dict
361        json_to_pyobject(py, &json_value, &dict)?;
362
363        // Return the Python dictionary
364        Ok(dict.into())
365    }
366
367    #[pyo3(signature = (path=None))]
368    pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
369        Ok(PyHelperFuncs::save_to_json(
370            self,
371            path,
372            FileName::LLMDriftProfile.to_str(),
373        )?)
374    }
375
376    #[staticmethod]
377    pub fn model_validate(data: &Bound<'_, PyDict>) -> LLMDriftProfile {
378        let json_value = pyobject_to_json(data).unwrap();
379
380        let string = serde_json::to_string(&json_value).unwrap();
381        serde_json::from_str(&string).expect("Failed to load drift profile")
382    }
383
384    #[staticmethod]
385    pub fn model_validate_json(json_string: String) -> LLMDriftProfile {
386        // deserialize the string to a struct
387        serde_json::from_str(&json_string).expect("Failed to load prompt drift profile")
388    }
389
390    #[staticmethod]
391    pub fn from_file(path: PathBuf) -> Result<LLMDriftProfile, ProfileError> {
392        let file = std::fs::read_to_string(&path)?;
393
394        Ok(serde_json::from_str(&file)?)
395    }
396
397    #[allow(clippy::too_many_arguments)]
398    #[pyo3(signature = (space=None, name=None, version=None, alert_config=None))]
399    pub fn update_config_args(
400        &mut self,
401        space: Option<String>,
402        name: Option<String>,
403        version: Option<String>,
404        alert_config: Option<LLMAlertConfig>,
405    ) -> Result<(), TypeError> {
406        self.config
407            .update_config_args(space, name, version, alert_config)
408    }
409
410    /// Create a profile request from the profile
411    pub fn create_profile_request(&self) -> Result<ProfileRequest, TypeError> {
412        let version: Option<String> = if self.config.version == DEFAULT_VERSION {
413            None
414        } else {
415            Some(self.config.version.clone())
416        };
417
418        Ok(ProfileRequest {
419            space: self.config.space.clone(),
420            profile: self.model_dump_json(),
421            drift_type: self.config.drift_type.clone(),
422            version_request: VersionRequest {
423                version,
424                version_type: VersionType::Minor,
425                pre_tag: None,
426                build_tag: None,
427            },
428            active: false,
429            deactivate_others: false,
430        })
431    }
432}
433
434impl LLMDriftProfile {
435    /// Creates an LLMDriftProfile from a configuration and a list of metrics.
436    ///
437    /// # Arguments
438    /// * `config` - LLMDriftConfig - The configuration for the LLM
439    /// * `metrics` - Vec<LLMDriftMetric> - The metrics that will be used to evaluate the LLM
440    /// * `scouter_version` - Option<String> - The version of scouter that the profile is created with.
441    /// # Returns
442    /// * `Result<Self, ProfileError>` - The LLMDriftProfile
443    pub async fn from_metrics(
444        mut config: LLMDriftConfig,
445        metrics: Vec<LLMDriftMetric>,
446    ) -> Result<Self, ProfileError> {
447        // Build a workflow from metrics
448        let mut workflow = Workflow::new("llm_drift_workflow");
449        let mut agents = HashMap::new();
450        let mut metric_names = Vec::new();
451
452        // Create agents. We don't want to duplicate, so we check if the agent already exists.
453        // if it doesn't, we create it.
454        for metric in &metrics {
455            // get prompt (if providing a list of metrics, prompt must be present)
456            let prompt = metric
457                .prompt
458                .as_ref()
459                .ok_or_else(|| ProfileError::MissingPromptError(metric.name.clone()))?;
460
461            let provider = prompt.model_settings.provider();
462
463            let agent = match agents.entry(provider) {
464                Entry::Occupied(entry) => entry.into_mut(),
465                Entry::Vacant(entry) => {
466                    let agent = Agent::from_model_settings(&prompt.model_settings).await?;
467                    workflow.add_agent(&agent);
468                    entry.insert(agent)
469                }
470            };
471
472            let task = Task::new(&agent.id, prompt.clone(), &metric.name, None, None);
473            validate_prompt_parameters(prompt, &metric.name)?;
474            workflow.add_task(task)?;
475            metric_names.push(metric.name.clone());
476        }
477
478        config.alert_config.set_alert_conditions(&metrics);
479
480        Ok(Self {
481            config,
482            metrics,
483            scouter_version: scouter_version(),
484            workflow,
485            metric_names,
486        })
487    }
488
489    /// Creates an LLMDriftProfile from a workflow and a list of metrics.
490    /// This is useful when the workflow is already defined and you want to create a profile from it.
491    /// This is also for more advanced use cases where the workflow may need to execute many dependent tasks.
492    /// Because of this, there are a few requirements:
493    /// 1. All beginning tasks in the the workflow must have "input" and/or "output" parameters defined.
494    /// 2. All ending tasks in the workflow must have a response type of "Score".
495    /// 3. The user must also supply a list of metrics that will be used to evaluate the output of the workflow.
496    ///    - The metric names must correspond to the final task names in the workflow.
497    ///
498    /// # Arguments
499    /// * `config` - LLMDriftConfig - The configuration for the LLM
500    /// * `workflow` - Workflow - The workflow that will be used to evaluate the L
501    /// * `metrics` - Vec<LLMDriftMetric> - The metrics that will be used to evaluate the LLM
502    /// * `scouter_version` - Option<String> - The version of scouter that the profile is created with.
503    ///
504    /// # Returns
505    /// * `Result<Self, ProfileError>` - The LLMDriftProfile
506    pub fn from_workflow(
507        mut config: LLMDriftConfig,
508        workflow: Workflow,
509        metrics: Vec<LLMDriftMetric>,
510    ) -> Result<Self, ProfileError> {
511        validate_workflow(&workflow, &metrics)?;
512
513        config.alert_config.set_alert_conditions(&metrics);
514
515        let metric_names = metrics.iter().map(|m| m.name.clone()).collect();
516
517        Ok(Self {
518            config,
519            metrics,
520            scouter_version: scouter_version(),
521            workflow,
522            metric_names,
523        })
524    }
525
526    /// Extracts a Workflow from a Python object.
527    ///
528    /// # Arguments
529    /// * `py_workflow` - Python object that should implement `__workflow__()` method
530    ///
531    /// # Returns
532    /// * `Result<Workflow, ProfileError>` - Extracted workflow
533    ///
534    /// # Errors
535    /// * `ProfileError::InvalidWorkflowType` - If object doesn't implement required interface
536    #[instrument(skip_all)]
537    fn extract_workflow(py_workflow: &Bound<'_, PyAny>) -> Result<Workflow, ProfileError> {
538        debug!("Extracting workflow from Python object");
539
540        if !py_workflow.hasattr("__workflow__")? {
541            error!("Invalid workflow type provided. Expected object with __workflow__ method.");
542            return Err(ProfileError::InvalidWorkflowType);
543        }
544
545        let workflow_string = py_workflow
546            .getattr("__workflow__")
547            .map_err(|e| {
548                error!("Failed to call __workflow__ property: {}", e);
549                ProfileError::InvalidWorkflowType
550            })?
551            .extract::<String>()
552            .inspect_err(|e| {
553                error!(
554                    "Failed to extract workflow string from Python object: {}",
555                    e
556                );
557            })?;
558
559        serde_json::from_str(&workflow_string).map_err(|e| {
560            error!("Failed to deserialize workflow: {}", e);
561            ProfileError::InvalidWorkflowType
562        })
563    }
564
565    pub fn get_metric_value(&self, metric_name: &str) -> Result<f64, ProfileError> {
566        self.metrics
567            .iter()
568            .find(|m| m.name == metric_name)
569            .map(|m| m.value)
570            .ok_or_else(|| ProfileError::MetricNotFound(metric_name.to_string()))
571    }
572
573    /// Same as py new method, but allows for passing a runtime for async operations.
574    /// This is used in Opsml, so that the Opsml global runtime can be used.
575    pub fn new_with_runtime(
576        config: LLMDriftConfig,
577        metrics: Vec<LLMDriftMetric>,
578        workflow: Option<Bound<'_, PyAny>>,
579        runtime: Arc<tokio::runtime::Runtime>,
580    ) -> Result<Self, ProfileError> {
581        match workflow {
582            Some(py_workflow) => {
583                // Extract and validate workflow from Python object
584                let workflow = Self::extract_workflow(&py_workflow).map_err(|e| {
585                    error!("Failed to extract workflow: {}", e);
586                    e
587                })?;
588                validate_workflow(&workflow, &metrics)?;
589                Self::from_workflow(config, workflow, metrics)
590            }
591            None => {
592                // Ensure metrics are provided when no workflow specified
593                if metrics.is_empty() {
594                    return Err(ProfileError::EmptyMetricsList);
595                }
596                runtime.block_on(async { Self::from_metrics(config, metrics).await })
597            }
598        }
599    }
600}
601
602impl ProfileBaseArgs for LLMDriftProfile {
603    fn get_base_args(&self) -> ProfileArgs {
604        ProfileArgs {
605            name: self.config.name.clone(),
606            space: self.config.space.clone(),
607            version: Some(self.config.version.clone()),
608            schedule: self.config.alert_config.schedule.clone(),
609            scouter_version: self.scouter_version.clone(),
610            drift_type: self.config.drift_type.clone(),
611        }
612    }
613
614    fn to_value(&self) -> Value {
615        serde_json::to_value(self).unwrap()
616    }
617}
618
619#[pyclass]
620#[derive(Debug, Serialize, Deserialize, Clone)]
621pub struct LLMDriftMap {
622    #[pyo3(get)]
623    pub records: Vec<LLMMetricRecord>,
624}
625
626#[pymethods]
627impl LLMDriftMap {
628    #[new]
629    pub fn new(records: Vec<LLMMetricRecord>) -> Self {
630        Self { records }
631    }
632
633    pub fn __str__(&self) -> String {
634        // serialize the struct to a string
635        PyHelperFuncs::__str__(self)
636    }
637}
638
639// write test using mock feature
640#[cfg(test)]
641#[cfg(feature = "mock")]
642mod tests {
643
644    use super::*;
645    use crate::AlertThreshold;
646    use crate::{AlertDispatchConfig, OpsGenieDispatchConfig, SlackDispatchConfig};
647    use potato_head::create_score_prompt;
648    use potato_head::prompt::ResponseType;
649    use potato_head::Provider;
650
651    use potato_head::{LLMTestServer, Message, PromptContent};
652
653    pub fn create_parameterized_prompt() -> Prompt {
654        let user_content =
655            PromptContent::Str("What is ${input} + ${response} + ${context}?".to_string());
656        let system_content = PromptContent::Str("You are a helpful assistant.".to_string());
657        Prompt::new_rs(
658            vec![Message::new_rs(user_content)],
659            "gpt-4o",
660            Provider::OpenAI,
661            vec![Message::new_rs(system_content)],
662            None,
663            None,
664            ResponseType::Null,
665        )
666        .unwrap()
667    }
668
669    #[test]
670    fn test_llm_drift_config() {
671        let mut drift_config = LLMDriftConfig::new(
672            MISSING,
673            MISSING,
674            "0.1.0",
675            25,
676            LLMAlertConfig::default(),
677            None,
678        )
679        .unwrap();
680        assert_eq!(drift_config.name, "__missing__");
681        assert_eq!(drift_config.space, "__missing__");
682        assert_eq!(drift_config.version, "0.1.0");
683        assert_eq!(
684            drift_config.alert_config.dispatch_config,
685            AlertDispatchConfig::default()
686        );
687
688        let test_slack_dispatch_config = SlackDispatchConfig {
689            channel: "test-channel".to_string(),
690        };
691        let new_alert_config = LLMAlertConfig {
692            schedule: "0 0 * * * *".to_string(),
693            dispatch_config: AlertDispatchConfig::Slack(test_slack_dispatch_config.clone()),
694            ..Default::default()
695        };
696
697        drift_config
698            .update_config_args(None, Some("test".to_string()), None, Some(new_alert_config))
699            .unwrap();
700
701        assert_eq!(drift_config.name, "test");
702        assert_eq!(
703            drift_config.alert_config.dispatch_config,
704            AlertDispatchConfig::Slack(test_slack_dispatch_config)
705        );
706        assert_eq!(
707            drift_config.alert_config.schedule,
708            "0 0 * * * *".to_string()
709        );
710    }
711
712    #[test]
713    fn test_llm_drift_profile_metric() {
714        let runtime = tokio::runtime::Runtime::new().unwrap();
715        let mut mock = LLMTestServer::new();
716        mock.start_server().unwrap();
717        let prompt = create_score_prompt(Some(vec!["input".to_string()]));
718
719        let metric1 = LLMDriftMetric::new(
720            "metric1",
721            5.0,
722            AlertThreshold::Above,
723            None,
724            Some(prompt.clone()),
725        )
726        .unwrap();
727        let metric2 = LLMDriftMetric::new(
728            "metric2",
729            3.0,
730            AlertThreshold::Below,
731            Some(1.0),
732            Some(prompt.clone()),
733        )
734        .unwrap();
735
736        let llm_metrics = vec![metric1, metric2];
737
738        let alert_config = LLMAlertConfig {
739            schedule: "0 0 * * * *".to_string(),
740            dispatch_config: AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
741                team: "test-team".to_string(),
742                priority: "P5".to_string(),
743            }),
744            ..Default::default()
745        };
746
747        let drift_config =
748            LLMDriftConfig::new("scouter", "ML", "0.1.0", 25, alert_config, None).unwrap();
749
750        let profile = runtime
751            .block_on(async { LLMDriftProfile::from_metrics(drift_config, llm_metrics).await })
752            .unwrap();
753        let _: Value =
754            serde_json::from_str(&profile.model_dump_json()).expect("Failed to parse actual JSON");
755
756        assert_eq!(profile.metrics.len(), 2);
757        assert_eq!(profile.scouter_version, env!("CARGO_PKG_VERSION"));
758
759        mock.stop_server().unwrap();
760    }
761
762    #[test]
763    fn test_llm_drift_profile_workflow() {
764        let mut mock = LLMTestServer::new();
765        mock.start_server().unwrap();
766        let runtime = tokio::runtime::Runtime::new().unwrap();
767
768        let mut workflow = Workflow::new("My eval Workflow");
769
770        let initial_prompt = create_parameterized_prompt();
771        let final_prompt1 = create_score_prompt(None);
772        let final_prompt2 = create_score_prompt(None);
773
774        let agent1 = runtime
775            .block_on(async { Agent::new(Provider::OpenAI, None).await })
776            .unwrap();
777
778        workflow.add_agent(&agent1);
779
780        // First task with parameters
781        workflow
782            .add_task(Task::new(
783                &agent1.id,
784                initial_prompt.clone(),
785                "task1",
786                None,
787                None,
788            ))
789            .unwrap();
790
791        // Final tasks that depend on the first task
792        workflow
793            .add_task(Task::new(
794                &agent1.id,
795                final_prompt1.clone(),
796                "task2",
797                Some(vec!["task1".to_string()]),
798                None,
799            ))
800            .unwrap();
801
802        workflow
803            .add_task(Task::new(
804                &agent1.id,
805                final_prompt2.clone(),
806                "task3",
807                Some(vec!["task1".to_string()]),
808                None,
809            ))
810            .unwrap();
811
812        let metric1 =
813            LLMDriftMetric::new("task2", 3.0, AlertThreshold::Below, Some(1.0), None).unwrap();
814        let metric2 =
815            LLMDriftMetric::new("task3", 4.0, AlertThreshold::Above, Some(2.0), None).unwrap();
816
817        let llm_metrics = vec![metric1, metric2];
818
819        let alert_config = LLMAlertConfig {
820            schedule: "0 0 * * * *".to_string(),
821            dispatch_config: AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
822                team: "test-team".to_string(),
823                priority: "P5".to_string(),
824            }),
825            ..Default::default()
826        };
827
828        let drift_config =
829            LLMDriftConfig::new("scouter", "ML", "0.1.0", 25, alert_config, None).unwrap();
830
831        let profile = LLMDriftProfile::from_workflow(drift_config, workflow, llm_metrics).unwrap();
832
833        let _: Value =
834            serde_json::from_str(&profile.model_dump_json()).expect("Failed to parse actual JSON");
835
836        assert_eq!(profile.metrics.len(), 2);
837        assert_eq!(profile.scouter_version, env!("CARGO_PKG_VERSION"));
838
839        let plan = profile.workflow.execution_plan().unwrap();
840
841        // plan should have 2 steps
842        assert_eq!(plan.len(), 2);
843        // first step should have 1 task
844        assert_eq!(plan.get(&1).unwrap().len(), 1);
845        // last step should have 2 tasks
846        assert_eq!(plan.get(&(plan.len() as i32)).unwrap().len(), 2);
847
848        mock.stop_server().unwrap();
849    }
850}