1pub mod artifact;
48pub mod config;
49mod error;
50pub mod experiment;
51mod http;
52pub mod modelversion;
53pub mod registered;
54pub mod run;
55
56use crate::config::{BasicAuth, ClientConfig, DownloadConfig, MlflowConfig};
57use crate::http::Response;
58use crate::{artifact::*, experiment::*, modelversion::*, registered::*, run::*};
59use anyhow::{anyhow, Context, Result};
60use backon::{ExponentialBuilder, Retryable};
61use reqwest::{Method, RequestBuilder};
62use serde::{de::DeserializeOwned, Deserialize, Serialize};
63use serde_json::json;
64use std::collections::HashMap;
65use std::path::Path;
66use tokio::{
67 fs,
68 fs::{create_dir_all, File},
69 io::AsyncWriteExt,
70 task::JoinSet,
71};
72use tracing::{debug, info, warn};
73
74#[derive(Clone, Debug)]
76pub struct Client {
77 pub client: reqwest::Client,
78 pub mlflow: MlflowConfig,
80 pub download: DownloadConfig,
82}
83
84pub(crate) mod bs {
85 use serde::Deserialize;
86
87 #[derive(Deserialize)]
88 pub struct Dummy {}
89
90 #[derive(Deserialize)]
91 pub struct Run {
92 pub run: super::Run,
93 }
94
95 #[derive(Deserialize)]
96 pub struct Runs {
97 #[serde(default)]
98 pub runs: Vec<super::Run>,
99 }
100
101 #[derive(Deserialize)]
102 pub struct RunInfo {
103 pub run_info: super::RunInfo,
104 }
105
106 #[derive(Deserialize)]
107 pub struct ModelVersion {
108 pub model_version: super::ModelVersion,
109 }
110
111 #[derive(Deserialize)]
112 pub struct RegisteredModel {
113 pub registered_model: super::RegisteredModel,
114 }
115
116 #[derive(Deserialize)]
117 pub struct RegisteredModels {
118 #[serde(default)]
119 pub registered_models: Vec<super::RegisteredModel>,
120 }
121}
122
123impl Client {
125 pub fn new(urlbase: impl AsRef<str>) -> Self {
128 Self {
129 client: reqwest::Client::new(),
130 mlflow: MlflowConfig {
131 urlbase: urlbase.as_ref().to_owned(),
132 auth: None,
133 },
134 download: DownloadConfig {
135 retry: false,
136 tasks: 16,
137 blacklist: Default::default(),
138 cache_local_artifacts: false,
139 file_size_check: true,
140 },
141 }
142 }
143
144 pub fn new_from_config(config: ClientConfig) -> Self {
146 Self {
147 client: reqwest::Client::new(),
148 mlflow: config.mlflow,
149 download: config.download,
150 }
151 }
152
153 pub fn with_auth(mut self, auth: BasicAuth) -> Self {
155 self.mlflow.auth = Some(auth);
156 self
157 }
158
159 fn make_url(&self, path: impl AsRef<str>) -> String {
160 format!("{}/{}", self.mlflow.urlbase, path.as_ref())
161 }
162
163 async fn send_json_request_no_retry<P: Serialize>(
165 &self,
166 method: Method,
167 path: impl AsRef<str>,
168 payload: &P,
169 ) -> Result<Response> {
170 let mut request = self
171 .client
172 .request(method, self.make_url(path.as_ref()))
173 .json(&payload);
174
175 request = self.assign_basic_auth(request);
176
177 Response(request.send().await?).error_for_status().await
178 }
179
180 async fn send_request<P: Serialize>(
182 &self,
183 method: Method,
184 path: impl AsRef<str>,
185 payload: P,
186 ) -> Result<Response> {
187 let future = || async {
188 self.send_json_request_no_retry(method.clone(), path.as_ref(), &payload)
189 .await
190 };
191
192 future
193 .retry(ExponentialBuilder::default())
194 .notify(|error, duration| {
195 warn!("Retrying in {:?}, got an error: {:?}", error, duration);
196 })
197 .await
198 .with_context(|| format!("Cannot send request to: {:?}", path.as_ref()))
199 }
200
201 fn assign_basic_auth(&self, request: RequestBuilder) -> RequestBuilder {
203 if let Some(auth) = self.mlflow.auth.as_ref() {
204 request.basic_auth(
205 auth.user.expose_secret().to_owned(),
206 Some(auth.password.expose_secret().to_owned()),
207 )
208 } else {
209 request
210 }
211 }
212
213 pub async fn create_experiment(&self, name: &str, tags: Vec<KeyValue>) -> Result<String> {
218 let path = "2.0/mlflow/experiments/create";
219 let request = CreateExperiment {
220 name: name.to_owned(),
221 tags,
222 };
223
224 self.send_request(Method::POST, path, request)
225 .await?
226 .json::<ExperimentId>()
227 .await
228 .map(|response| response.experiment_id)
229 }
230
231 pub async fn delete_experiment(&self, experiment_id: &str) -> Result<()> {
235 let path = "2.0/mlflow/experiments/delete";
236 let request = ExperimentId {
237 experiment_id: experiment_id.to_owned(),
238 };
239
240 self.send_request(Method::POST, path, request)
241 .await
242 .map(|_| ())
243 }
244
245 pub async fn get_experiment(&self, experiment_id: &str) -> Result<Experiment> {
248 let path = "2.0/mlflow/experiments/get";
249 let request = ExperimentId {
250 experiment_id: experiment_id.to_owned(),
251 };
252
253 #[derive(Deserialize)]
254 struct Response {
255 experiment: Experiment,
256 }
257
258 self.send_request(Method::GET, path, request)
259 .await?
260 .json::<Response>()
261 .await
262 .map(|response| response.experiment)
263 }
264
265 pub async fn get_experiment_by_name(&self, name: &str) -> Result<Experiment> {
272 let path = "2.0/mlflow/experiments/get-by-name";
273 let request = json!({
274 "experiment_name": name
275 });
276
277 #[derive(Deserialize)]
278 struct Response {
279 experiment: Experiment,
280 }
281
282 self.send_request(Method::GET, path, request)
283 .await?
284 .json::<Response>()
285 .await
286 .map(|response| response.experiment)
287 }
288
289 pub async fn get_or_create_experiment(
292 &self,
293 name: &str,
294 tags: Vec<KeyValue>,
295 ) -> Result<Experiment> {
296 let path = "2.0/mlflow/experiments/get-by-name";
297 let request = json!({
298 "experiment_name": name
299 });
300
301 #[derive(Deserialize)]
302 struct Response {
303 experiment: Experiment,
304 }
305
306 let result = self
307 .send_json_request_no_retry(Method::GET, path, &request)
308 .await;
309
310 match result {
311 Err(_) => {
312 let experiment_id = self.create_experiment(name, tags).await?;
313 self.get_experiment(&experiment_id).await
314 }
315 Ok(response) => response
316 .json::<Response>()
317 .await
318 .map(|response| response.experiment),
319 }
320 }
321
322 pub async fn create_run(&self, create: CreateRun) -> Result<Run> {
326 let path = "2.0/mlflow/runs/create";
327
328 self.send_request(Method::POST, path, create)
329 .await?
330 .json::<bs::Run>()
331 .await
332 .map(|value| value.run)
333 }
334
335 pub async fn get_run(&self, run_id: &str) -> Result<Run> {
340 let path = "2.0/mlflow/runs/get";
341 let request = Request {
342 run_id: run_id.to_string(),
343 };
344
345 #[derive(Serialize)]
346 struct Request {
347 run_id: String,
348 }
349
350 self.send_request(Method::GET, path, request)
351 .await?
352 .json::<bs::Run>()
353 .await
354 .map(|value| value.run)
355 }
356
357 pub async fn search_runs(&self, search: SearchRuns) -> Result<Vec<Run>> {
360 let path = "2.0/mlflow/runs/search";
361
362 self.send_request(Method::POST, path, search)
363 .await?
364 .json::<bs::Runs>()
365 .await
366 .map(|response| response.runs)
367 }
368
369 pub async fn update_run(&self, update: UpdateRun) -> Result<RunInfo> {
371 let path = "2.0/mlflow/runs/update";
372
373 self.send_request(Method::POST, path, update)
374 .await?
375 .json::<bs::RunInfo>()
376 .await
377 .map(|value| value.run_info)
378 }
379
380 pub async fn add_run_meta(&self, run_id: &str, meta: RunData) -> Result<()> {
385 let path = "2.0/mlflow/runs/log-batch";
386 let request = json!({
387 "run_id": run_id.to_owned(),
388 "metrics": meta.metrics,
389 "params": meta.params,
390 "tags": meta.tags
391 });
392
393 self.send_request(Method::POST, path, request)
394 .await
395 .map(|_| ())
396 }
397
398 pub async fn delete_run(&self, run_id: &str) -> Result<()> {
399 let path = "2.0/mlflow/runs/delete";
400
401 self.send_request(Method::POST, path, json!({"run_id": run_id}))
402 .await
403 .map(|_| ())
404 }
405
406 pub async fn add_run_inputs(&self, run_id: &str, inputs: Vec<DataSetInput>) -> Result<()> {
407 let path = "2.0/mlflow/runs/log-inputs";
408 let request = json!({
409 "run_id": run_id.to_owned(),
410 "datasets": inputs,
411 });
412
413 self.send_request(Method::POST, path, request)
414 .await
415 .map(|_| ())
416 }
417
418 pub async fn delete_registered_model(&self, name: &str) -> Result<()> {
419 let path = "2.0/mlflow/registered-models/delete";
420 let request = json!({"name": name.to_owned()});
421
422 self.send_request(Method::DELETE, path, request)
423 .await
424 .map(|_| ())
425 }
426
427 pub async fn register_model(&self, request: RegisterModel) -> Result<RegisteredModel> {
428 let path = "2.0/mlflow/registered-models/create";
429
430 self.send_request(Method::POST, path, request)
431 .await?
432 .json::<bs::RegisteredModel>()
433 .await
434 .map(|response| response.registered_model)
435 }
436
437 pub async fn get_registered_model(&self, name: &str) -> Result<RegisteredModel> {
438 let path = "2.0/mlflow/registered-models/get";
439 let request = json!({"name": name.to_owned()});
440
441 self.send_request(Method::GET, path, request)
442 .await?
443 .json::<bs::RegisteredModel>()
444 .await
445 .map(|response| response.registered_model)
446 }
447
448 pub async fn search_registered_models(&self, filter: &str) -> Result<Vec<RegisteredModel>> {
449 let path = "2.0/mlflow/registered-models/search";
450 let request = json!({"filter": filter.to_owned()});
451
452 self.send_request(Method::GET, path, request)
453 .await?
454 .json::<bs::RegisteredModels>()
455 .await
456 .map(|response| response.registered_models)
457 }
458
459 pub async fn create_model_version(&self, request: CreateModelVersion) -> Result<ModelVersion> {
460 let path = "2.0/mlflow/model-versions/create";
461
462 self.send_request(Method::POST, path, request)
463 .await?
464 .json::<bs::ModelVersion>()
465 .await
466 .map(|response| response.model_version)
467 }
468
469 pub async fn transition_model_version_stage(
470 &self,
471 request: TransitionModelVersionStage,
472 ) -> Result<ModelVersion> {
473 let path = "2.0/mlflow/model-versions/transition-stage";
474
475 self.send_request(Method::POST, path, request)
476 .await?
477 .json::<bs::ModelVersion>()
478 .await
479 .map(|response| response.model_version)
480 }
481
482 pub async fn list_run_artifacts(&self, run_id: &str) -> Result<ListedArtifacts> {
484 let path = "2.0/mlflow/artifacts/list";
485 let request = json!({
486 "run_id": run_id
487 });
488
489 let mut response = self
490 .send_request(Method::GET, path, request)
491 .await?
492 .json::<ListedArtifacts>()
493 .await?;
494
495 let mut tocheck = response.files.drain(..).collect::<Vec<_>>();
496
497 while let Some(info) = tocheck.pop() {
498 if !info.is_dir {
500 response.files.push(info);
501 continue;
502 }
503
504 debug!("Listing directory: {:?}", info.path);
505
506 let request = json!({
507 "run_id": run_id,
508 "path": info.path,
509 });
510
511 let mut files = self
512 .send_request(Method::GET, path, request)
513 .await?
514 .json::<ListedArtifacts>()
515 .await?
516 .files;
517
518 tocheck.append(&mut files);
519 }
520
521 Ok(response)
522 }
523
524 pub async fn download_artifacts(
526 &self,
527 downloads: Vec<DownloadRunArtifacts>,
528 ) -> Result<Vec<RunArtifacts>> {
529 let mut downloads = downloads
532 .into_iter()
533 .flat_map(|downloads| downloads.as_single_downloads())
534 .collect::<Vec<_>>();
535
536 downloads.retain(|download| {
538 if self.download.is_blacklisted(&download.file) {
539 debug!("Filtering out blacklisted file: {}", download.file);
540 false
541 } else {
542 true
543 }
544 });
545
546 let chunk_size = if downloads.len() <= self.download.tasks {
547 1
548 } else {
549 downloads.len() / self.download.tasks
550 };
551
552 let batches = downloads.chunks(chunk_size).map(|chunks| chunks.to_vec());
558
559 let mut tasks = JoinSet::<Result<Vec<_>>>::new();
560
561 for (idx, batch) in batches.enumerate() {
562 let client = self.clone();
563
564 tasks.spawn(async move {
565 debug!("Worker {} has got #{} files", idx, batch.len());
566
567 let mut completed = Vec::with_capacity(batch.len());
568
569 for download in batch {
570 debug!("Starting {:?}", download);
571 let size = client.download_artifact(&download).await?;
572 debug!(
573 "Downloaded artifact {} [size={}, expected={}]",
574 download.file, size, download.expected_size
575 );
576 completed.push(download);
577 }
578
579 Ok(completed)
580 });
581 }
582
583 let mut downloaded = HashMap::new();
584
585 while let Some(result) = tasks.join_next().await {
586 let completed = result
589 .context("We cannot join download task")?
590 .context("Download task has returned an error")?;
591
592 for download in completed {
595 downloaded
596 .entry(download.run_id.clone())
597 .and_modify(|artifacts: &mut RunArtifacts| {
598 artifacts.paths.push(download.path());
599 })
600 .or_insert_with(|| RunArtifacts {
601 paths: vec![download.path()],
602 experiment_id: download.experiment_id,
603 run_id: download.run_id,
604 root: download.destination,
605 });
606 }
607 }
608
609 Ok(downloaded.into_values().collect())
610 }
611
612 pub async fn prepare_run_download(
614 &self,
615 run_id: &str,
616 directory: impl AsRef<Path>,
617 ) -> Result<DownloadRunArtifacts> {
618 let list = self
620 .list_run_artifacts(run_id)
621 .await
622 .context("Cannot list run artifacts")?;
623
624 let run = self
625 .get_run(run_id)
626 .await
627 .context("Cannot get run from mlflow")?;
628
629 Ok(DownloadRunArtifacts::new_from_run(directory, run, list))
630 }
631
632 pub async fn download_run_artifacts(
634 &self,
635 download: DownloadRunArtifacts,
636 ) -> Result<RunArtifacts> {
637 debug!("Starting: {:#?}", download);
638
639 self.download_artifacts(vec![download.clone()])
640 .await
641 .with_context(|| format!("Cannot download artifacts for {:#?}", download))?
642 .pop()
643 .context("BUG: We've not received any RunArtifacts")
644 }
645
646 pub async fn upload_json_artifact_no_retry(
648 &self,
649 data: &impl Serialize,
650 artifact: &Artifact,
651 ) -> Result<()> {
652 let path = format!(
653 "2.0/mlflow-artifacts/artifacts/{}/{}/artifacts/{}",
654 artifact.experiment_id,
655 artifact.run_id,
656 artifact.path.to_string_lossy()
657 );
658
659 let mut request = self.client.put(self.make_url(path)).json(data);
660
661 request = self.assign_basic_auth(request);
662
663 Response(request.send().await?)
664 .error_for_status()
665 .await?
666 .json::<bs::Dummy>()
667 .await
668 .map(|_| ())
669 }
670
671 pub async fn upload_json_artifact(
673 &self,
674 data: &impl Serialize,
675 artifact: &Artifact,
676 ) -> Result<()> {
677 let future = || async { self.upload_json_artifact_no_retry(data, artifact).await };
678
679 future
680 .retry(ExponentialBuilder::default())
681 .notify(|error, duration| {
682 warn!(
683 "Retrying upload in {:?}, got an error: {:?}",
684 error, duration
685 );
686 })
687 .await
688 .with_context(|| format!("Cannot upload: {:?}", artifact))
689 }
690
691 pub async fn download_json_artifact_no_retry<T: DeserializeOwned>(
693 &self,
694 artifact: &Artifact,
695 ) -> Result<T> {
696 let path = format!(
697 "2.0/mlflow-artifacts/artifacts/{}/{}/artifacts/{}",
698 artifact.experiment_id,
699 artifact.run_id,
700 artifact.path.to_string_lossy()
701 );
702
703 let mut request = self.client.get(self.make_url(path));
704
705 request = self.assign_basic_auth(request);
706
707 Response(request.send().await?)
708 .error_for_status()
709 .await?
710 .json::<T>()
711 .await
712 }
713
714 pub async fn download_json_artifact<T: DeserializeOwned>(
716 &self,
717 artifact: &Artifact,
718 ) -> Result<T> {
719 let future = || async { self.download_json_artifact_no_retry(artifact).await };
720
721 future
722 .retry(ExponentialBuilder::default())
723 .notify(|error, duration| {
724 warn!(
725 "Retrying download in {:?}, got an error: {:?}",
726 error, duration
727 );
728 })
729 .await
730 .with_context(|| format!("Cannot download: {:?}", artifact))
731 }
732
733 pub async fn upload_artifact_no_retry(
735 &self,
736 source: impl AsRef<Path>,
737 artifact: &Artifact,
738 ) -> Result<()> {
739 let path = format!(
740 "2.0/mlflow-artifacts/artifacts/{}/{}/artifacts/{}",
741 artifact.experiment_id,
742 artifact.run_id,
743 artifact.path.to_string_lossy()
744 );
745
746 let file = File::open(source.as_ref())
747 .await
748 .with_context(|| format!("Cannot open artifact file: {:?}", source.as_ref()))?;
749
750 let mut request = self.client.put(self.make_url(path)).body(file);
751
752 request = self.assign_basic_auth(request);
753
754 Response(request.send().await?)
755 .error_for_status()
756 .await?
757 .json::<bs::Dummy>()
758 .await
759 .map(|_| ())
760 }
761
762 pub async fn upload_artifacts(&self, uploads: Vec<UploadArtifact>) -> Result<()> {
764 let chunk_size = if uploads.len() <= self.download.tasks {
765 1
766 } else {
767 uploads.len() / self.download.tasks
768 };
769
770 let batches = uploads.chunks(chunk_size).map(|chunks| chunks.to_vec());
771
772 let mut tasks = JoinSet::<Result<()>>::new();
773
774 for (idx, batch) in batches.enumerate() {
775 let client = self.clone();
776
777 tasks.spawn(async move {
778 debug!("Worker {} has got #{} files", idx, batch.len());
779
780 for upload in batch {
781 debug!("Starting {:?}", upload);
782 client
783 .upload_artifact(&upload.local, &upload.remote)
784 .await?;
785 }
786
787 Ok(())
788 });
789 }
790
791 while let Some(result) = tasks.join_next().await {
792 result
793 .context("We cannot join download task")?
794 .context("Upload task has returned an error")?;
795 }
796
797 Ok(())
798 }
799
800 pub async fn upload_artifact(
802 &self,
803 source: impl AsRef<Path>,
804 artifact: &Artifact,
805 ) -> Result<()> {
806 let future = || async {
807 self.upload_artifact_no_retry(source.as_ref(), artifact)
808 .await
809 };
810
811 future
812 .retry(ExponentialBuilder::default())
813 .notify(|error, duration| {
814 warn!(
815 "Retrying upload in {:?}, got an error: {:?}",
816 error, duration
817 );
818 })
819 .await
820 .with_context(|| format!("Cannot upload: {:?}", artifact))
821 }
822
823 pub async fn download_artifact_no_retry(&self, download: &DownloadRunArtifact) -> Result<u64> {
827 let path = format!(
828 "2.0/mlflow-artifacts/artifacts/{}/{}/artifacts/{}",
829 download.experiment_id, download.run_id, download.file,
830 );
831
832 if let Some(directory) = download.path().parent() {
833 create_dir_all(directory).await.with_context(|| {
834 format!(
835 "Unable to create a directory {:?} for the artifact",
836 directory
837 )
838 })?;
839 }
840
841 if self.download.cache_local_artifacts && download.path().exists() {
843 let path = download.path();
844
845 info!("Path {path:?} does exists and caching is enabled, skipping download");
846
847 let file = File::open(&path)
848 .await
849 .with_context(|| format!("Unable to open cached artifact: {path:?} for reading"))?;
850
851 return file
852 .metadata()
853 .await
854 .with_context(|| format!("Unable to get metadata for: {path:?}"))
855 .map(|metadata| metadata.len());
856 }
857
858 let mut file = File::create(&download.path())
859 .await
860 .with_context(|| format!("Cannot create artifact file: {:?}", download.path()))?;
861
862 let mut request = self.client.get(self.make_url(path));
863
864 request = self.assign_basic_auth(request);
865
866 let mut response = Response(request.send().await?).error_for_status().await?.0;
867
868 debug!(
869 "Content-Length of {:?} is {:?}",
870 download.file,
871 response.content_length()
872 );
873
874 let mut size = 0;
875
876 while let Some(chunk) = response
877 .chunk()
878 .await
879 .context("Cannot read artifact data")?
880 {
881 file.write_all(&chunk)
882 .await
883 .context("Cannot write data to artifact file")?;
884 size += chunk.len() as u64;
885 }
886
887 file.flush()
888 .await
889 .context("Unable to flush artifact file to disk")?;
890
891 file.sync_all()
892 .await
893 .context("Unable to write file metadata to disk")?;
894
895 Ok(size)
896 }
897
898 pub async fn download_artifact(&self, download: &DownloadRunArtifact) -> Result<u64> {
900 let future = || async {
901 let size = self.download_artifact_no_retry(download).await?;
902 if self.download.file_size_check && size != download.expected_size {
903 fs::remove_file(download.path()).await?; Err(anyhow!(
905 "Expected size {} and downloaded {} do not match for {}",
906 size,
907 download.expected_size,
908 download.file,
909 ))
910 } else {
911 Ok(size)
912 }
913 };
914
915 future
916 .retry(ExponentialBuilder::default())
917 .notify(|error, duration| {
918 warn!(
919 "Retrying download in {:?}, got an error: {:?}",
920 error, duration
921 );
922 })
923 .await
924 .with_context(|| format!("Cannot download: {:?}", download))
925 }
926}
927
928#[cfg(test)]
929mod test {
930 use super::*;
931 use redact::Secret;
932 use rstest::{fixture, rstest};
933 use std::path::PathBuf;
934 use tracing_test::traced_test;
935
936 #[fixture]
937 fn run_name() -> String {
938 format!("run-{}", rand::random::<u64>())
939 }
940
941 #[fixture]
942 fn model_name() -> String {
943 format!("registered-model-{}", rand::random::<u64>())
944 }
945
946 #[fixture]
947 fn experiment_name() -> String {
948 format!("experiment-{}", rand::random::<u64>())
949 }
950
951 #[fixture]
952 fn client() -> Client {
953 let host = match std::env::var("CI") {
954 Ok(_) => "mlflow",
955 Err(_) => "localhost",
956 };
957
958 Client::new(format!("http://{}:5000/api", host)).with_auth(BasicAuth {
959 user: Secret::from("kokot"),
960 password: Secret::from("bezpeci2021"),
961 })
962 }
963
964 #[rstest]
965 fn test_client_debug(client: Client) {
966 let formatted = format!("{client:?}");
967
968 assert!(!formatted.contains("kokot"));
970 assert!(!formatted.contains("bezpeci2021"));
971 }
972
973 #[rstest]
974 #[tokio::test]
975 async fn test_get_or_create_experiment(experiment_name: String, client: Client) {
976 client
977 .get_or_create_experiment(&experiment_name, vec![])
978 .await
979 .expect("BUG: Cannot create experiment");
980 }
981
982 #[rstest]
983 #[tokio::test]
984 async fn test_create_get_delete_experiment(experiment_name: String, client: Client) {
985 let id = client
986 .create_experiment(&experiment_name, vec![])
987 .await
988 .expect("BUG: Cannot create experiment");
989
990 let experiment = client
991 .get_experiment(&id)
992 .await
993 .expect("BUG: Cannot get experiment");
994
995 assert_eq!(experiment.name, experiment_name);
996
997 client
998 .delete_experiment(&id)
999 .await
1000 .expect("BUG: Cannot update experiment");
1001
1002 client
1003 .create_experiment(&experiment_name, vec![])
1004 .await
1005 .expect_err("BUG: Cannot create deleted experiment");
1006 }
1007
1008 #[rstest]
1009 #[tokio::test]
1010 async fn test_get_experiment_by_name(experiment_name: String, client: Client) {
1011 let experiment_id = client
1012 .create_experiment(&experiment_name, vec![])
1013 .await
1014 .expect("BUG: Cannot create experiment");
1015
1016 let experiment = client
1017 .get_experiment_by_name(&experiment_name)
1018 .await
1019 .expect("BUG: Cannot get crated experiment");
1020
1021 assert_eq!(experiment.experiment_id, experiment_id);
1022 }
1023
1024 #[rstest]
1025 #[tokio::test]
1026 #[awt]
1027 async fn test_run_search(client: Client, #[future] run: Run) {
1028 let search = SearchRuns::new()
1029 .experiment_ids(vec![run.info.experiment_id.clone()])
1030 .max_results(Some(16))
1031 .view(ViewType::All)
1032 .build();
1033
1034 let runs = client
1035 .search_runs(search)
1036 .await
1037 .expect("BUG: Cannot search runs");
1038
1039 assert!(!runs.is_empty());
1040
1041 let search = SearchRuns::new()
1043 .experiment_ids(vec![run.info.experiment_id.clone()])
1044 .view(ViewType::All)
1045 .build();
1046
1047 let runs = client
1048 .search_runs(search)
1049 .await
1050 .expect("BUG: Cannot search runs");
1051
1052 assert!(!runs.is_empty());
1053 }
1054
1055 #[rstest]
1056 #[tokio::test]
1057 async fn test_run_create(experiment_name: String, run_name: String, client: Client) {
1058 let experiment_id = client
1059 .create_experiment(&experiment_name, vec![])
1060 .await
1061 .expect("BUG: Cannot create experiment");
1062
1063 let create = CreateRun::new()
1064 .run_name(&run_name)
1065 .experiment_id(&experiment_id)
1066 .build();
1067
1068 let run = client
1069 .create_run(create)
1070 .await
1071 .expect("BUG: Cannot create run");
1072
1073 let run1 = client
1074 .get_run(&run.info.run_id)
1075 .await
1076 .expect("BUG: Cannot get run");
1077
1078 assert_eq!(run, run1);
1079 }
1080
1081 #[rstest]
1082 #[tokio::test]
1083 async fn test_run_delete(experiment_name: String, run_name: String, client: Client) {
1084 let experiment_id = client
1085 .create_experiment(&experiment_name, vec![])
1086 .await
1087 .expect("BUG: Cannot create experiment");
1088
1089 let create = CreateRun::new()
1090 .run_name(&run_name)
1091 .experiment_id(&experiment_id)
1092 .build();
1093
1094 let run = client
1095 .create_run(create)
1096 .await
1097 .expect("BUG: Cannot create run");
1098
1099 client
1100 .delete_run(&run.info.run_id)
1101 .await
1102 .expect("BUG: Cannot delete run");
1103
1104 client
1105 .delete_run(&run.info.run_id)
1106 .await
1107 .expect("BUG: Can delete non-existing run");
1108 }
1109
1110 #[rstest]
1111 #[tokio::test]
1112 async fn test_run_update_data(experiment_name: String, run_name: String, client: Client) {
1113 let experiment_id = client
1114 .create_experiment(&experiment_name, vec![])
1115 .await
1116 .expect("BUG: Cannot create experiment");
1117
1118 let create = CreateRun::new()
1119 .run_name(&run_name)
1120 .experiment_id(&experiment_id)
1121 .build();
1122
1123 let run = client
1124 .create_run(create)
1125 .await
1126 .expect("BUG: Cannot create run");
1127
1128 let mut update = RunData::new()
1129 .metrics(vec![Metric::new().key("m").value(37.).step(0).build()])
1130 .params(vec![KeyValue::new().key("p").value("42").build()])
1131 .tags(vec![KeyValue::new().key("t").value("73").build()])
1132 .build();
1133
1134 client
1135 .add_run_meta(&run.info.run_id, update.clone())
1136 .await
1137 .expect("BUG: Cannot update run data");
1138
1139 update.tags.insert(
1140 0,
1141 KeyValue::new()
1142 .key("mlflow.runName")
1143 .value(run_name)
1144 .build(),
1145 );
1146
1147 let run1 = client
1148 .get_run(&run.info.run_id)
1149 .await
1150 .expect("BUG: Cannot get run");
1151
1152 assert_eq!(run1.data, update);
1153 }
1154
1155 #[rstest]
1156 #[tokio::test]
1157 async fn test_run_add_inputs(experiment_name: String, run_name: String, client: Client) {
1158 let experiment_id = client
1159 .create_experiment(&experiment_name, vec![])
1160 .await
1161 .expect("BUG: Cannot create experiment");
1162
1163 let create = CreateRun::new()
1164 .run_name(&run_name)
1165 .experiment_id(&experiment_id)
1166 .build();
1167
1168 let run = client
1169 .create_run(create)
1170 .await
1171 .expect("BUG: Cannot create run");
1172
1173 let input = DataSetInput::new()
1174 .tags(vec![KeyValue::new().key("a").value("x").build()])
1175 .dataset(
1176 DataSet::new()
1177 .name("kokot")
1178 .digest("123")
1179 .source_type("kokot1")
1180 .source("s3")
1181 .schema("{}")
1182 .profile("{\"rows\": 22}")
1183 .build(),
1184 )
1185 .build();
1186
1187 client
1188 .add_run_inputs(&run.info.run_id, vec![input.clone()])
1189 .await
1190 .expect("BUG: Unable to add inputs to run");
1191
1192 let run1 = client
1193 .get_run(&run.info.run_id)
1194 .await
1195 .expect("BUG: Cannot get run");
1196
1197 assert_eq!(
1198 run1.inputs,
1199 RunInputs {
1200 inputs: vec![input]
1201 }
1202 );
1203 }
1204
1205 #[fixture]
1206 async fn run(experiment_name: String, run_name: String, client: Client) -> Run {
1207 let experiment_id = client
1208 .create_experiment(&experiment_name, vec![])
1209 .await
1210 .expect("BUG: Cannot create experiment");
1211
1212 let create = CreateRun::new()
1213 .run_name(&run_name)
1214 .experiment_id(&experiment_id)
1215 .build();
1216
1217 client
1218 .create_run(create)
1219 .await
1220 .expect("BUG: Cannot create run")
1221 }
1222
1223 #[rstest]
1224 #[tokio::test]
1225 #[awt]
1226 async fn test_run_update(client: Client, #[future] run: Run) {
1227 let end_time_ms: u64 = 1733565663000;
1228
1229 let update = UpdateRun::new()
1230 .run_id(&run.info.run_id)
1231 .status(RunStatus::Killed)
1232 .end_time(end_time_ms as i64)
1233 .experiment_id(&run.info.experiment_id)
1234 .build();
1235
1236 client
1237 .update_run(update)
1238 .await
1239 .expect("BUG: Unable to update run");
1240
1241 let run1 = client
1242 .get_run(&run.info.run_id)
1243 .await
1244 .expect("BUG: Unable to fetch run");
1245
1246 assert_eq!(run1.info.end_time, Some(end_time_ms));
1247 assert_eq!(run1.info.status, RunStatus::Killed);
1248 }
1249
1250 #[rstest]
1251 #[tokio::test]
1252 #[awt]
1253 #[traced_test]
1254 async fn test_artifacts(client: Client, #[future] run: Run) {
1255 let artifact = Artifact {
1256 experiment_id: run.info.experiment_id.clone(),
1257 run_id: run.info.run_id.clone(),
1258 path: PathBuf::from("abc/lock"),
1259 };
1260
1261 let mut artifact1 = artifact.clone();
1262 artifact1.path = PathBuf::from("abc/lock1");
1263
1264 let artifacts = vec![
1265 UploadArtifact {
1266 local: "Cargo.lock".into(),
1267 remote: artifact.clone(),
1268 },
1269 UploadArtifact {
1270 local: "Cargo.lock".into(),
1271 remote: artifact1.clone(),
1272 },
1273 ];
1274
1275 client
1276 .upload_artifacts(artifacts)
1277 .await
1278 .expect("BUG: Unable to upload artifacts");
1279
1280 let list = client
1281 .list_run_artifacts(&run.info.run_id)
1282 .await
1283 .expect("BUG: Unable to list run artifacts");
1284
1285 assert_eq!(list.files.len(), 2);
1286
1287 let download = client
1288 .prepare_run_download(&run.info.run_id, "/tmp")
1289 .await
1290 .expect("BUG: Unable to prepare run download");
1291
1292 let mut artifacts = client
1293 .download_run_artifacts(download)
1294 .await
1295 .expect("BUG: Cannot download artifact");
1296
1297 artifacts.paths.sort();
1298
1299 assert_eq!(
1300 artifacts.paths,
1301 vec![
1302 PathBuf::from("/tmp/abc/lock"),
1303 PathBuf::from("/tmp/abc/lock1"),
1304 ]
1305 );
1306 }
1307
1308 #[rstest]
1309 #[tokio::test]
1310 #[awt]
1311 async fn test_artifacts_cache(mut client: Client, #[future] run: Run) {
1312 let artifact = Artifact {
1314 experiment_id: run.info.experiment_id.clone(),
1315 run_id: run.info.run_id.clone(),
1316 path: PathBuf::from("run.json"),
1317 };
1318
1319 client
1320 .upload_json_artifact(&run, &artifact)
1321 .await
1322 .expect("BUG: Cannot upload artifact");
1323
1324 client.download.cache_local_artifacts = true;
1327
1328 let filedir = PathBuf::from(format!("/tmp/{}/", run.info.run_id));
1329 let filepath = filedir.join("run.json");
1330
1331 create_dir_all(&filedir)
1332 .await
1333 .expect("BUG: Unable to create destdir");
1334
1335 let file = File::create(&filepath)
1336 .await
1337 .expect("BUG: Unable to create dummy file");
1338
1339 let download = client
1340 .prepare_run_download(&run.info.run_id, filedir)
1341 .await
1342 .expect("BUG: Unable to prepare run download");
1343
1344 let artifacts = client
1345 .download_run_artifacts(download)
1346 .await
1347 .expect("BUG: Cannot download artifact");
1348
1349 assert_eq!(artifacts.paths, vec![filepath]);
1350
1351 assert_eq!(
1352 file.metadata()
1353 .await
1354 .expect("BUG: Unable to get file metadata")
1355 .len(),
1356 0
1357 );
1358 }
1359
1360 #[rstest]
1361 #[tokio::test]
1362 #[awt]
1363 async fn test_json_artifact(client: Client, #[future] run: Run) {
1364 let artifact = Artifact {
1365 experiment_id: run.info.experiment_id.clone(),
1366 run_id: run.info.run_id.clone(),
1367 path: PathBuf::from("run.json"),
1368 };
1369
1370 client
1371 .upload_json_artifact(&run, &artifact)
1372 .await
1373 .expect("BUG: Cannot upload artifact");
1374
1375 let run1 = client
1376 .download_json_artifact::<Run>(&artifact)
1377 .await
1378 .expect("BUG: Cannot download artifact");
1379
1380 assert_eq!(run, run1);
1381 }
1382
1383 #[rstest]
1384 #[tokio::test]
1385 #[awt]
1386 async fn test_registered_model(client: Client, model_name: String) {
1387 let registered = client
1388 .register_model(
1389 RegisterModel::new()
1390 .name(&model_name)
1391 .description("")
1392 .build(),
1393 )
1394 .await
1395 .expect("BUG: Unable to register model");
1396
1397 let registered1 = client
1398 .get_registered_model(&model_name)
1399 .await
1400 .expect("BUG: Cannot get registered model");
1401
1402 assert_eq!(registered.name, registered1.name);
1403 assert_eq!(
1404 registered.creation_timestamp,
1405 registered1.creation_timestamp
1406 );
1407
1408 let _ = client
1409 .delete_registered_model(&model_name)
1410 .await
1411 .expect("BUG: Unable to delete registered model");
1412 }
1413
1414 #[rstest]
1415 #[tokio::test]
1416 #[awt]
1417 async fn test_search_registered_models(client: Client, #[future] run: Run, model_name: String) {
1418 let _registered = client
1419 .register_model(
1420 RegisterModel::new()
1421 .name(&model_name)
1422 .description("yep")
1423 .tags(vec![KeyValue::new()
1424 .key("kokot")
1425 .value(&model_name)
1426 .build()])
1427 .build(),
1428 )
1429 .await
1430 .expect("BUG: Unable to register model");
1431
1432 let create = CreateModelVersion::new()
1433 .registered_model_name(&model_name)
1434 .artifacts_url("s3:///kokot")
1435 .run_id(&run.info.run_id)
1436 .description("xxx")
1437 .build();
1438
1439 client
1440 .create_model_version(create)
1441 .await
1442 .expect("BUG: Cannot create model version");
1443
1444 let mut models = client
1445 .search_registered_models(&format!("tags.kokot like '({}|abc)'", model_name))
1446 .await
1447 .expect("BUG: Cannot search registered models");
1448
1449 let mut model = models
1450 .pop()
1451 .expect("BUG: We must get at least one registred model");
1452
1453 let latest = model
1454 .latest_versions
1455 .pop()
1456 .expect("BUG: Model must have at least one latest version");
1457
1458 assert_eq!(latest.run_id, run.info.run_id);
1459 }
1460
1461 #[rstest]
1462 #[tokio::test]
1463 #[awt]
1464 async fn test_add_model_version(
1465 client: Client,
1466 #[future] run: Run,
1467 #[future]
1468 #[from(run)]
1469 run1: Run,
1470 model_name: String,
1471 ) {
1472 let _registered = client
1473 .register_model(
1474 RegisterModel::new()
1475 .name(&model_name)
1476 .description("yep")
1477 .build(),
1478 )
1479 .await
1480 .expect("BUG: Unable to register model");
1481
1482 let create = CreateModelVersion::new()
1483 .registered_model_name(&model_name)
1484 .artifacts_url("s3:///kokot")
1485 .run_id(&run.info.run_id)
1486 .description("xxx")
1487 .build();
1488
1489 let mut create1 = create.clone();
1490 create1.run_id = run1.info.run_id;
1491
1492 client
1493 .create_model_version(create)
1494 .await
1495 .expect("BUG: Cannot create model version");
1496
1497 let version = client
1498 .create_model_version(create1)
1499 .await
1500 .expect("BUG: Cannot create model version");
1501
1502 assert_eq!(version.current_stage, "None");
1503
1504 let transition = TransitionModelVersionStage::new()
1505 .name(&model_name)
1506 .version(version.version)
1507 .stage(ModelVersionStage::Production)
1508 .archive_existing_versions(false)
1509 .build();
1510
1511 let version1 = client
1512 .transition_model_version_stage(transition)
1513 .await
1514 .expect("BUG: Cannot transition model stage");
1515
1516 assert_eq!(version1.current_stage, "Production");
1517 }
1518}