1#![allow(clippy::useless_conversion)]
2use crate::error::ClientError;
3use pyo3::{prelude::*, IntoPyObjectExt};
4use scouter_settings::http::HTTPConfig;
5use scouter_types::contracts::{
6 DriftAlertRequest, DriftRequest, GetProfileRequest, ProfileRequest, ProfileStatusRequest,
7};
8use scouter_types::http::{RequestType, Routes};
9use scouter_types::RegisteredProfileResponse;
10
11use crate::http::HTTPClient;
12use scouter_types::{
13 alert::Alert, psi::BinnedPsiFeatureMetrics, spc::SpcDriftFeatures, BinnedMetrics, DriftProfile,
14 DriftType, ProfileFuncs,
15};
16use std::path::PathBuf;
17use tracing::{debug, error};
18
19pub const DOWNLOAD_CHUNK_SIZE: usize = 1024 * 1024 * 5;
20
21#[derive(Debug, Clone)]
22pub struct ScouterClient {
23 client: HTTPClient,
24}
25
26impl ScouterClient {
27 pub fn new(config: Option<HTTPConfig>) -> Result<Self, ClientError> {
28 let client = HTTPClient::new(config.unwrap_or_default())?;
29
30 Ok(ScouterClient { client })
31 }
32
33 pub fn insert_profile(
35 &self,
36 request: &ProfileRequest,
37 ) -> Result<RegisteredProfileResponse, ClientError> {
38 let response = self.client.request(
39 Routes::Profile,
40 RequestType::Post,
41 Some(serde_json::to_value(request).unwrap()),
42 None,
43 None,
44 )?;
45
46 if response.status().is_success() {
47 let body = response.bytes()?;
48 let profile_response: RegisteredProfileResponse = serde_json::from_slice(&body)?;
49
50 debug!("Profile inserted successfully: {:?}", profile_response);
51 Ok(profile_response)
52 } else {
53 Err(ClientError::InsertProfileError)
54 }
55 }
56
57 pub fn update_profile_status(
58 &self,
59 request: &ProfileStatusRequest,
60 ) -> Result<bool, ClientError> {
61 let response = self.client.request(
62 Routes::ProfileStatus,
63 RequestType::Put,
64 Some(serde_json::to_value(request).unwrap()),
65 None,
66 None,
67 )?;
68
69 if response.status().is_success() {
70 Ok(true)
71 } else {
72 Err(ClientError::UpdateProfileError)
73 }
74 }
75
76 pub fn get_alerts(&self, request: &DriftAlertRequest) -> Result<Vec<Alert>, ClientError> {
77 debug!("Getting alerts for: {:?}", request);
78
79 let query_string = serde_qs::to_string(request)?;
80
81 let response = self.client.request(
82 Routes::Alerts,
83 RequestType::Get,
84 None,
85 Some(query_string),
86 None,
87 )?;
88
89 if !response.status().is_success() {
91 return Err(ClientError::GetDriftAlertError);
92 }
93
94 let body: serde_json::Value = response.json()?;
96
97 let alerts = body
99 .get("alerts")
100 .map(|alerts| {
101 serde_json::from_value::<Vec<Alert>>(alerts.clone()).inspect_err(|e| {
102 error!(
103 "Failed to parse drift alerts {:?}. Error: {:?}",
104 &request, e
105 );
106 })
107 })
108 .unwrap_or_else(|| {
109 error!("No alerts found in response");
110 Ok(Vec::new())
111 })?;
112
113 Ok(alerts)
114 }
115
116 pub fn get_drift_profile(
117 &self,
118 request: GetProfileRequest,
119 ) -> Result<DriftProfile, ClientError> {
120 let query_string = serde_qs::to_string(&request)?;
121
122 let response = self.client.request(
123 Routes::Profile,
124 RequestType::Get,
125 None,
126 Some(query_string),
127 None,
128 )?;
129
130 if !response.status().is_success() {
132 error!("Failed to get profile. Status: {:?}", response.status());
133 return Err(ClientError::GetDriftProfileError);
134 }
135
136 let body = response.bytes()?;
138
139 let profile: DriftProfile = serde_json::from_slice(&body)?;
141
142 Ok(profile)
143 }
144
145 pub fn check_service_health(&self) -> Result<bool, ClientError> {
147 let response = self
148 .client
149 .request(Routes::Healthcheck, RequestType::Get, None, None, None)
150 .inspect_err(|e| {
151 error!("Failed to check scouter health {}", e);
152 })?;
153
154 if response.status() == 200 {
155 Ok(true)
156 } else {
157 Ok(false)
158 }
159 }
160}
161
162#[pyclass(name = "ScouterClient")]
163pub struct PyScouterClient {
164 client: ScouterClient,
165}
166#[pymethods]
167impl PyScouterClient {
168 #[new]
169 #[pyo3(signature = (config=None))]
170 pub fn new(config: Option<&Bound<'_, PyAny>>) -> Result<Self, ClientError> {
171 let config = config.map_or(Ok(HTTPConfig::default()), |unwrapped| {
172 if unwrapped.is_instance_of::<HTTPConfig>() {
173 unwrapped.extract::<HTTPConfig>()
174 } else {
175 Err(ClientError::InvalidConfigTypeError.into())
176 }
177 })?;
178
179 let client = ScouterClient::new(Some(config.clone()))?;
180
181 Ok(PyScouterClient { client })
182 }
183
184 #[pyo3(signature = (profile, set_active=false, deactivate_others=false))]
194 pub fn register_profile(
195 &self,
196 profile: &Bound<'_, PyAny>,
197 set_active: bool,
198 deactivate_others: bool,
199 ) -> Result<bool, ClientError> {
200 let request = profile
201 .call_method0("create_profile_request")?
202 .extract::<ProfileRequest>()?;
203
204 let profile_response = self.client.insert_profile(&request)?;
205
206 profile.call_method1(
208 "update_config_args",
209 (
210 Some(profile_response.space),
211 Some(profile_response.name),
212 Some(profile_response.version),
213 ),
214 )?;
215
216 debug!("Profile inserted successfully");
217 if set_active {
218 let name = profile
219 .getattr("config")?
220 .getattr("name")?
221 .extract::<String>()?;
222
223 let space = profile
224 .getattr("config")?
225 .getattr("space")?
226 .extract::<String>()?;
227
228 let version = profile
229 .getattr("config")?
230 .getattr("version")?
231 .extract::<String>()?;
232
233 let drift_type = profile
234 .getattr("config")?
235 .getattr("drift_type")?
236 .extract::<DriftType>()?;
237
238 let request = ProfileStatusRequest {
239 name,
240 space,
241 version,
242 active: true,
243 drift_type: Some(drift_type),
244 deactivate_others,
245 };
246
247 self.client.update_profile_status(&request)?;
248 }
249
250 Ok(true)
251 }
252
253 pub fn update_profile_status(
261 &self,
262 request: ProfileStatusRequest,
263 ) -> Result<bool, ClientError> {
264 self.client.update_profile_status(&request)
265 }
266
267 pub fn get_binned_drift<'py>(
277 &self,
278 py: Python<'py>,
279 drift_request: DriftRequest,
280 ) -> Result<Bound<'py, PyAny>, ClientError> {
281 match drift_request.drift_type {
282 DriftType::Spc => {
283 PyScouterClient::get_spc_binned_drift(py, &self.client.client, drift_request)
284 }
285 DriftType::Psi => {
286 PyScouterClient::get_psi_binned_drift(py, &self.client.client, drift_request)
287 }
288 DriftType::Custom => {
289 PyScouterClient::get_custom_binned_drift(py, &self.client.client, drift_request)
290 }
291 DriftType::LLM => {
292 PyScouterClient::get_llm_metric_binned_drift(py, &self.client.client, drift_request)
293 }
294 }
295 }
296
297 pub fn get_alerts(&self, request: DriftAlertRequest) -> Result<Vec<Alert>, ClientError> {
298 debug!("Getting alerts for: {:?}", request);
299
300 let alerts = self.client.get_alerts(&request)?;
301
302 Ok(alerts)
303 }
304
305 #[pyo3(signature = (request, path))]
306 pub fn download_profile(
307 &self,
308 request: GetProfileRequest,
309 path: Option<PathBuf>,
310 ) -> Result<String, ClientError> {
311 debug!("Downloading profile: {:?}", request);
312
313 let filename = format!(
314 "{}_{}_{}_{}.json",
315 request.name, request.space, request.version, request.drift_type
316 );
317
318 let profile = self.client.get_drift_profile(request)?;
319
320 ProfileFuncs::save_to_json(profile, path.clone(), &filename)?;
321
322 Ok(path.map_or(filename, |p| p.to_string_lossy().to_string()))
323 }
324}
325
326impl PyScouterClient {
327 fn get_spc_binned_drift<'py>(
328 py: Python<'py>,
329 client: &HTTPClient,
330 drift_request: DriftRequest,
331 ) -> Result<Bound<'py, PyAny>, ClientError> {
332 let query_string = serde_qs::to_string(&drift_request)?;
333
334 let response = client.request(
335 Routes::SpcDrift,
336 RequestType::Get,
337 None,
338 Some(query_string),
339 None,
340 )?;
341
342 if response.status().is_client_error() || response.status().is_server_error() {
343 return Err(ClientError::GetDriftDataError);
344 }
345
346 let body = response.bytes()?;
347
348 let results: SpcDriftFeatures = serde_json::from_slice(&body)?;
349
350 Ok(results.into_bound_py_any(py).unwrap())
351 }
352 fn get_psi_binned_drift<'py>(
353 py: Python<'py>,
354 client: &HTTPClient,
355 drift_request: DriftRequest,
356 ) -> Result<Bound<'py, PyAny>, ClientError> {
357 let query_string = serde_qs::to_string(&drift_request)?;
358
359 let response = client.request(
360 Routes::PsiDrift,
361 RequestType::Get,
362 None,
363 Some(query_string),
364 None,
365 )?;
366
367 if response.status().is_client_error() || response.status().is_server_error() {
368 error!(
370 "Failed to get PSI drift data. Status: {:?}",
371 response.status()
372 );
373 error!("Response text: {:?}", response.text());
374 return Err(ClientError::GetDriftDataError);
375 }
376
377 let body = response.bytes()?;
378
379 let results: BinnedPsiFeatureMetrics = serde_json::from_slice(&body)?;
380
381 Ok(results.into_bound_py_any(py).unwrap())
382 }
383
384 fn get_custom_binned_drift<'py>(
385 py: Python<'py>,
386 client: &HTTPClient,
387 drift_request: DriftRequest,
388 ) -> Result<Bound<'py, PyAny>, ClientError> {
389 let query_string = serde_qs::to_string(&drift_request)?;
390
391 let response = client.request(
392 Routes::CustomDrift,
393 RequestType::Get,
394 None,
395 Some(query_string),
396 None,
397 )?;
398
399 if response.status().is_client_error() || response.status().is_server_error() {
400 return Err(ClientError::GetDriftDataError);
401 }
402
403 let body = response.bytes()?;
404
405 let results: BinnedMetrics = serde_json::from_slice(&body)?;
406
407 Ok(results.into_bound_py_any(py).unwrap())
408 }
409
410 fn get_llm_metric_binned_drift<'py>(
411 py: Python<'py>,
412 client: &HTTPClient,
413 drift_request: DriftRequest,
414 ) -> Result<Bound<'py, PyAny>, ClientError> {
415 let query_string = serde_qs::to_string(&drift_request)?;
416
417 let response = client.request(
418 Routes::LLMDrift,
419 RequestType::Get,
420 None,
421 Some(query_string),
422 None,
423 )?;
424
425 if response.status().is_client_error() || response.status().is_server_error() {
426 return Err(ClientError::GetDriftDataError);
427 }
428
429 let body = response.bytes()?;
430
431 let results: BinnedMetrics = serde_json::from_slice(&body)?;
432
433 Ok(results.into_bound_py_any(py).unwrap())
434 }
435}