scouter_client/http/
client.rs

1#![allow(clippy::useless_conversion)]
2use crate::error::ClientError;
3use pyo3::{prelude::*, IntoPyObjectExt};
4use scouter_settings::http::HTTPConfig;
5use scouter_types::contracts::{
6    DriftAlertRequest, DriftRequest, GetProfileRequest, ProfileRequest, ProfileStatusRequest,
7};
8use scouter_types::http::{RequestType, Routes};
9use scouter_types::RegisteredProfileResponse;
10
11use crate::http::HTTPClient;
12use scouter_types::{
13    alert::Alert, psi::BinnedPsiFeatureMetrics, spc::SpcDriftFeatures, BinnedMetrics, DriftProfile,
14    DriftType, ProfileFuncs,
15};
16use std::path::PathBuf;
17use tracing::{debug, error};
18
19pub const DOWNLOAD_CHUNK_SIZE: usize = 1024 * 1024 * 5;
20
21#[derive(Debug, Clone)]
22pub struct ScouterClient {
23    client: HTTPClient,
24}
25
26impl ScouterClient {
27    pub fn new(config: Option<HTTPConfig>) -> Result<Self, ClientError> {
28        let client = HTTPClient::new(config.unwrap_or_default())?;
29
30        Ok(ScouterClient { client })
31    }
32
33    /// Insert a profile into the scouter server
34    pub fn insert_profile(
35        &self,
36        request: &ProfileRequest,
37    ) -> Result<RegisteredProfileResponse, ClientError> {
38        let response = self.client.request(
39            Routes::Profile,
40            RequestType::Post,
41            Some(serde_json::to_value(request).unwrap()),
42            None,
43            None,
44        )?;
45
46        if response.status().is_success() {
47            let body = response.bytes()?;
48            let profile_response: RegisteredProfileResponse = serde_json::from_slice(&body)?;
49
50            debug!("Profile inserted successfully: {:?}", profile_response);
51            Ok(profile_response)
52        } else {
53            Err(ClientError::InsertProfileError)
54        }
55    }
56
57    pub fn update_profile_status(
58        &self,
59        request: &ProfileStatusRequest,
60    ) -> Result<bool, ClientError> {
61        let response = self.client.request(
62            Routes::ProfileStatus,
63            RequestType::Put,
64            Some(serde_json::to_value(request).unwrap()),
65            None,
66            None,
67        )?;
68
69        if response.status().is_success() {
70            Ok(true)
71        } else {
72            Err(ClientError::UpdateProfileError)
73        }
74    }
75
76    pub fn get_alerts(&self, request: &DriftAlertRequest) -> Result<Vec<Alert>, ClientError> {
77        debug!("Getting alerts for: {:?}", request);
78
79        let query_string = serde_qs::to_string(request)?;
80
81        let response = self.client.request(
82            Routes::Alerts,
83            RequestType::Get,
84            None,
85            Some(query_string),
86            None,
87        )?;
88
89        // Check response status
90        if !response.status().is_success() {
91            return Err(ClientError::GetDriftAlertError);
92        }
93
94        // Parse response body
95        let body: serde_json::Value = response.json()?;
96
97        // Extract alerts from response
98        let alerts = body
99            .get("alerts")
100            .map(|alerts| {
101                serde_json::from_value::<Vec<Alert>>(alerts.clone()).inspect_err(|e| {
102                    error!(
103                        "Failed to parse drift alerts {:?}. Error: {:?}",
104                        &request, e
105                    );
106                })
107            })
108            .unwrap_or_else(|| {
109                error!("No alerts found in response");
110                Ok(Vec::new())
111            })?;
112
113        Ok(alerts)
114    }
115
116    pub fn get_drift_profile(
117        &self,
118        request: GetProfileRequest,
119    ) -> Result<DriftProfile, ClientError> {
120        let query_string = serde_qs::to_string(&request)?;
121
122        let response = self.client.request(
123            Routes::Profile,
124            RequestType::Get,
125            None,
126            Some(query_string),
127            None,
128        )?;
129
130        // Early return for error status codes
131        if !response.status().is_success() {
132            error!("Failed to get profile. Status: {:?}", response.status());
133            return Err(ClientError::GetDriftProfileError);
134        }
135
136        // Get response body
137        let body = response.bytes()?;
138
139        // Parse JSON response
140        let profile: DriftProfile = serde_json::from_slice(&body)?;
141
142        Ok(profile)
143    }
144
145    /// Check if the scouter server is healthy
146    pub fn check_service_health(&self) -> Result<bool, ClientError> {
147        let response = self
148            .client
149            .request(Routes::Healthcheck, RequestType::Get, None, None, None)
150            .inspect_err(|e| {
151                error!("Failed to check scouter health {}", e);
152            })?;
153
154        if response.status() == 200 {
155            Ok(true)
156        } else {
157            Ok(false)
158        }
159    }
160}
161
162#[pyclass(name = "ScouterClient")]
163pub struct PyScouterClient {
164    client: ScouterClient,
165}
166#[pymethods]
167impl PyScouterClient {
168    #[new]
169    #[pyo3(signature = (config=None))]
170    pub fn new(config: Option<&Bound<'_, PyAny>>) -> Result<Self, ClientError> {
171        let config = config.map_or(Ok(HTTPConfig::default()), |unwrapped| {
172            if unwrapped.is_instance_of::<HTTPConfig>() {
173                unwrapped.extract::<HTTPConfig>()
174            } else {
175                Err(ClientError::InvalidConfigTypeError.into())
176            }
177        })?;
178
179        let client = ScouterClient::new(Some(config.clone()))?;
180
181        Ok(PyScouterClient { client })
182    }
183
184    /// Insert a profile into the scouter server
185    ///
186    /// # Arguments
187    ///
188    /// * `profile` - A profile object to insert
189    ///
190    /// # Returns
191    ///
192    /// * `Ok(())` if the profile was inserted successfully
193    #[pyo3(signature = (profile, set_active=false, deactivate_others=false))]
194    pub fn register_profile(
195        &self,
196        profile: &Bound<'_, PyAny>,
197        set_active: bool,
198        deactivate_others: bool,
199    ) -> Result<bool, ClientError> {
200        let request = profile
201            .call_method0("create_profile_request")?
202            .extract::<ProfileRequest>()?;
203
204        let profile_response = self.client.insert_profile(&request)?;
205
206        // update config args
207        profile.call_method1(
208            "update_config_args",
209            (
210                Some(profile_response.space),
211                Some(profile_response.name),
212                Some(profile_response.version),
213            ),
214        )?;
215
216        debug!("Profile inserted successfully");
217        if set_active {
218            let name = profile
219                .getattr("config")?
220                .getattr("name")?
221                .extract::<String>()?;
222
223            let space = profile
224                .getattr("config")?
225                .getattr("space")?
226                .extract::<String>()?;
227
228            let version = profile
229                .getattr("config")?
230                .getattr("version")?
231                .extract::<String>()?;
232
233            let drift_type = profile
234                .getattr("config")?
235                .getattr("drift_type")?
236                .extract::<DriftType>()?;
237
238            let request = ProfileStatusRequest {
239                name,
240                space,
241                version,
242                active: true,
243                drift_type: Some(drift_type),
244                deactivate_others,
245            };
246
247            self.client.update_profile_status(&request)?;
248        }
249
250        Ok(true)
251    }
252
253    /// Update the status of a profile
254    ///
255    /// # Arguments
256    /// * `request` - A profile status request object
257    ///
258    /// # Returns
259    /// * `Ok(())` if the profile status was updated successfully
260    pub fn update_profile_status(
261        &self,
262        request: ProfileStatusRequest,
263    ) -> Result<bool, ClientError> {
264        self.client.update_profile_status(&request)
265    }
266
267    /// Get binned drift data from the scouter server
268    ///
269    /// # Arguments
270    ///
271    /// * `drift_request` - A drift request object
272    ///
273    /// # Returns
274    ///
275    /// * A binned drift object
276    pub fn get_binned_drift<'py>(
277        &self,
278        py: Python<'py>,
279        drift_request: DriftRequest,
280    ) -> Result<Bound<'py, PyAny>, ClientError> {
281        match drift_request.drift_type {
282            DriftType::Spc => {
283                PyScouterClient::get_spc_binned_drift(py, &self.client.client, drift_request)
284            }
285            DriftType::Psi => {
286                PyScouterClient::get_psi_binned_drift(py, &self.client.client, drift_request)
287            }
288            DriftType::Custom => {
289                PyScouterClient::get_custom_binned_drift(py, &self.client.client, drift_request)
290            }
291            DriftType::LLM => {
292                PyScouterClient::get_llm_metric_binned_drift(py, &self.client.client, drift_request)
293            }
294        }
295    }
296
297    pub fn get_alerts(&self, request: DriftAlertRequest) -> Result<Vec<Alert>, ClientError> {
298        debug!("Getting alerts for: {:?}", request);
299
300        let alerts = self.client.get_alerts(&request)?;
301
302        Ok(alerts)
303    }
304
305    #[pyo3(signature = (request, path))]
306    pub fn download_profile(
307        &self,
308        request: GetProfileRequest,
309        path: Option<PathBuf>,
310    ) -> Result<String, ClientError> {
311        debug!("Downloading profile: {:?}", request);
312
313        let filename = format!(
314            "{}_{}_{}_{}.json",
315            request.name, request.space, request.version, request.drift_type
316        );
317
318        let profile = self.client.get_drift_profile(request)?;
319
320        ProfileFuncs::save_to_json(profile, path.clone(), &filename)?;
321
322        Ok(path.map_or(filename, |p| p.to_string_lossy().to_string()))
323    }
324}
325
326impl PyScouterClient {
327    fn get_spc_binned_drift<'py>(
328        py: Python<'py>,
329        client: &HTTPClient,
330        drift_request: DriftRequest,
331    ) -> Result<Bound<'py, PyAny>, ClientError> {
332        let query_string = serde_qs::to_string(&drift_request)?;
333
334        let response = client.request(
335            Routes::SpcDrift,
336            RequestType::Get,
337            None,
338            Some(query_string),
339            None,
340        )?;
341
342        if response.status().is_client_error() || response.status().is_server_error() {
343            return Err(ClientError::GetDriftDataError);
344        }
345
346        let body = response.bytes()?;
347
348        let results: SpcDriftFeatures = serde_json::from_slice(&body)?;
349
350        Ok(results.into_bound_py_any(py).unwrap())
351    }
352    fn get_psi_binned_drift<'py>(
353        py: Python<'py>,
354        client: &HTTPClient,
355        drift_request: DriftRequest,
356    ) -> Result<Bound<'py, PyAny>, ClientError> {
357        let query_string = serde_qs::to_string(&drift_request)?;
358
359        let response = client.request(
360            Routes::PsiDrift,
361            RequestType::Get,
362            None,
363            Some(query_string),
364            None,
365        )?;
366
367        if response.status().is_client_error() || response.status().is_server_error() {
368            // print response text
369            error!(
370                "Failed to get PSI drift data. Status: {:?}",
371                response.status()
372            );
373            error!("Response text: {:?}", response.text());
374            return Err(ClientError::GetDriftDataError);
375        }
376
377        let body = response.bytes()?;
378
379        let results: BinnedPsiFeatureMetrics = serde_json::from_slice(&body)?;
380
381        Ok(results.into_bound_py_any(py).unwrap())
382    }
383
384    fn get_custom_binned_drift<'py>(
385        py: Python<'py>,
386        client: &HTTPClient,
387        drift_request: DriftRequest,
388    ) -> Result<Bound<'py, PyAny>, ClientError> {
389        let query_string = serde_qs::to_string(&drift_request)?;
390
391        let response = client.request(
392            Routes::CustomDrift,
393            RequestType::Get,
394            None,
395            Some(query_string),
396            None,
397        )?;
398
399        if response.status().is_client_error() || response.status().is_server_error() {
400            return Err(ClientError::GetDriftDataError);
401        }
402
403        let body = response.bytes()?;
404
405        let results: BinnedMetrics = serde_json::from_slice(&body)?;
406
407        Ok(results.into_bound_py_any(py).unwrap())
408    }
409
410    fn get_llm_metric_binned_drift<'py>(
411        py: Python<'py>,
412        client: &HTTPClient,
413        drift_request: DriftRequest,
414    ) -> Result<Bound<'py, PyAny>, ClientError> {
415        let query_string = serde_qs::to_string(&drift_request)?;
416
417        let response = client.request(
418            Routes::LLMDrift,
419            RequestType::Get,
420            None,
421            Some(query_string),
422            None,
423        )?;
424
425        if response.status().is_client_error() || response.status().is_server_error() {
426            return Err(ClientError::GetDriftDataError);
427        }
428
429        let body = response.bytes()?;
430
431        let results: BinnedMetrics = serde_json::from_slice(&body)?;
432
433        Ok(results.into_bound_py_any(py).unwrap())
434    }
435}