scouter_sql/sql/
schema.rs

1use crate::sql::error::SqlError;
2use chrono::{DateTime, Utc};
3use potato_head::create_uuid7;
4use scouter_types::psi::DistributionData;
5use scouter_types::BoxedLLMDriftServerRecord;
6use scouter_types::LLMDriftServerRecord;
7use scouter_types::{
8    alert::Alert, get_utc_datetime, psi::FeatureBinProportionResult, BinnedMetric,
9    BinnedMetricStats, RecordType,
10};
11use scouter_types::{EntityType, LLMRecord};
12use semver::{BuildMetadata, Prerelease, Version};
13use serde::{Deserialize, Serialize};
14use serde_json::Value;
15use sqlx::{postgres::PgRow, Error, FromRow, Row};
16use std::collections::BTreeMap;
17use std::collections::HashMap;
18
19#[derive(Serialize, Deserialize, Debug, Clone)]
20pub struct DriftRecord {
21    pub created_at: DateTime<Utc>,
22    pub name: String,
23    pub space: String,
24    pub version: String,
25    pub feature: String,
26    pub value: f64,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct SpcFeatureResult {
31    pub feature: String,
32    pub created_at: Vec<DateTime<Utc>>,
33    pub values: Vec<f64>,
34}
35
36impl<'r> FromRow<'r, PgRow> for SpcFeatureResult {
37    fn from_row(row: &'r PgRow) -> Result<Self, Error> {
38        Ok(SpcFeatureResult {
39            feature: row.try_get("feature")?,
40            created_at: row.try_get("created_at")?,
41            values: row.try_get("values")?,
42        })
43    }
44}
45
46#[derive(Debug)]
47pub struct FeatureDistributionWrapper(pub String, pub DistributionData);
48
49impl<'r> FromRow<'r, PgRow> for FeatureDistributionWrapper {
50    fn from_row(row: &'r PgRow) -> Result<Self, Error> {
51        let feature: String = row.try_get("feature")?;
52        let sample_size: i64 = row.try_get("sample_size")?;
53        let bins_json: serde_json::Value = row.try_get("bins")?;
54        let bins: BTreeMap<usize, f64> =
55            serde_json::from_value(bins_json).map_err(|e| Error::Decode(e.into()))?;
56
57        Ok(FeatureDistributionWrapper(
58            feature,
59            DistributionData {
60                sample_size: sample_size as u64,
61                bins,
62            },
63        ))
64    }
65}
66
67pub struct BinnedMetricWrapper(pub BinnedMetric);
68
69impl<'r> FromRow<'r, PgRow> for BinnedMetricWrapper {
70    fn from_row(row: &'r PgRow) -> Result<Self, Error> {
71        let stats_json: Vec<serde_json::Value> = row.try_get("stats")?;
72
73        let stats: Vec<BinnedMetricStats> = stats_json
74            .into_iter()
75            .map(|value| serde_json::from_value(value).unwrap_or_default())
76            .collect();
77
78        Ok(BinnedMetricWrapper(BinnedMetric {
79            metric: row.try_get("metric")?,
80            created_at: row.try_get("created_at")?,
81            stats,
82        }))
83    }
84}
85
86pub struct AlertWrapper(pub Alert);
87
88impl<'r> FromRow<'r, PgRow> for AlertWrapper {
89    fn from_row(row: &'r PgRow) -> Result<Self, Error> {
90        let alert_value: serde_json::Value = row.try_get("alert")?;
91        let alert: BTreeMap<String, String> =
92            serde_json::from_value(alert_value).unwrap_or_default();
93
94        Ok(AlertWrapper(Alert {
95            created_at: row.try_get("created_at")?,
96            name: row.try_get("name")?,
97            space: row.try_get("space")?,
98            version: row.try_get("version")?,
99            alert,
100            entity_name: row.try_get("entity_name")?,
101            id: row.try_get("id")?,
102            drift_type: row.try_get("drift_type")?,
103            active: row.try_get("active")?,
104        }))
105    }
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct TaskRequest {
110    pub name: String,
111    pub space: String,
112    pub version: String,
113    pub profile: String,
114    pub drift_type: String,
115    pub previous_run: DateTime<Utc>,
116    pub schedule: String,
117    pub uid: String,
118}
119
120impl<'r> FromRow<'r, PgRow> for TaskRequest {
121    fn from_row(row: &'r PgRow) -> Result<Self, Error> {
122        let profile: serde_json::Value = row.try_get("profile")?;
123
124        Ok(TaskRequest {
125            name: row.try_get("name")?,
126            space: row.try_get("space")?,
127            version: row.try_get("version")?,
128            profile: profile.to_string(),
129            drift_type: row.try_get("drift_type")?,
130            previous_run: row.try_get("previous_run")?,
131            schedule: row.try_get("schedule")?,
132            uid: row.try_get("uid")?,
133        })
134    }
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct ObservabilityResult {
139    pub route_name: String,
140    pub created_at: Vec<DateTime<Utc>>,
141    pub p5: Vec<f64>,
142    pub p25: Vec<f64>,
143    pub p50: Vec<f64>,
144    pub p95: Vec<f64>,
145    pub p99: Vec<f64>,
146    pub total_request_count: Vec<i64>,
147    pub total_error_count: Vec<i64>,
148    pub error_latency: Vec<f64>,
149    pub status_counts: Vec<HashMap<String, i64>>,
150}
151
152impl<'r> FromRow<'r, PgRow> for ObservabilityResult {
153    fn from_row(row: &'r PgRow) -> Result<Self, Error> {
154        // decode status counts to vec of jsonb
155        let status_counts: Vec<serde_json::Value> = row.try_get("status_counts")?;
156
157        // convert vec of jsonb to vec of hashmaps
158        let status_counts: Vec<HashMap<String, i64>> = status_counts
159            .into_iter()
160            .map(|value| serde_json::from_value(value).unwrap_or_default())
161            .collect();
162
163        Ok(ObservabilityResult {
164            route_name: row.try_get("route_name")?,
165            created_at: row.try_get("created_at")?,
166            p5: row.try_get("p5")?,
167            p25: row.try_get("p25")?,
168            p50: row.try_get("p50")?,
169            p95: row.try_get("p95")?,
170            p99: row.try_get("p99")?,
171            total_request_count: row.try_get("total_request_count")?,
172            total_error_count: row.try_get("total_error_count")?,
173            error_latency: row.try_get("error_latency")?,
174            status_counts,
175        })
176    }
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct BinProportion {
181    pub bin_id: usize,
182    pub proportion: f64,
183}
184
185#[derive(Debug)]
186pub struct FeatureBinProportionResultWrapper(pub FeatureBinProportionResult);
187
188impl<'r> FromRow<'r, PgRow> for FeatureBinProportionResultWrapper {
189    fn from_row(row: &'r PgRow) -> Result<Self, Error> {
190        // Extract the bin_proportions as a Vec of tuples
191        let bin_proportions_json: Vec<serde_json::Value> = row.try_get("bin_proportions")?;
192
193        // Convert the Vec of tuples into a Vec of BinProportion structs
194        let bin_proportions: Vec<BTreeMap<usize, f64>> = bin_proportions_json
195            .into_iter()
196            .map(|json| serde_json::from_value(json).unwrap_or_default())
197            .collect();
198
199        let overall_proportions_json: serde_json::Value = row.try_get("overall_proportions")?;
200        let overall_proportions: BTreeMap<usize, f64> =
201            serde_json::from_value(overall_proportions_json).unwrap_or_default();
202
203        Ok(FeatureBinProportionResultWrapper(
204            FeatureBinProportionResult {
205                feature: row.try_get("feature")?,
206                created_at: row.try_get("created_at")?,
207                bin_proportions,
208                overall_proportions,
209            },
210        ))
211    }
212}
213#[derive(Debug, Clone, FromRow)]
214pub struct Entity {
215    pub space: String,
216    pub name: String,
217    pub version: String,
218    pub begin_timestamp: DateTime<Utc>,
219    pub end_timestamp: DateTime<Utc>,
220}
221
222impl Entity {
223    pub fn get_write_path(&self, record_type: &RecordType) -> String {
224        format!(
225            "{}/{}/{}/{}",
226            self.space, self.name, self.version, record_type
227        )
228    }
229}
230
231#[derive(Debug, Serialize, Deserialize, Clone)]
232pub struct User {
233    pub id: Option<i32>,
234    pub created_at: DateTime<Utc>,
235    pub active: bool,
236    pub username: String,
237    pub password_hash: String,
238    pub hashed_recovery_codes: Vec<String>,
239    pub permissions: Vec<String>,
240    pub group_permissions: Vec<String>,
241    pub role: String,
242    pub favorite_spaces: Vec<String>,
243    pub refresh_token: Option<String>,
244    pub email: String,
245    pub updated_at: DateTime<Utc>,
246}
247
248impl User {
249    #[allow(clippy::too_many_arguments)]
250    pub fn new(
251        username: String,
252        password_hash: String,
253        email: String,
254        hashed_recovery_codes: Vec<String>,
255        permissions: Option<Vec<String>>,
256        group_permissions: Option<Vec<String>>,
257        role: Option<String>,
258        favorite_spaces: Option<Vec<String>>,
259    ) -> Self {
260        let created_at = get_utc_datetime();
261
262        User {
263            id: None,
264            created_at,
265            active: true,
266            username,
267            password_hash,
268            hashed_recovery_codes,
269            permissions: permissions.unwrap_or(vec!["read:all".to_string()]),
270            group_permissions: group_permissions.unwrap_or(vec!["user".to_string()]),
271            favorite_spaces: favorite_spaces.unwrap_or_default(),
272            role: role.unwrap_or("user".to_string()),
273            refresh_token: None,
274            email,
275            updated_at: created_at,
276        }
277    }
278}
279
280impl FromRow<'_, PgRow> for User {
281    fn from_row(row: &PgRow) -> Result<Self, sqlx::Error> {
282        let id = row.try_get("id")?;
283        let created_at = row.try_get("created_at")?;
284        let updated_at = row.try_get("updated_at")?;
285        let active = row.try_get("active")?;
286        let username = row.try_get("username")?;
287        let password_hash = row.try_get("password_hash")?;
288        let email = row.try_get("email")?;
289        let role = row.try_get("role")?;
290        let refresh_token = row.try_get("refresh_token")?;
291
292        let group_permissions: Vec<String> =
293            serde_json::from_value(row.try_get("group_permissions")?).unwrap_or_default();
294
295        let permissions: Vec<String> =
296            serde_json::from_value(row.try_get("permissions")?).unwrap_or_default();
297
298        let hashed_recovery_codes: Vec<String> =
299            serde_json::from_value(row.try_get("hashed_recovery_codes")?).unwrap_or_default();
300
301        let favorite_spaces: Vec<String> =
302            serde_json::from_value(row.try_get("favorite_spaces")?).unwrap_or_default();
303
304        Ok(User {
305            id,
306            created_at,
307            updated_at,
308            active,
309            username,
310            password_hash,
311            email,
312            role,
313            refresh_token,
314            hashed_recovery_codes,
315            permissions,
316            group_permissions,
317            favorite_spaces,
318        })
319    }
320}
321
322#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
323pub struct UpdateAlertResult {
324    pub id: i32,
325    pub active: bool,
326    pub updated_at: DateTime<Utc>,
327}
328
329#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
330pub struct LLMDriftServerSQLRecord {
331    pub uid: String,
332
333    pub created_at: chrono::DateTime<Utc>,
334
335    pub space: String,
336
337    pub name: String,
338
339    pub version: String,
340
341    pub prompt: Option<Value>,
342
343    pub context: Value,
344
345    pub score: Value,
346
347    pub status: String,
348
349    pub id: i64,
350
351    pub updated_at: Option<DateTime<Utc>>,
352
353    pub processing_started_at: Option<DateTime<Utc>>,
354
355    pub processing_ended_at: Option<DateTime<Utc>>,
356
357    pub processing_duration: Option<i32>,
358}
359
360impl LLMDriftServerSQLRecord {
361    /// Method use when server receives a record from the client
362    pub fn from_server_record(record: &LLMDriftServerRecord) -> Self {
363        LLMDriftServerSQLRecord {
364            created_at: record.created_at,
365            space: record.space.clone(),
366            name: record.name.clone(),
367            version: record.version.clone(),
368            prompt: record.prompt.clone(),
369            context: record.context.clone(),
370            score: record.score.clone(),
371            status: record.status.to_string(),
372            id: 0,               // This is a placeholder, as the ID will be set by the database
373            uid: create_uuid7(), // This is also a placeholder, as the UID will be set by the database
374            updated_at: None,
375            processing_started_at: None,
376            processing_ended_at: None,
377            processing_duration: None, // This will be set when the record is processed
378        }
379    }
380}
381
382impl From<LLMDriftServerSQLRecord> for LLMDriftServerRecord {
383    fn from(sql_record: LLMDriftServerSQLRecord) -> Self {
384        Self {
385            id: sql_record.id,
386            created_at: sql_record.created_at,
387            space: sql_record.space,
388            name: sql_record.name,
389            version: sql_record.version,
390            context: sql_record.context,
391            score: sql_record.score,
392            prompt: sql_record.prompt,
393            status: sql_record.status.parse().unwrap_or_default(), // Handle parsing appropriately
394            processing_started_at: sql_record.processing_started_at,
395            processing_ended_at: sql_record.processing_ended_at,
396            processing_duration: sql_record.processing_duration,
397            updated_at: sql_record.updated_at,
398            uid: sql_record.uid,
399        }
400    }
401}
402
403/// Converts a `PgRow` to a `BoxedLLMDriftServerRecord`
404/// Conversion is done by first converting the row to an `LLMDriftServerSQLRecord`
405/// and then converting that to an `LLMDriftServerRecord`.
406pub fn llm_drift_record_from_row(row: &PgRow) -> Result<BoxedLLMDriftServerRecord, SqlError> {
407    let sql_record = LLMDriftServerSQLRecord::from_row(row)?;
408    let record = LLMDriftServerRecord::from(sql_record);
409
410    Ok(BoxedLLMDriftServerRecord {
411        record: Box::new(record),
412    })
413}
414
415pub fn llm_drift_metric_from_row(row: &PgRow) -> Result<BoxedLLMDriftServerRecord, SqlError> {
416    let sql_record = LLMDriftServerSQLRecord::from_row(row)?;
417    let record = LLMDriftServerRecord::from(sql_record);
418
419    Ok(BoxedLLMDriftServerRecord {
420        record: Box::new(record),
421    })
422}
423
424pub struct LLMRecordWrapper(pub LLMRecord);
425
426impl<'r> FromRow<'r, PgRow> for LLMRecordWrapper {
427    fn from_row(row: &'r PgRow) -> Result<Self, sqlx::Error> {
428        let llm_record = LLMRecord {
429            uid: row.try_get("uid")?,
430            created_at: row.try_get("created_at")?,
431            space: row.try_get("space")?,
432            name: row.try_get("name")?,
433            version: row.try_get("version")?,
434            context: row.try_get("context")?,
435            prompt: row.try_get("prompt")?,
436            score: row.try_get("score")?,
437            entity_type: EntityType::LLM,
438        };
439        Ok(Self(llm_record))
440    }
441}
442
443#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
444pub struct VersionResult {
445    pub created_at: DateTime<Utc>,
446    pub name: String,
447    pub space: String,
448    pub major: i32,
449    pub minor: i32,
450    pub patch: i32,
451    pub pre_tag: Option<String>,
452    pub build_tag: Option<String>,
453}
454
455impl VersionResult {
456    pub fn to_version(&self) -> Result<Version, SqlError> {
457        let mut version = Version::new(self.major as u64, self.minor as u64, self.patch as u64);
458
459        if self.pre_tag.is_some() {
460            version.pre = Prerelease::new(self.pre_tag.as_ref().unwrap())?;
461        }
462
463        if self.build_tag.is_some() {
464            version.build = BuildMetadata::new(self.build_tag.as_ref().unwrap())?;
465        }
466
467        Ok(version)
468    }
469}