Skip to main content

scouter_dataframe/parquet/genai/
task.rs

1use crate::error::DataFrameError;
2use crate::parquet::traits::ParquetFrame;
3use crate::parquet::types::BinnedTableName;
4use crate::sql::helper::get_binned_genai_task_values_query;
5use crate::storage::ObjectStore;
6use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
7use arrow_array::array::{
8    BooleanArray, DictionaryArray, Float64Array, Int32Array, StringArray, TimestampNanosecondArray,
9    UInt32Array, UInt8Array,
10};
11use arrow_array::RecordBatch;
12use async_trait::async_trait;
13use chrono::{DateTime, Utc};
14use datafusion::dataframe::DataFrame;
15use datafusion::prelude::SessionContext;
16use scouter_settings::ObjectStorageSettings;
17use scouter_types::{GenAIEvalTaskResult, ServerRecords, StorageType, ToDriftRecords};
18use std::sync::Arc;
19
20pub struct GenAITaskDataFrame {
21    schema: Arc<Schema>,
22    pub object_store: ObjectStore,
23}
24
25#[async_trait]
26impl ParquetFrame for GenAITaskDataFrame {
27    fn new(storage_settings: &ObjectStorageSettings) -> Result<Self, DataFrameError> {
28        GenAITaskDataFrame::new(storage_settings)
29    }
30
31    async fn get_dataframe(&self, records: ServerRecords) -> Result<DataFrame, DataFrameError> {
32        let records = records.to_genai_task_records()?;
33        let batch = self.build_batch(records)?;
34
35        let ctx = self.object_store.get_session()?;
36
37        let df = ctx.read_batches(vec![batch])?;
38
39        Ok(df)
40    }
41
42    fn storage_root(&self) -> String {
43        self.object_store.storage_settings.canonicalized_path()
44    }
45
46    fn storage_type(&self) -> StorageType {
47        self.object_store.storage_settings.storage_type.clone()
48    }
49
50    fn get_session_context(&self) -> Result<SessionContext, DataFrameError> {
51        Ok(self.object_store.get_session()?)
52    }
53
54    fn get_binned_sql(
55        &self,
56        bin: &f64,
57        start_time: &DateTime<Utc>,
58        end_time: &DateTime<Utc>,
59        entity_id: &i32,
60    ) -> String {
61        get_binned_genai_task_values_query(bin, start_time, end_time, entity_id)
62    }
63
64    fn table_name(&self) -> String {
65        BinnedTableName::GenAITask.to_string()
66    }
67}
68
69impl GenAITaskDataFrame {
70    pub fn new(storage_settings: &ObjectStorageSettings) -> Result<Self, DataFrameError> {
71        let schema = Arc::new(Schema::new(vec![
72            Field::new("entity_id", DataType::Int32, false),
73            Field::new("entity_uid", DataType::Utf8, false),
74            Field::new("record_uid", DataType::Utf8, false),
75            Field::new(
76                "task_id",
77                DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)),
78                false,
79            ),
80            Field::new(
81                "task_type",
82                DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)),
83                false,
84            ),
85            Field::new(
86                "operator",
87                DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)),
88                false,
89            ),
90            Field::new(
91                "created_at",
92                DataType::Timestamp(TimeUnit::Nanosecond, None),
93                false,
94            ),
95            Field::new(
96                "start_time",
97                DataType::Timestamp(TimeUnit::Nanosecond, None),
98                false,
99            ),
100            Field::new(
101                "end_time",
102                DataType::Timestamp(TimeUnit::Nanosecond, None),
103                false,
104            ),
105            Field::new("stage", DataType::Int32, false),
106            Field::new("passed", DataType::Boolean, false),
107            Field::new("condition", DataType::Boolean, false),
108            Field::new("value", DataType::Float64, false),
109            Field::new("assertion", DataType::Utf8, true),
110            Field::new("expected", DataType::Utf8, false),
111            Field::new("actual", DataType::Utf8, false),
112            Field::new("message", DataType::Utf8, false),
113        ]));
114
115        let object_store = ObjectStore::new(storage_settings)?;
116
117        Ok(GenAITaskDataFrame {
118            schema,
119            object_store,
120        })
121    }
122
123    /// Builds an Arrow RecordBatch from GenAIEvalTaskResults
124    /// # Arguments
125    /// * `records` - A vector of references to GenAIEvalTaskResults
126    /// # Returns
127    /// * A RecordBatch containing the data from the records
128    /// # Errors
129    /// * DataFrameError if there is an issue creating the RecordBatch
130    fn build_batch(
131        &self,
132        records: Vec<GenAIEvalTaskResult>,
133    ) -> Result<RecordBatch, DataFrameError> {
134        // Entity and record identifiers
135        let entity_id_array = Int32Array::from_iter_values(records.iter().map(|r| r.entity_id));
136        let entity_uid_array =
137            StringArray::from_iter_values(records.iter().map(|r| r.entity_uid.as_str()));
138        let record_uid_array =
139            StringArray::from_iter_values(records.iter().map(|r| r.record_uid.as_str()));
140
141        let task_id_values =
142            StringArray::from_iter_values(records.iter().map(|r| r.task_id.as_str()));
143        let task_id_keys = UInt32Array::from_iter_values(0..records.len() as u32);
144        let task_id_array = DictionaryArray::new(task_id_keys, Arc::new(task_id_values));
145
146        let task_type_values =
147            StringArray::from_iter_values(records.iter().map(|r| r.task_type.as_str()));
148        let task_type_keys = UInt8Array::from_iter_values(0..records.len() as u8);
149        let task_type_array = DictionaryArray::new(task_type_keys, Arc::new(task_type_values));
150
151        let operator_values =
152            StringArray::from_iter_values(records.iter().map(|r| r.operator.as_str()));
153        let operator_keys = UInt8Array::from_iter_values(0..records.len() as u8);
154        let operator_array = DictionaryArray::new(operator_keys, Arc::new(operator_values));
155
156        let created_at_array = TimestampNanosecondArray::from_iter_values(
157            records
158                .iter()
159                .map(|r| r.created_at.timestamp_nanos_opt().unwrap_or_default()),
160        );
161        let start_time_array = TimestampNanosecondArray::from_iter_values(
162            records
163                .iter()
164                .map(|r| r.start_time.timestamp_nanos_opt().unwrap_or_default()),
165        );
166        let end_time_array = TimestampNanosecondArray::from_iter_values(
167            records
168                .iter()
169                .map(|r| r.end_time.timestamp_nanos_opt().unwrap_or_default()),
170        );
171
172        let stage_array = Int32Array::from_iter_values(records.iter().map(|r| r.stage));
173        let passed_array = BooleanArray::from_iter(records.iter().map(|r| r.passed));
174        let condition_array = BooleanArray::from_iter(records.iter().map(|r| r.condition));
175        let value_array = Float64Array::from_iter_values(records.iter().map(|r| r.value));
176
177        let assertion_array = StringArray::from_iter_values(records.iter().map(|r| r.assertion()));
178        let expected_array =
179            StringArray::from_iter_values(records.iter().map(|r| r.expected.to_string()));
180        let actual_array =
181            StringArray::from_iter_values(records.iter().map(|r| r.actual.to_string()));
182        let message_array =
183            StringArray::from_iter_values(records.iter().map(|r| r.message.as_str()));
184
185        let batch = RecordBatch::try_new(
186            self.schema.clone(),
187            vec![
188                Arc::new(entity_id_array),
189                Arc::new(entity_uid_array),
190                Arc::new(record_uid_array),
191                Arc::new(task_id_array),
192                Arc::new(task_type_array),
193                Arc::new(operator_array),
194                Arc::new(created_at_array),
195                Arc::new(start_time_array),
196                Arc::new(end_time_array),
197                Arc::new(stage_array),
198                Arc::new(passed_array),
199                Arc::new(condition_array),
200                Arc::new(value_array),
201                Arc::new(assertion_array),
202                Arc::new(expected_array),
203                Arc::new(actual_array),
204                Arc::new(message_array),
205            ],
206        )
207        .map_err(DataFrameError::ArrowError)?;
208
209        Ok(batch)
210    }
211}