Skip to main content

scouter_dataframe/parquet/genai/
task.rs

1use crate::error::DataFrameError;
2use crate::parquet::traits::ParquetFrame;
3use crate::parquet::types::BinnedTableName;
4use crate::sql::helper::get_binned_genai_task_values_query;
5use crate::storage::ObjectStore;
6use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
7use arrow_array::array::{Float64Array, Int32Array, StringArray, TimestampNanosecondArray};
8use arrow_array::{BooleanArray, RecordBatch};
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use datafusion::dataframe::DataFrame;
12use datafusion::prelude::SessionContext;
13use scouter_settings::ObjectStorageSettings;
14use scouter_types::{GenAIEvalTaskResult, ServerRecords, StorageType, ToDriftRecords};
15use std::sync::Arc;
16
17pub struct GenAITaskDataFrame {
18    schema: Arc<Schema>,
19    pub object_store: ObjectStore,
20}
21
22#[async_trait]
23impl ParquetFrame for GenAITaskDataFrame {
24    fn new(storage_settings: &ObjectStorageSettings) -> Result<Self, DataFrameError> {
25        GenAITaskDataFrame::new(storage_settings)
26    }
27
28    async fn get_dataframe(&self, records: ServerRecords) -> Result<DataFrame, DataFrameError> {
29        let records = records.to_genai_task_records()?;
30        let batch = self.build_batch(records)?;
31
32        let ctx = self.object_store.get_session()?;
33
34        let df = ctx.read_batches(vec![batch])?;
35
36        Ok(df)
37    }
38
39    fn storage_root(&self) -> String {
40        self.object_store.storage_settings.canonicalized_path()
41    }
42
43    fn storage_type(&self) -> StorageType {
44        self.object_store.storage_settings.storage_type.clone()
45    }
46
47    fn get_session_context(&self) -> Result<SessionContext, DataFrameError> {
48        Ok(self.object_store.get_session()?)
49    }
50
51    fn get_binned_sql(
52        &self,
53        bin: &f64,
54        start_time: &DateTime<Utc>,
55        end_time: &DateTime<Utc>,
56        entity_id: &i32,
57    ) -> String {
58        get_binned_genai_task_values_query(bin, start_time, end_time, entity_id)
59    }
60
61    fn table_name(&self) -> String {
62        BinnedTableName::GenAITask.to_string()
63    }
64}
65
66impl GenAITaskDataFrame {
67    pub fn new(storage_settings: &ObjectStorageSettings) -> Result<Self, DataFrameError> {
68        let schema = Arc::new(Schema::new(vec![
69            Field::new(
70                "created_at",
71                DataType::Timestamp(TimeUnit::Nanosecond, None),
72                false,
73            ),
74            Field::new(
75                "start_time",
76                DataType::Timestamp(TimeUnit::Nanosecond, None),
77                false,
78            ),
79            Field::new(
80                "end_time",
81                DataType::Timestamp(TimeUnit::Nanosecond, None),
82                false,
83            ),
84            Field::new("record_uid", DataType::Utf8, false),
85            Field::new("entity_id", DataType::Int32, false),
86            Field::new("task_id", DataType::Utf8, false),
87            Field::new("task_type", DataType::Utf8, false),
88            Field::new("passed", DataType::Boolean, false),
89            Field::new("value", DataType::Float64, false),
90            Field::new("field_path", DataType::Utf8, true),
91            Field::new("operator", DataType::Utf8, false),
92            Field::new("expected", DataType::Utf8, false),
93            Field::new("actual", DataType::Utf8, false),
94            Field::new("message", DataType::Utf8, false),
95            Field::new("condition", DataType::Boolean, false),
96            Field::new("stage", DataType::Int32, false),
97        ]));
98
99        let object_store = ObjectStore::new(storage_settings)?;
100
101        Ok(GenAITaskDataFrame {
102            schema,
103            object_store,
104        })
105    }
106
107    /// Builds an Arrow RecordBatch from GenAIEvalTaskResults
108    /// # Arguments
109    /// * `records` - A vector of references to GenAIEvalTaskResults
110    /// # Returns
111    /// * A RecordBatch containing the data from the records
112    /// # Errors
113    /// * DataFrameError if there is an issue creating the RecordBatch
114    fn build_batch(
115        &self,
116        records: Vec<GenAIEvalTaskResult>,
117    ) -> Result<RecordBatch, DataFrameError> {
118        let created_at_array = TimestampNanosecondArray::from_iter_values(
119            records
120                .iter()
121                .map(|r| r.created_at.timestamp_nanos_opt().unwrap_or_default()),
122        );
123
124        let start_time_array = TimestampNanosecondArray::from_iter_values(
125            records
126                .iter()
127                .map(|r| r.start_time.timestamp_nanos_opt().unwrap_or_default()),
128        );
129
130        let end_time_array = TimestampNanosecondArray::from_iter_values(
131            records
132                .iter()
133                .map(|r| r.end_time.timestamp_nanos_opt().unwrap_or_default()),
134        );
135
136        let uid_array =
137            StringArray::from_iter_values(records.iter().map(|r| r.record_uid.as_str()));
138
139        let entity_id_array = Int32Array::from_iter_values(records.iter().map(|r| r.entity_id));
140
141        let task_id_array =
142            StringArray::from_iter_values(records.iter().map(|r| r.task_id.as_str()));
143
144        let task_type_array =
145            StringArray::from_iter_values(records.iter().map(|r| r.task_type.as_str()));
146
147        let passed_array = BooleanArray::from_iter(records.iter().map(|r| r.passed));
148
149        let value_array = Float64Array::from_iter_values(records.iter().map(|r| r.value));
150
151        let field_path_array =
152            StringArray::from_iter(records.iter().map(|r| r.field_path.as_deref()));
153
154        let operator_array =
155            StringArray::from_iter_values(records.iter().map(|r| r.operator.as_str()));
156
157        let expected_array =
158            StringArray::from_iter_values(records.iter().map(|r| r.expected.to_string()));
159
160        let actual_array =
161            StringArray::from_iter_values(records.iter().map(|r| r.actual.to_string()));
162
163        let message_array =
164            StringArray::from_iter_values(records.iter().map(|r| r.message.as_str()));
165
166        let condition_array = BooleanArray::from_iter(records.iter().map(|r| r.condition));
167
168        let stage_array = Int32Array::from_iter_values(records.iter().map(|r| r.stage));
169
170        let batch = RecordBatch::try_new(
171            self.schema.clone(),
172            vec![
173                Arc::new(created_at_array),
174                Arc::new(start_time_array),
175                Arc::new(end_time_array),
176                Arc::new(uid_array),
177                Arc::new(entity_id_array),
178                Arc::new(task_id_array),
179                Arc::new(task_type_array),
180                Arc::new(passed_array),
181                Arc::new(value_array),
182                Arc::new(field_path_array),
183                Arc::new(operator_array),
184                Arc::new(expected_array),
185                Arc::new(actual_array),
186                Arc::new(message_array),
187                Arc::new(condition_array),
188                Arc::new(stage_array),
189            ],
190        )
191        .map_err(DataFrameError::ArrowError)?;
192
193        Ok(batch)
194    }
195}