scouter_dataframe/parquet/genai/
workflow.rs

1use crate::error::DataFrameError;
2use crate::parquet::traits::ParquetFrame;
3use crate::parquet::types::BinnedTableName;
4use crate::sql::helper::get_binned_genai_workflow_values_query;
5use crate::storage::ObjectStore;
6use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
7use arrow_array::array::{Float64Array, Int32Array, StringArray, TimestampNanosecondArray};
8use arrow_array::RecordBatch;
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use datafusion::dataframe::DataFrame;
12use datafusion::prelude::SessionContext;
13use scouter_settings::ObjectStorageSettings;
14use scouter_types::{GenAIEvalWorkflowResult, ServerRecords, StorageType, ToDriftRecords};
15use std::sync::Arc;
16
17pub struct GenAIWorkflowDataFrame {
18    schema: Arc<Schema>,
19    pub object_store: ObjectStore,
20}
21
22#[async_trait]
23impl ParquetFrame for GenAIWorkflowDataFrame {
24    fn new(storage_settings: &ObjectStorageSettings) -> Result<Self, DataFrameError> {
25        GenAIWorkflowDataFrame::new(storage_settings)
26    }
27
28    async fn get_dataframe(&self, records: ServerRecords) -> Result<DataFrame, DataFrameError> {
29        // Assuming ServerRecords has a method to_genai_workflow_records()
30        let records = records.to_genai_workflow_records()?;
31        let batch = self.build_batch(records)?;
32
33        let ctx = self.object_store.get_session()?;
34        let df = ctx.read_batches(vec![batch])?;
35
36        Ok(df)
37    }
38
39    fn storage_root(&self) -> String {
40        self.object_store.storage_settings.canonicalized_path()
41    }
42
43    fn storage_type(&self) -> StorageType {
44        self.object_store.storage_settings.storage_type.clone()
45    }
46
47    fn get_session_context(&self) -> Result<SessionContext, DataFrameError> {
48        Ok(self.object_store.get_session()?)
49    }
50
51    fn get_binned_sql(
52        &self,
53        bin: &f64,
54        start_time: &DateTime<Utc>,
55        end_time: &DateTime<Utc>,
56        entity_id: &i32,
57    ) -> String {
58        // You'll need to implement this helper for workflows specifically
59        get_binned_genai_workflow_values_query(bin, start_time, end_time, entity_id)
60    }
61
62    fn table_name(&self) -> String {
63        // Ensure this variant exists in your BinnedTableName enum
64        BinnedTableName::GenAIWorkflow.to_string()
65    }
66}
67
68impl GenAIWorkflowDataFrame {
69    pub fn new(storage_settings: &ObjectStorageSettings) -> Result<Self, DataFrameError> {
70        let schema = Arc::new(Schema::new(vec![
71            Field::new("id", DataType::Int64, false),
72            Field::new(
73                "created_at",
74                DataType::Timestamp(TimeUnit::Nanosecond, None),
75                false,
76            ),
77            Field::new("record_uid", DataType::Utf8, false),
78            Field::new("entity_id", DataType::Int32, false),
79            Field::new("total_tasks", DataType::Int32, false),
80            Field::new("passed_tasks", DataType::Int32, false),
81            Field::new("failed_tasks", DataType::Int32, false),
82            Field::new("pass_rate", DataType::Float64, false),
83            Field::new("duration_ms", DataType::Int64, false),
84            Field::new("metric", DataType::Utf8, false),
85            Field::new("execution_plan", DataType::Utf8, false),
86        ]));
87
88        let object_store = ObjectStore::new(storage_settings)?;
89
90        Ok(GenAIWorkflowDataFrame {
91            schema,
92            object_store,
93        })
94    }
95
96    fn build_batch(
97        &self,
98        records: Vec<GenAIEvalWorkflowResult>,
99    ) -> Result<RecordBatch, DataFrameError> {
100        // id
101        let id_array = arrow_array::Int64Array::from_iter_values(records.iter().map(|r| r.id));
102        // created_at
103        let created_at_array = TimestampNanosecondArray::from_iter_values(
104            records
105                .iter()
106                .map(|r| r.created_at.timestamp_nanos_opt().unwrap_or_default()),
107        );
108
109        // 2. record_uid
110        let uid_array =
111            StringArray::from_iter_values(records.iter().map(|r| r.record_uid.as_str()));
112
113        // 3. entity_id
114        let entity_id_array = Int32Array::from_iter_values(records.iter().map(|r| r.entity_id));
115
116        // 4. total_tasks
117        let total_tasks_array = Int32Array::from_iter_values(records.iter().map(|r| r.total_tasks));
118
119        // 5. passed_tasks
120        let passed_tasks_array =
121            Int32Array::from_iter_values(records.iter().map(|r| r.passed_tasks));
122
123        // 6. failed_tasks
124        let failed_tasks_array =
125            Int32Array::from_iter_values(records.iter().map(|r| r.failed_tasks));
126
127        // 7. pass_rate
128        let pass_rate_array = Float64Array::from_iter_values(records.iter().map(|r| r.pass_rate));
129
130        // 8. duration_ms
131        let duration_ms_array =
132            arrow_array::Int64Array::from_iter_values(records.iter().map(|r| r.duration_ms));
133
134        let metric_array = StringArray::from_iter_values(records.iter().map(|_| "workflow"));
135
136        let execution_plan_array = StringArray::from_iter_values(
137            records
138                .iter()
139                .map(|r| serde_json::to_string(&r.execution_plan).unwrap_or_default()),
140        );
141
142        let batch = RecordBatch::try_new(
143            self.schema.clone(),
144            vec![
145                Arc::new(id_array),
146                Arc::new(created_at_array),
147                Arc::new(uid_array),
148                Arc::new(entity_id_array),
149                Arc::new(total_tasks_array),
150                Arc::new(passed_tasks_array),
151                Arc::new(failed_tasks_array),
152                Arc::new(pass_rate_array),
153                Arc::new(duration_ms_array),
154                Arc::new(metric_array),
155                Arc::new(execution_plan_array),
156            ],
157        )
158        .map_err(DataFrameError::ArrowError)?;
159
160        Ok(batch)
161    }
162}