scouter_types/llm/
profile.rs

1use crate::error::{ProfileError, TypeError};
2use crate::llm::alert::LLMAlertConfig;
3use crate::llm::alert::LLMMetric;
4use crate::util::{json_to_pyobject, pyobject_to_json};
5use crate::ProfileRequest;
6use crate::{scouter_version, LLMMetricRecord};
7use crate::{
8    DispatchDriftConfig, DriftArgs, DriftType, FileName, ProfileArgs, ProfileBaseArgs,
9    ProfileFuncs, VersionRequest, DEFAULT_VERSION, MISSING,
10};
11use core::fmt::Debug;
12use potato_head::prompt::ResponseType;
13use potato_head::Task;
14use potato_head::{Agent, Prompt};
15use potato_head::{Provider, Workflow};
16use pyo3::prelude::*;
17use pyo3::types::PyDict;
18use scouter_semver::VersionType;
19use serde::{Deserialize, Serialize};
20use serde_json::Value;
21use std::collections::hash_map::Entry;
22use std::collections::HashMap;
23use std::path::PathBuf;
24use std::sync::Arc;
25use tracing::{debug, error, instrument};
26
27#[pyclass]
28#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
29pub struct LLMDriftConfig {
30    #[pyo3(get, set)]
31    pub sample_rate: usize,
32
33    #[pyo3(get, set)]
34    pub space: String,
35
36    #[pyo3(get, set)]
37    pub name: String,
38
39    #[pyo3(get, set)]
40    pub version: String,
41
42    #[pyo3(get, set)]
43    pub alert_config: LLMAlertConfig,
44
45    #[pyo3(get, set)]
46    #[serde(default = "default_drift_type")]
47    pub drift_type: DriftType,
48}
49
50fn default_drift_type() -> DriftType {
51    DriftType::LLM
52}
53
54impl DispatchDriftConfig for LLMDriftConfig {
55    fn get_drift_args(&self) -> DriftArgs {
56        DriftArgs {
57            name: self.name.clone(),
58            space: self.space.clone(),
59            version: self.version.clone(),
60            dispatch_config: self.alert_config.dispatch_config.clone(),
61        }
62    }
63}
64
65#[pymethods]
66#[allow(clippy::too_many_arguments)]
67impl LLMDriftConfig {
68    #[new]
69    #[pyo3(signature = (space=MISSING, name=MISSING, version=DEFAULT_VERSION, sample_rate=5, alert_config=LLMAlertConfig::default(), config_path=None))]
70    pub fn new(
71        space: &str,
72        name: &str,
73        version: &str,
74        sample_rate: usize,
75        alert_config: LLMAlertConfig,
76        config_path: Option<PathBuf>,
77    ) -> Result<Self, ProfileError> {
78        if let Some(config_path) = config_path {
79            let config = LLMDriftConfig::load_from_json_file(config_path)?;
80            return Ok(config);
81        }
82
83        Ok(Self {
84            sample_rate,
85            space: space.to_string(),
86            name: name.to_string(),
87            version: version.to_string(),
88            alert_config,
89            drift_type: DriftType::LLM,
90        })
91    }
92
93    #[staticmethod]
94    pub fn load_from_json_file(path: PathBuf) -> Result<LLMDriftConfig, ProfileError> {
95        // deserialize the string to a struct
96
97        let file = std::fs::read_to_string(&path)?;
98
99        Ok(serde_json::from_str(&file)?)
100    }
101
102    pub fn __str__(&self) -> String {
103        // serialize the struct to a string
104        ProfileFuncs::__str__(self)
105    }
106
107    pub fn model_dump_json(&self) -> String {
108        // serialize the struct to a string
109        ProfileFuncs::__json__(self)
110    }
111
112    #[allow(clippy::too_many_arguments)]
113    #[pyo3(signature = (space=None, name=None, version=None, alert_config=None))]
114    pub fn update_config_args(
115        &mut self,
116        space: Option<String>,
117        name: Option<String>,
118        version: Option<String>,
119        alert_config: Option<LLMAlertConfig>,
120    ) -> Result<(), TypeError> {
121        if name.is_some() {
122            self.name = name.ok_or(TypeError::MissingNameError)?;
123        }
124
125        if space.is_some() {
126            self.space = space.ok_or(TypeError::MissingSpaceError)?;
127        }
128
129        if version.is_some() {
130            self.version = version.ok_or(TypeError::MissingVersionError)?;
131        }
132
133        if alert_config.is_some() {
134            self.alert_config = alert_config.ok_or(TypeError::MissingAlertConfigError)?;
135        }
136
137        Ok(())
138    }
139}
140
141/// Validates that a prompt contains at least one required parameter.
142///
143/// LLM evaluation prompts must have either "input" or "output" parameters
144/// to access the data being evaluated.
145///
146/// # Arguments
147/// * `prompt` - The prompt to validate
148/// * `id` - Identifier for error reporting
149///
150/// # Returns
151/// * `Ok(())` if validation passes
152/// * `Err(ProfileError::MissingPromptParametersError)` if no required parameters found
153///
154/// # Errors
155/// Returns an error if the prompt lacks both "input" and "output" parameters.
156fn validate_prompt_parameters(prompt: &Prompt, id: &str) -> Result<(), ProfileError> {
157    let has_at_least_one_param = !prompt.parameters.is_empty();
158
159    if !has_at_least_one_param {
160        return Err(ProfileError::NeedAtLeastOneBoundParameterError(
161            id.to_string(),
162        ));
163    }
164
165    Ok(())
166}
167
168/// Validates that a prompt has the correct response type for scoring.
169///
170/// LLM evaluation prompts must return scores for drift detection.
171///
172/// # Arguments
173/// * `prompt` - The prompt to validate
174/// * `id` - Identifier for error reporting
175///
176/// # Returns
177/// * `Ok(())` if validation passes
178/// * `Err(ProfileError::InvalidResponseType)` if response type is not Score
179fn validate_prompt_response_type(prompt: &Prompt, id: &str) -> Result<(), ProfileError> {
180    if prompt.response_type != ResponseType::Score {
181        return Err(ProfileError::InvalidResponseType(id.to_string()));
182    }
183
184    Ok(())
185}
186
187/// Helper function to safely retrieve a task from the workflow.
188fn get_workflow_task<'a>(
189    workflow: &'a Workflow,
190    task_id: &'a str,
191) -> Result<&'a Arc<std::sync::RwLock<potato_head::Task>>, ProfileError> {
192    workflow
193        .task_list
194        .tasks
195        .get(task_id)
196        .ok_or_else(|| ProfileError::NoTasksFoundError(format!("Task '{task_id}' not found")))
197}
198
199/// Helper function to validate first tasks in workflow execution.
200fn validate_first_tasks(
201    workflow: &Workflow,
202    execution_order: &HashMap<i32, std::collections::HashSet<String>>,
203) -> Result<(), ProfileError> {
204    let first_tasks = execution_order
205        .get(&1)
206        .ok_or_else(|| ProfileError::NoTasksFoundError("No initial tasks found".to_string()))?;
207
208    for task_id in first_tasks {
209        let task = get_workflow_task(workflow, task_id)?;
210        let task_guard = task
211            .read()
212            .map_err(|_| ProfileError::NoTasksFoundError("Failed to read task".to_string()))?;
213
214        validate_prompt_parameters(&task_guard.prompt, &task_guard.id)?;
215    }
216
217    Ok(())
218}
219
220/// Helper function to validate last tasks in workflow execution.
221fn validate_last_tasks(
222    workflow: &Workflow,
223    execution_order: &HashMap<i32, std::collections::HashSet<String>>,
224    metrics: &[LLMMetric],
225) -> Result<(), ProfileError> {
226    let last_step = execution_order.len() as i32;
227    let last_tasks = execution_order
228        .get(&last_step)
229        .ok_or_else(|| ProfileError::NoTasksFoundError("No final tasks found".to_string()))?;
230
231    for task_id in last_tasks {
232        // assert task_id exists in metrics (all output tasks must have a corresponding metric)
233        if !metrics.iter().any(|m| m.name == *task_id) {
234            return Err(ProfileError::MetricNotFoundForOutputTask(task_id.clone()));
235        }
236
237        let task = get_workflow_task(workflow, task_id)?;
238        let task_guard = task
239            .read()
240            .map_err(|_| ProfileError::NoTasksFoundError("Failed to read task".to_string()))?;
241
242        validate_prompt_response_type(&task_guard.prompt, &task_guard.id)?;
243    }
244
245    Ok(())
246}
247
248/// Validates workflow execution parameters and response types.
249///
250/// Ensures that:
251/// - First tasks have required prompt parameters
252/// - Last tasks have Score response type
253///
254/// # Arguments
255/// * `workflow` - The workflow to validate
256///
257/// # Returns
258/// * `Ok(())` if validation passes
259/// * `Err(ProfileError)` if validation fails
260///
261/// # Errors
262/// Returns various ProfileError types based on validation failures.
263fn validate_workflow(workflow: &Workflow, metrics: &[LLMMetric]) -> Result<(), ProfileError> {
264    let execution_order = workflow.execution_plan()?;
265
266    // Validate first tasks have required parameters
267    validate_first_tasks(workflow, &execution_order)?;
268
269    // Validate last tasks have correct response type
270    validate_last_tasks(workflow, &execution_order, metrics)?;
271
272    Ok(())
273}
274
275#[pyclass]
276#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
277pub struct LLMDriftProfile {
278    #[pyo3(get)]
279    pub config: LLMDriftConfig,
280
281    #[pyo3(get)]
282    pub metrics: Vec<LLMMetric>,
283
284    #[pyo3(get)]
285    pub scouter_version: String,
286
287    pub workflow: Workflow,
288
289    #[pyo3(get)]
290    pub metric_names: Vec<String>,
291}
292#[pymethods]
293impl LLMDriftProfile {
294    #[new]
295    #[pyo3(signature = (config, metrics, workflow=None))]
296    /// Create a new LLMDriftProfile
297    /// LLM evaluations are run asynchronously on the scouter server.
298    ///
299    /// # Logic flow:
300    ///  1. If a user provides only a list of metrics, a workflow will be created from the metrics using `from_metrics` method.
301    ///  2. If a user provides a workflow, It will be parsed and validated using `from_workflow` method.
302    ///     - The user must also provide a list of metrics that will be used to evaluate the output of the workflow.
303    ///     - The metric names must correspond to the final task names in the workflow
304    /// In addition, baseline metrics and threshold will be extracted from the LLMMetric.
305    /// # Arguments
306    /// * `config` - LLMDriftConfig - The configuration for the LLM drift profile
307    /// * `metrics` - Option<Bound<'_, PyList>> - Optional list of metrics that will be used to evaluate the LLM
308    /// * `workflow` - Option<Bound<'_, PyAny>> - Optional workflow to use for the LLM drift profile
309    ///
310    /// # Returns
311    /// * `Result<Self, ProfileError>` - The LLMDriftProfile
312    ///
313    /// # Errors
314    /// * `ProfileError::MissingWorkflowError` - If the workflow is
315    #[instrument(skip_all)]
316    pub fn new(
317        config: LLMDriftConfig,
318        metrics: Vec<LLMMetric>,
319        workflow: Option<Bound<'_, PyAny>>,
320    ) -> Result<Self, ProfileError> {
321        match workflow {
322            Some(py_workflow) => {
323                // Extract and validate workflow from Python object
324                let workflow = Self::extract_workflow(&py_workflow).map_err(|e| {
325                    error!("Failed to extract workflow: {}", e);
326                    e
327                })?;
328                validate_workflow(&workflow, &metrics)?;
329                Self::from_workflow(config, workflow, metrics)
330            }
331            None => {
332                // Ensure metrics are provided when no workflow specified
333                if metrics.is_empty() {
334                    return Err(ProfileError::EmptyMetricsList);
335                }
336                Self::from_metrics(config, metrics)
337            }
338        }
339    }
340
341    pub fn __str__(&self) -> String {
342        // serialize the struct to a string
343        ProfileFuncs::__str__(self)
344    }
345
346    pub fn model_dump_json(&self) -> String {
347        // serialize the struct to a string
348        ProfileFuncs::__json__(self)
349    }
350
351    pub fn model_dump(&self, py: Python) -> Result<Py<PyDict>, ProfileError> {
352        let json_str = serde_json::to_string(&self)?;
353
354        let json_value: Value = serde_json::from_str(&json_str)?;
355
356        // Create a new Python dictionary
357        let dict = PyDict::new(py);
358
359        // Convert JSON to Python dict
360        json_to_pyobject(py, &json_value, &dict)?;
361
362        // Return the Python dictionary
363        Ok(dict.into())
364    }
365
366    #[pyo3(signature = (path=None))]
367    pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
368        Ok(ProfileFuncs::save_to_json(
369            self,
370            path,
371            FileName::LLMDriftProfile.to_str(),
372        )?)
373    }
374
375    #[staticmethod]
376    pub fn model_validate(data: &Bound<'_, PyDict>) -> LLMDriftProfile {
377        let json_value = pyobject_to_json(data).unwrap();
378
379        let string = serde_json::to_string(&json_value).unwrap();
380        serde_json::from_str(&string).expect("Failed to load drift profile")
381    }
382
383    #[staticmethod]
384    pub fn model_validate_json(json_string: String) -> LLMDriftProfile {
385        // deserialize the string to a struct
386        serde_json::from_str(&json_string).expect("Failed to load prompt drift profile")
387    }
388
389    #[staticmethod]
390    pub fn from_file(path: PathBuf) -> Result<LLMDriftProfile, ProfileError> {
391        let file = std::fs::read_to_string(&path)?;
392
393        Ok(serde_json::from_str(&file)?)
394    }
395
396    #[allow(clippy::too_many_arguments)]
397    #[pyo3(signature = (space=None, name=None, version=None, alert_config=None))]
398    pub fn update_config_args(
399        &mut self,
400        space: Option<String>,
401        name: Option<String>,
402        version: Option<String>,
403        alert_config: Option<LLMAlertConfig>,
404    ) -> Result<(), TypeError> {
405        self.config
406            .update_config_args(space, name, version, alert_config)
407    }
408
409    /// Create a profile request from the profile
410    pub fn create_profile_request(&self) -> Result<ProfileRequest, TypeError> {
411        let version: Option<String> = if self.config.version == DEFAULT_VERSION {
412            None
413        } else {
414            Some(self.config.version.clone())
415        };
416
417        Ok(ProfileRequest {
418            space: self.config.space.clone(),
419            profile: self.model_dump_json(),
420            drift_type: self.config.drift_type.clone(),
421            version_request: VersionRequest {
422                version,
423                version_type: VersionType::Minor,
424                pre_tag: None,
425                build_tag: None,
426            },
427        })
428    }
429}
430
431impl LLMDriftProfile {
432    /// Creates an LLMDriftProfile from a configuration and a list of metrics.
433    ///
434    /// # Arguments
435    /// * `config` - LLMDriftConfig - The configuration for the LLM
436    /// * `metrics` - Vec<LLMMetric> - The metrics that will be used to evaluate the LLM
437    /// * `scouter_version` - Option<String> - The version of scouter that the profile is created with.
438    /// # Returns
439    /// * `Result<Self, ProfileError>` - The LLMDriftProfile
440    pub fn from_metrics(
441        mut config: LLMDriftConfig,
442        metrics: Vec<LLMMetric>,
443    ) -> Result<Self, ProfileError> {
444        // Build a workflow from metrics
445        let mut workflow = Workflow::new("llm_drift_workflow");
446        let mut agents = HashMap::new();
447        let mut metric_names = Vec::new();
448
449        // Create agents. We don't want to duplicate, so we check if the agent already exists.
450        // if it doesn't, we create it.
451        for metric in &metrics {
452            // get prompt (if providing a list of metrics, prompt must be present)
453            let prompt = metric
454                .prompt
455                .as_ref()
456                .ok_or_else(|| ProfileError::MissingPromptError(metric.name.clone()))?;
457
458            let provider = Provider::from_string(&prompt.model_settings.provider)?;
459
460            let agent = match agents.entry(provider) {
461                Entry::Occupied(entry) => entry.into_mut(),
462                Entry::Vacant(entry) => {
463                    let agent = Agent::from_model_settings(&prompt.model_settings)?;
464                    workflow.add_agent(&agent);
465                    entry.insert(agent)
466                }
467            };
468
469            let task = Task::new(&agent.id, prompt.clone(), &metric.name, None, None);
470            validate_prompt_parameters(prompt, &metric.name)?;
471            workflow.add_task(task)?;
472            metric_names.push(metric.name.clone());
473        }
474
475        config.alert_config.set_alert_conditions(&metrics);
476
477        Ok(Self {
478            config,
479            metrics,
480            scouter_version: scouter_version(),
481            workflow,
482            metric_names,
483        })
484    }
485
486    /// Creates an LLMDriftProfile from a workflow and a list of metrics.
487    /// This is useful when the workflow is already defined and you want to create a profile from it.
488    /// This is also for more advanced use cases where the workflow may need to execute many dependent tasks.
489    /// Because of this, there are a few requirements:
490    /// 1. All beginning tasks in the the workflow must have "input" and/or "output" parameters defined.
491    /// 2. All ending tasks in the workflow must have a response type of "Score".
492    /// 3. The user must also supply a list of metrics that will be used to evaluate the output of the workflow.
493    ///    - The metric names must correspond to the final task names in the workflow.
494    ///
495    /// # Arguments
496    /// * `config` - LLMDriftConfig - The configuration for the LLM
497    /// * `workflow` - Workflow - The workflow that will be used to evaluate the L
498    /// * `metrics` - Vec<LLMMetric> - The metrics that will be used to evaluate the LLM
499    /// * `scouter_version` - Option<String> - The version of scouter that the profile is created with.
500    ///
501    /// # Returns
502    /// * `Result<Self, ProfileError>` - The LLMDriftProfile
503    pub fn from_workflow(
504        mut config: LLMDriftConfig,
505        workflow: Workflow,
506        metrics: Vec<LLMMetric>,
507    ) -> Result<Self, ProfileError> {
508        validate_workflow(&workflow, &metrics)?;
509
510        config.alert_config.set_alert_conditions(&metrics);
511
512        let metric_names = metrics.iter().map(|m| m.name.clone()).collect();
513
514        Ok(Self {
515            config,
516            metrics,
517            scouter_version: scouter_version(),
518            workflow,
519            metric_names,
520        })
521    }
522
523    /// Extracts a Workflow from a Python object.
524    ///
525    /// # Arguments
526    /// * `py_workflow` - Python object that should implement `__workflow__()` method
527    ///
528    /// # Returns
529    /// * `Result<Workflow, ProfileError>` - Extracted workflow
530    ///
531    /// # Errors
532    /// * `ProfileError::InvalidWorkflowType` - If object doesn't implement required interface
533    #[instrument(skip_all)]
534    fn extract_workflow(py_workflow: &Bound<'_, PyAny>) -> Result<Workflow, ProfileError> {
535        debug!("Extracting workflow from Python object");
536
537        if !py_workflow.hasattr("__workflow__")? {
538            error!("Invalid workflow type provided. Expected object with __workflow__ method.");
539            return Err(ProfileError::InvalidWorkflowType);
540        }
541
542        let workflow_string = py_workflow
543            .getattr("__workflow__")
544            .map_err(|e| {
545                error!("Failed to call __workflow__ property: {}", e);
546                ProfileError::InvalidWorkflowType
547            })?
548            .extract::<String>()
549            .inspect_err(|e| {
550                error!(
551                    "Failed to extract workflow string from Python object: {}",
552                    e
553                );
554            })?;
555
556        serde_json::from_str(&workflow_string).map_err(|e| {
557            error!("Failed to deserialize workflow: {}", e);
558            ProfileError::InvalidWorkflowType
559        })
560    }
561
562    pub fn get_metric_value(&self, metric_name: &str) -> Result<f64, ProfileError> {
563        self.metrics
564            .iter()
565            .find(|m| m.name == metric_name)
566            .map(|m| m.value)
567            .ok_or_else(|| ProfileError::MetricNotFound(metric_name.to_string()))
568    }
569}
570
571impl ProfileBaseArgs for LLMDriftProfile {
572    fn get_base_args(&self) -> ProfileArgs {
573        ProfileArgs {
574            name: self.config.name.clone(),
575            space: self.config.space.clone(),
576            version: Some(self.config.version.clone()),
577            schedule: self.config.alert_config.schedule.clone(),
578            scouter_version: self.scouter_version.clone(),
579            drift_type: self.config.drift_type.clone(),
580        }
581    }
582
583    fn to_value(&self) -> Value {
584        serde_json::to_value(self).unwrap()
585    }
586}
587
588#[pyclass]
589#[derive(Debug, Serialize, Deserialize, Clone)]
590pub struct LLMDriftMap {
591    #[pyo3(get)]
592    pub records: Vec<LLMMetricRecord>,
593}
594
595#[pymethods]
596impl LLMDriftMap {
597    #[new]
598    pub fn new(records: Vec<LLMMetricRecord>) -> Self {
599        Self { records }
600    }
601
602    pub fn __str__(&self) -> String {
603        // serialize the struct to a string
604        ProfileFuncs::__str__(self)
605    }
606}
607
608// write test using mock feature
609#[cfg(test)]
610#[cfg(feature = "mock")]
611mod tests {
612
613    use super::*;
614    use crate::AlertThreshold;
615    use crate::{AlertDispatchConfig, OpsGenieDispatchConfig, SlackDispatchConfig};
616    use potato_head::create_score_prompt;
617    use potato_head::prompt::ResponseType;
618
619    use potato_head::{LLMTestServer, Message, PromptContent};
620
621    pub fn create_parameterized_prompt() -> Prompt {
622        let user_content =
623            PromptContent::Str("What is ${input} + ${response} + ${context}?".to_string());
624        let system_content = PromptContent::Str("You are a helpful assistant.".to_string());
625        Prompt::new_rs(
626            vec![Message::new_rs(user_content)],
627            Some("gpt-4o"),
628            Some("openai"),
629            vec![Message::new_rs(system_content)],
630            None,
631            None,
632            ResponseType::Null,
633        )
634        .unwrap()
635    }
636
637    #[test]
638    fn test_llm_drift_config() {
639        let mut drift_config = LLMDriftConfig::new(
640            MISSING,
641            MISSING,
642            "0.1.0",
643            25,
644            LLMAlertConfig::default(),
645            None,
646        )
647        .unwrap();
648        assert_eq!(drift_config.name, "__missing__");
649        assert_eq!(drift_config.space, "__missing__");
650        assert_eq!(drift_config.version, "0.1.0");
651        assert_eq!(
652            drift_config.alert_config.dispatch_config,
653            AlertDispatchConfig::default()
654        );
655
656        let test_slack_dispatch_config = SlackDispatchConfig {
657            channel: "test-channel".to_string(),
658        };
659        let new_alert_config = LLMAlertConfig {
660            schedule: "0 0 * * * *".to_string(),
661            dispatch_config: AlertDispatchConfig::Slack(test_slack_dispatch_config.clone()),
662            ..Default::default()
663        };
664
665        drift_config
666            .update_config_args(None, Some("test".to_string()), None, Some(new_alert_config))
667            .unwrap();
668
669        assert_eq!(drift_config.name, "test");
670        assert_eq!(
671            drift_config.alert_config.dispatch_config,
672            AlertDispatchConfig::Slack(test_slack_dispatch_config)
673        );
674        assert_eq!(
675            drift_config.alert_config.schedule,
676            "0 0 * * * *".to_string()
677        );
678    }
679
680    #[test]
681    fn test_llm_drift_profile_metric() {
682        let _runtime = tokio::runtime::Runtime::new().unwrap();
683        let mut mock = LLMTestServer::new();
684        mock.start_server().unwrap();
685        let prompt = create_score_prompt(Some(vec!["input".to_string()]));
686
687        let metric1 = LLMMetric::new(
688            "metric1",
689            5.0,
690            AlertThreshold::Above,
691            None,
692            Some(prompt.clone()),
693        )
694        .unwrap();
695        let metric2 = LLMMetric::new(
696            "metric2",
697            3.0,
698            AlertThreshold::Below,
699            Some(1.0),
700            Some(prompt.clone()),
701        )
702        .unwrap();
703
704        let llm_metrics = vec![metric1, metric2];
705
706        let alert_config = LLMAlertConfig {
707            schedule: "0 0 * * * *".to_string(),
708            dispatch_config: AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
709                team: "test-team".to_string(),
710                priority: "P5".to_string(),
711            }),
712            ..Default::default()
713        };
714
715        let drift_config =
716            LLMDriftConfig::new("scouter", "ML", "0.1.0", 25, alert_config, None).unwrap();
717
718        let profile = LLMDriftProfile::from_metrics(drift_config, llm_metrics).unwrap();
719        let _: Value =
720            serde_json::from_str(&profile.model_dump_json()).expect("Failed to parse actual JSON");
721
722        assert_eq!(profile.metrics.len(), 2);
723        assert_eq!(profile.scouter_version, env!("CARGO_PKG_VERSION"));
724
725        mock.stop_server().unwrap();
726    }
727
728    #[test]
729    fn test_llm_drift_profile_workflow() {
730        let mut mock = LLMTestServer::new();
731        mock.start_server().unwrap();
732
733        let mut workflow = Workflow::new("My eval Workflow");
734
735        let initial_prompt = create_parameterized_prompt();
736        let final_prompt1 = create_score_prompt(None);
737        let final_prompt2 = create_score_prompt(None);
738
739        let agent1 = Agent::new(Provider::OpenAI, None).unwrap();
740        workflow.add_agent(&agent1);
741
742        // First task with parameters
743        workflow
744            .add_task(Task::new(
745                &agent1.id,
746                initial_prompt.clone(),
747                "task1",
748                None,
749                None,
750            ))
751            .unwrap();
752
753        // Final tasks that depend on the first task
754        workflow
755            .add_task(Task::new(
756                &agent1.id,
757                final_prompt1.clone(),
758                "task2",
759                Some(vec!["task1".to_string()]),
760                None,
761            ))
762            .unwrap();
763
764        workflow
765            .add_task(Task::new(
766                &agent1.id,
767                final_prompt2.clone(),
768                "task3",
769                Some(vec!["task1".to_string()]),
770                None,
771            ))
772            .unwrap();
773
774        let metric1 = LLMMetric::new("task2", 3.0, AlertThreshold::Below, Some(1.0), None).unwrap();
775        let metric2 = LLMMetric::new("task3", 4.0, AlertThreshold::Above, Some(2.0), None).unwrap();
776
777        let llm_metrics = vec![metric1, metric2];
778
779        let alert_config = LLMAlertConfig {
780            schedule: "0 0 * * * *".to_string(),
781            dispatch_config: AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
782                team: "test-team".to_string(),
783                priority: "P5".to_string(),
784            }),
785            ..Default::default()
786        };
787
788        let drift_config =
789            LLMDriftConfig::new("scouter", "ML", "0.1.0", 25, alert_config, None).unwrap();
790
791        let profile = LLMDriftProfile::from_workflow(drift_config, workflow, llm_metrics).unwrap();
792
793        let _: Value =
794            serde_json::from_str(&profile.model_dump_json()).expect("Failed to parse actual JSON");
795
796        assert_eq!(profile.metrics.len(), 2);
797        assert_eq!(profile.scouter_version, env!("CARGO_PKG_VERSION"));
798
799        let plan = profile.workflow.execution_plan().unwrap();
800
801        // plan should have 2 steps
802        assert_eq!(plan.len(), 2);
803        // first step should have 1 task
804        assert_eq!(plan.get(&1).unwrap().len(), 1);
805        // last step should have 2 tasks
806        assert_eq!(plan.get(&(plan.len() as i32)).unwrap().len(), 2);
807
808        mock.stop_server().unwrap();
809    }
810}