trs_mlflow/
lib.rs

1//! This crate contains an unofficial asynchronous mlflow client which uses 2.0
2//! API.
3//!
4//! The whole API description can be found here: <https://mlflow.org/docs/latest/rest-api.html>
5//!
6//! Usage example:
7//! ```rust, no_run
8//! # #[tokio::main]
9//! # async fn main() {
10//! use trs_mlflow::{run::CreateRun, Client};
11//!
12//! let client = Client::new("http://localhost:5000/api");
13//!
14//! let experiment_id = client
15//!     .create_experiment(&"foo", vec![])
16//!     .await
17//!     .expect("BUG: Cannot create experiment");
18//!
19//! let create = CreateRun::new()
20//!     .run_name("bar")
21//!     .experiment_id(&experiment_id)
22//!     .build();
23//!
24//! let run = client
25//!     .create_run(create)
26//!     .await
27//!     .expect("BUG: Cannot create run");
28//!
29//! # }
30//! ```
31//!
32//!
33//! # Testing
34//!
35//! If you want to run tests you need to run MlFlow server locally, which can
36//! be done by running this command:
37//! ```sh
38//! ./server/run.sh
39//! ```
40//!
41//! This command creates a new venv, installs server and then runs it.
42//!
43//! # Disclaimer
44//!
45//! This is heavy WIP crate as I'm still trying to figure out how to wrap some
46//! API methods.
47pub mod artifact;
48pub mod config;
49mod error;
50pub mod experiment;
51mod http;
52pub mod modelversion;
53pub mod registered;
54pub mod run;
55
56use crate::config::{BasicAuth, ClientConfig, DownloadConfig, MlflowConfig};
57use crate::http::Response;
58use crate::{artifact::*, experiment::*, modelversion::*, registered::*, run::*};
59use anyhow::{anyhow, Context, Result};
60use backon::{ExponentialBuilder, Retryable};
61use reqwest::{Method, RequestBuilder};
62use serde::{de::DeserializeOwned, Deserialize, Serialize};
63use serde_json::json;
64use std::collections::HashMap;
65use std::path::Path;
66use tokio::{
67    fs,
68    fs::{create_dir_all, File},
69    io::AsyncWriteExt,
70    task::JoinSet,
71};
72use tracing::{debug, info, warn};
73
74/// This is low-level client which reflects mlflow 2.0 API.
75#[derive(Clone, Debug)]
76pub struct Client {
77    pub client: reqwest::Client,
78    /// Mlflow related configuration.
79    pub mlflow: MlflowConfig,
80    /// Configuration which is related to artifact download.
81    pub download: DownloadConfig,
82}
83
84pub(crate) mod bs {
85    use serde::Deserialize;
86
87    #[derive(Deserialize)]
88    pub struct Dummy {}
89
90    #[derive(Deserialize)]
91    pub struct Run {
92        pub run: super::Run,
93    }
94
95    #[derive(Deserialize)]
96    pub struct Runs {
97        #[serde(default)]
98        pub runs: Vec<super::Run>,
99    }
100
101    #[derive(Deserialize)]
102    pub struct RunInfo {
103        pub run_info: super::RunInfo,
104    }
105
106    #[derive(Deserialize)]
107    pub struct ModelVersion {
108        pub model_version: super::ModelVersion,
109    }
110
111    #[derive(Deserialize)]
112    pub struct RegisteredModel {
113        pub registered_model: super::RegisteredModel,
114    }
115
116    #[derive(Deserialize)]
117    pub struct RegisteredModels {
118        #[serde(default)]
119        pub registered_models: Vec<super::RegisteredModel>,
120    }
121}
122
123// TODO(ast): add retry config
124impl Client {
125    /// Creates a new client which uses a default [`reqwest::Client`] without
126    /// additional settings.
127    pub fn new(urlbase: impl AsRef<str>) -> Self {
128        Self {
129            client: reqwest::Client::new(),
130            mlflow: MlflowConfig {
131                urlbase: urlbase.as_ref().to_owned(),
132                auth: None,
133            },
134            download: DownloadConfig {
135                retry: false,
136                tasks: 16,
137                blacklist: Default::default(),
138                cache_local_artifacts: false,
139                file_size_check: true,
140            },
141        }
142    }
143
144    /// Creates a new client using provided configuration.
145    pub fn new_from_config(config: ClientConfig) -> Self {
146        Self {
147            client: reqwest::Client::new(),
148            mlflow: config.mlflow,
149            download: config.download,
150        }
151    }
152
153    /// Sets basic auth which will be send to mlflow server.
154    pub fn with_auth(mut self, auth: BasicAuth) -> Self {
155        self.mlflow.auth = Some(auth);
156        self
157    }
158
159    fn make_url(&self, path: impl AsRef<str>) -> String {
160        format!("{}/{}", self.mlflow.urlbase, path.as_ref())
161    }
162
163    /// Send request to mlflow and handles returned errors.
164    async fn send_json_request_no_retry<P: Serialize>(
165        &self,
166        method: Method,
167        path: impl AsRef<str>,
168        payload: &P,
169    ) -> Result<Response> {
170        let mut request = self
171            .client
172            .request(method, self.make_url(path.as_ref()))
173            .json(&payload);
174
175        request = self.assign_basic_auth(request);
176
177        Response(request.send().await?).error_for_status().await
178    }
179
180    /// Send request to mlflow and handles returned errors.
181    async fn send_request<P: Serialize>(
182        &self,
183        method: Method,
184        path: impl AsRef<str>,
185        payload: P,
186    ) -> Result<Response> {
187        let future = || async {
188            self.send_json_request_no_retry(method.clone(), path.as_ref(), &payload)
189                .await
190        };
191
192        future
193            .retry(ExponentialBuilder::default())
194            .notify(|error, duration| {
195                warn!("Retrying in {:?}, got an error: {:?}", error, duration);
196            })
197            .await
198            .with_context(|| format!("Cannot send request to: {:?}", path.as_ref()))
199    }
200
201    /// Assigns basic auth if client uses it into provided request builder.
202    fn assign_basic_auth(&self, request: RequestBuilder) -> RequestBuilder {
203        if let Some(auth) = self.mlflow.auth.as_ref() {
204            request.basic_auth(
205                auth.user.expose_secret().to_owned(),
206                Some(auth.password.expose_secret().to_owned()),
207            )
208        } else {
209            request
210        }
211    }
212
213    /// Create an experiment with a name. Returns the ID of the newly created
214    /// experiment. Validates that another experiment with the same name does
215    /// not already exist and fails if another experiment with the same name
216    /// already exists.
217    pub async fn create_experiment(&self, name: &str, tags: Vec<KeyValue>) -> Result<String> {
218        let path = "2.0/mlflow/experiments/create";
219        let request = CreateExperiment {
220            name: name.to_owned(),
221            tags,
222        };
223
224        self.send_request(Method::POST, path, request)
225            .await?
226            .json::<ExperimentId>()
227            .await
228            .map(|response| response.experiment_id)
229    }
230
231    /// Mark an experiment and associated metadata, runs, metrics, params, and
232    /// tags for deletion. If the experiment uses FileStore, artifacts
233    /// associated with experiment are also deleted.
234    pub async fn delete_experiment(&self, experiment_id: &str) -> Result<()> {
235        let path = "2.0/mlflow/experiments/delete";
236        let request = ExperimentId {
237            experiment_id: experiment_id.to_owned(),
238        };
239
240        self.send_request(Method::POST, path, request)
241            .await
242            .map(|_| ())
243    }
244
245    /// Get metadata for an experiment. This method works on deleted
246    /// experiments.
247    pub async fn get_experiment(&self, experiment_id: &str) -> Result<Experiment> {
248        let path = "2.0/mlflow/experiments/get";
249        let request = ExperimentId {
250            experiment_id: experiment_id.to_owned(),
251        };
252
253        #[derive(Deserialize)]
254        struct Response {
255            experiment: Experiment,
256        }
257
258        self.send_request(Method::GET, path, request)
259            .await?
260            .json::<Response>()
261            .await
262            .map(|response| response.experiment)
263    }
264
265    /// Get metadata for an experiment.
266    ///
267    /// This endpoint will return deleted experiments, but prefers the active
268    /// experiment if an active and deleted experiment share the same name. If
269    /// multiple deleted experiments share the same name, the API will return
270    /// one of them.
271    pub async fn get_experiment_by_name(&self, name: &str) -> Result<Experiment> {
272        let path = "2.0/mlflow/experiments/get-by-name";
273        let request = json!({
274            "experiment_name": name
275        });
276
277        #[derive(Deserialize)]
278        struct Response {
279            experiment: Experiment,
280        }
281
282        self.send_request(Method::GET, path, request)
283            .await?
284            .json::<Response>()
285            .await
286            .map(|response| response.experiment)
287    }
288
289    /// Loads experiment with specified name or create one with that name and without
290    /// tags.
291    pub async fn get_or_create_experiment(
292        &self,
293        name: &str,
294        tags: Vec<KeyValue>,
295    ) -> Result<Experiment> {
296        let path = "2.0/mlflow/experiments/get-by-name";
297        let request = json!({
298            "experiment_name": name
299        });
300
301        #[derive(Deserialize)]
302        struct Response {
303            experiment: Experiment,
304        }
305
306        let result = self
307            .send_json_request_no_retry(Method::GET, path, &request)
308            .await;
309
310        match result {
311            Err(_) => {
312                let experiment_id = self.create_experiment(name, tags).await?;
313                self.get_experiment(&experiment_id).await
314            }
315            Ok(response) => response
316                .json::<Response>()
317                .await
318                .map(|response| response.experiment),
319        }
320    }
321
322    /// Create a new run within an experiment. A run is usually a single
323    /// execution of a machine learning or data ETL pipeline. MLflow uses runs
324    /// to track Param, Metric, and RunTag associated with a single execution.
325    pub async fn create_run(&self, create: CreateRun) -> Result<Run> {
326        let path = "2.0/mlflow/runs/create";
327
328        self.send_request(Method::POST, path, create)
329            .await?
330            .json::<bs::Run>()
331            .await
332            .map(|value| value.run)
333    }
334
335    /// Get metadata, metrics, params, and tags for a run. In the case where
336    /// multiple metrics with the same key are logged for a run, return only the
337    /// value with the latest timestamp. If there are multiple values with the
338    /// latest timestamp, return the maximum of these values.
339    pub async fn get_run(&self, run_id: &str) -> Result<Run> {
340        let path = "2.0/mlflow/runs/get";
341        let request = Request {
342            run_id: run_id.to_string(),
343        };
344
345        #[derive(Serialize)]
346        struct Request {
347            run_id: String,
348        }
349
350        self.send_request(Method::GET, path, request)
351            .await?
352            .json::<bs::Run>()
353            .await
354            .map(|value| value.run)
355    }
356
357    /// Search for runs that satisfy expressions. Search expressions can use
358    /// Metric and Param keys.
359    pub async fn search_runs(&self, search: SearchRuns) -> Result<Vec<Run>> {
360        let path = "2.0/mlflow/runs/search";
361
362        self.send_request(Method::POST, path, search)
363            .await?
364            .json::<bs::Runs>()
365            .await
366            .map(|response| response.runs)
367    }
368
369    /// Update run metadata.
370    pub async fn update_run(&self, update: UpdateRun) -> Result<RunInfo> {
371        let path = "2.0/mlflow/runs/update";
372
373        self.send_request(Method::POST, path, update)
374            .await?
375            .json::<bs::RunInfo>()
376            .await
377            .map(|value| value.run_info)
378    }
379
380    /// Log a batch of metrics, params, and tags for a run. If any data failed
381    /// to be persisted, the server will respond with an error (non-200 status
382    /// code). In case of error (due to internal server error or an invalid
383    /// request), partial data may be written.
384    pub async fn add_run_meta(&self, run_id: &str, meta: RunData) -> Result<()> {
385        let path = "2.0/mlflow/runs/log-batch";
386        let request = json!({
387            "run_id": run_id.to_owned(),
388            "metrics": meta.metrics,
389            "params": meta.params,
390            "tags": meta.tags
391        });
392
393        self.send_request(Method::POST, path, request)
394            .await
395            .map(|_| ())
396    }
397
398    pub async fn delete_run(&self, run_id: &str) -> Result<()> {
399        let path = "2.0/mlflow/runs/delete";
400
401        self.send_request(Method::POST, path, json!({"run_id": run_id}))
402            .await
403            .map(|_| ())
404    }
405
406    pub async fn add_run_inputs(&self, run_id: &str, inputs: Vec<DataSetInput>) -> Result<()> {
407        let path = "2.0/mlflow/runs/log-inputs";
408        let request = json!({
409            "run_id": run_id.to_owned(),
410            "datasets": inputs,
411        });
412
413        self.send_request(Method::POST, path, request)
414            .await
415            .map(|_| ())
416    }
417
418    pub async fn delete_registered_model(&self, name: &str) -> Result<()> {
419        let path = "2.0/mlflow/registered-models/delete";
420        let request = json!({"name": name.to_owned()});
421
422        self.send_request(Method::DELETE, path, request)
423            .await
424            .map(|_| ())
425    }
426
427    pub async fn register_model(&self, request: RegisterModel) -> Result<RegisteredModel> {
428        let path = "2.0/mlflow/registered-models/create";
429
430        self.send_request(Method::POST, path, request)
431            .await?
432            .json::<bs::RegisteredModel>()
433            .await
434            .map(|response| response.registered_model)
435    }
436
437    pub async fn get_registered_model(&self, name: &str) -> Result<RegisteredModel> {
438        let path = "2.0/mlflow/registered-models/get";
439        let request = json!({"name": name.to_owned()});
440
441        self.send_request(Method::GET, path, request)
442            .await?
443            .json::<bs::RegisteredModel>()
444            .await
445            .map(|response| response.registered_model)
446    }
447
448    pub async fn search_registered_models(&self, filter: &str) -> Result<Vec<RegisteredModel>> {
449        let path = "2.0/mlflow/registered-models/search";
450        let request = json!({"filter": filter.to_owned()});
451
452        self.send_request(Method::GET, path, request)
453            .await?
454            .json::<bs::RegisteredModels>()
455            .await
456            .map(|response| response.registered_models)
457    }
458
459    pub async fn create_model_version(&self, request: CreateModelVersion) -> Result<ModelVersion> {
460        let path = "2.0/mlflow/model-versions/create";
461
462        self.send_request(Method::POST, path, request)
463            .await?
464            .json::<bs::ModelVersion>()
465            .await
466            .map(|response| response.model_version)
467    }
468
469    pub async fn transition_model_version_stage(
470        &self,
471        request: TransitionModelVersionStage,
472    ) -> Result<ModelVersion> {
473        let path = "2.0/mlflow/model-versions/transition-stage";
474
475        self.send_request(Method::POST, path, request)
476            .await?
477            .json::<bs::ModelVersion>()
478            .await
479            .map(|response| response.model_version)
480    }
481
482    /// Lists all artifacts (recursively) for a run.
483    pub async fn list_run_artifacts(&self, run_id: &str) -> Result<ListedArtifacts> {
484        let path = "2.0/mlflow/artifacts/list";
485        let request = json!({
486            "run_id": run_id
487        });
488
489        let mut response = self
490            .send_request(Method::GET, path, request)
491            .await?
492            .json::<ListedArtifacts>()
493            .await?;
494
495        let mut tocheck = response.files.drain(..).collect::<Vec<_>>();
496
497        while let Some(info) = tocheck.pop() {
498            // file is not a directory, there is nothing to do
499            if !info.is_dir {
500                response.files.push(info);
501                continue;
502            }
503
504            debug!("Listing directory: {:?}", info.path);
505
506            let request = json!({
507                "run_id": run_id,
508                "path": info.path,
509            });
510
511            let mut files = self
512                .send_request(Method::GET, path, request)
513                .await?
514                .json::<ListedArtifacts>()
515                .await?
516                .files;
517
518            tocheck.append(&mut files);
519        }
520
521        Ok(response)
522    }
523
524    /// Downloads multiple artifacts from multiple runs.
525    pub async fn download_artifacts(
526        &self,
527        downloads: Vec<DownloadRunArtifacts>,
528    ) -> Result<Vec<RunArtifacts>> {
529        // Create DownloadRunArtifact for earch artifact and each download
530        // command
531        let mut downloads = downloads
532            .into_iter()
533            .flat_map(|downloads| downloads.as_single_downloads())
534            .collect::<Vec<_>>();
535
536        // Filter all blacklisted artifacts
537        downloads.retain(|download| {
538            if self.download.is_blacklisted(&download.file) {
539                debug!("Filtering out blacklisted file: {}", download.file);
540                false
541            } else {
542                true
543            }
544        });
545
546        let chunk_size = if downloads.len() <= self.download.tasks {
547            1
548        } else {
549            downloads.len() / self.download.tasks
550        };
551
552        // Creates batch of artifacts for each download worker:
553        //    Artifacts = [1, 2, 3, 4, 5], Tasks=2
554        //    Task1 = [1, 2, 3]
555        //    Task2 = [4, 5]
556        //    Batches = [Task1, Task2]
557        let batches = downloads.chunks(chunk_size).map(|chunks| chunks.to_vec());
558
559        let mut tasks = JoinSet::<Result<Vec<_>>>::new();
560
561        for (idx, batch) in batches.enumerate() {
562            let client = self.clone();
563
564            tasks.spawn(async move {
565                debug!("Worker {} has got #{} files", idx, batch.len());
566
567                let mut completed = Vec::with_capacity(batch.len());
568
569                for download in batch {
570                    debug!("Starting {:?}", download);
571                    let size = client.download_artifact(&download).await?;
572                    debug!(
573                        "Downloaded artifact {} [size={}, expected={}]",
574                        download.file, size, download.expected_size
575                    );
576                    completed.push(download);
577                }
578
579                Ok(completed)
580            });
581        }
582
583        let mut downloaded = HashMap::new();
584
585        while let Some(result) = tasks.join_next().await {
586            // NOTE: we all tasks are aborted when error occurs so we don't need
587            //       to stop them explicitly
588            let completed = result
589                .context("We cannot join download task")?
590                .context("Download task has returned an error")?;
591
592            // now, we must assign each downloaded artifact into a correct
593            // run
594            for download in completed {
595                downloaded
596                    .entry(download.run_id.clone())
597                    .and_modify(|artifacts: &mut RunArtifacts| {
598                        artifacts.paths.push(download.path());
599                    })
600                    .or_insert_with(|| RunArtifacts {
601                        paths: vec![download.path()],
602                        experiment_id: download.experiment_id,
603                        run_id: download.run_id,
604                        root: download.destination,
605                    });
606            }
607        }
608
609        Ok(downloaded.into_values().collect())
610    }
611
612    /// Prepares download of artifacts from specified run.
613    pub async fn prepare_run_download(
614        &self,
615        run_id: &str,
616        directory: impl AsRef<Path>,
617    ) -> Result<DownloadRunArtifacts> {
618        // Get information about files inside of specified run
619        let list = self
620            .list_run_artifacts(run_id)
621            .await
622            .context("Cannot list run artifacts")?;
623
624        let run = self
625            .get_run(run_id)
626            .await
627            .context("Cannot get run from mlflow")?;
628
629        Ok(DownloadRunArtifacts::new_from_run(directory, run, list))
630    }
631
632    /// Downloads all artifacts for one run.
633    pub async fn download_run_artifacts(
634        &self,
635        download: DownloadRunArtifacts,
636    ) -> Result<RunArtifacts> {
637        debug!("Starting: {:#?}", download);
638
639        self.download_artifacts(vec![download.clone()])
640            .await
641            .with_context(|| format!("Cannot download artifacts for {:#?}", download))?
642            .pop()
643            .context("BUG: We've not received any RunArtifacts")
644    }
645
646    /// Uploads json artifact using mlflow tracking server.
647    pub async fn upload_json_artifact_no_retry(
648        &self,
649        data: &impl Serialize,
650        artifact: &Artifact,
651    ) -> Result<()> {
652        let path = format!(
653            "2.0/mlflow-artifacts/artifacts/{}/{}/artifacts/{}",
654            artifact.experiment_id,
655            artifact.run_id,
656            artifact.path.to_string_lossy()
657        );
658
659        let mut request = self.client.put(self.make_url(path)).json(data);
660
661        request = self.assign_basic_auth(request);
662
663        Response(request.send().await?)
664            .error_for_status()
665            .await?
666            .json::<bs::Dummy>()
667            .await
668            .map(|_| ())
669    }
670
671    /// Uploads json artifact using mlflow tracking server.
672    pub async fn upload_json_artifact(
673        &self,
674        data: &impl Serialize,
675        artifact: &Artifact,
676    ) -> Result<()> {
677        let future = || async { self.upload_json_artifact_no_retry(data, artifact).await };
678
679        future
680            .retry(ExponentialBuilder::default())
681            .notify(|error, duration| {
682                warn!(
683                    "Retrying upload in {:?}, got an error: {:?}",
684                    error, duration
685                );
686            })
687            .await
688            .with_context(|| format!("Cannot upload: {:?}", artifact))
689    }
690
691    /// Downloads artifact and tries it to parse as a json.
692    pub async fn download_json_artifact_no_retry<T: DeserializeOwned>(
693        &self,
694        artifact: &Artifact,
695    ) -> Result<T> {
696        let path = format!(
697            "2.0/mlflow-artifacts/artifacts/{}/{}/artifacts/{}",
698            artifact.experiment_id,
699            artifact.run_id,
700            artifact.path.to_string_lossy()
701        );
702
703        let mut request = self.client.get(self.make_url(path));
704
705        request = self.assign_basic_auth(request);
706
707        Response(request.send().await?)
708            .error_for_status()
709            .await?
710            .json::<T>()
711            .await
712    }
713
714    /// Downloads artifact and tries it to parse as a json.
715    pub async fn download_json_artifact<T: DeserializeOwned>(
716        &self,
717        artifact: &Artifact,
718    ) -> Result<T> {
719        let future = || async { self.download_json_artifact_no_retry(artifact).await };
720
721        future
722            .retry(ExponentialBuilder::default())
723            .notify(|error, duration| {
724                warn!(
725                    "Retrying download in {:?}, got an error: {:?}",
726                    error, duration
727                );
728            })
729            .await
730            .with_context(|| format!("Cannot download: {:?}", artifact))
731    }
732
733    /// Uploads artifact using mlflow tracking server.
734    pub async fn upload_artifact_no_retry(
735        &self,
736        source: impl AsRef<Path>,
737        artifact: &Artifact,
738    ) -> Result<()> {
739        let path = format!(
740            "2.0/mlflow-artifacts/artifacts/{}/{}/artifacts/{}",
741            artifact.experiment_id,
742            artifact.run_id,
743            artifact.path.to_string_lossy()
744        );
745
746        let file = File::open(source.as_ref())
747            .await
748            .with_context(|| format!("Cannot open artifact file: {:?}", source.as_ref()))?;
749
750        let mut request = self.client.put(self.make_url(path)).body(file);
751
752        request = self.assign_basic_auth(request);
753
754        Response(request.send().await?)
755            .error_for_status()
756            .await?
757            .json::<bs::Dummy>()
758            .await
759            .map(|_| ())
760    }
761
762    /// Uploads multiple artifacts at once.
763    pub async fn upload_artifacts(&self, uploads: Vec<UploadArtifact>) -> Result<()> {
764        let chunk_size = if uploads.len() <= self.download.tasks {
765            1
766        } else {
767            uploads.len() / self.download.tasks
768        };
769
770        let batches = uploads.chunks(chunk_size).map(|chunks| chunks.to_vec());
771
772        let mut tasks = JoinSet::<Result<()>>::new();
773
774        for (idx, batch) in batches.enumerate() {
775            let client = self.clone();
776
777            tasks.spawn(async move {
778                debug!("Worker {} has got #{} files", idx, batch.len());
779
780                for upload in batch {
781                    debug!("Starting {:?}", upload);
782                    client
783                        .upload_artifact(&upload.local, &upload.remote)
784                        .await?;
785                }
786
787                Ok(())
788            });
789        }
790
791        while let Some(result) = tasks.join_next().await {
792            result
793                .context("We cannot join download task")?
794                .context("Upload task has returned an error")?;
795        }
796
797        Ok(())
798    }
799
800    /// Uploads artifact using mlflow tracking server.
801    pub async fn upload_artifact(
802        &self,
803        source: impl AsRef<Path>,
804        artifact: &Artifact,
805    ) -> Result<()> {
806        let future = || async {
807            self.upload_artifact_no_retry(source.as_ref(), artifact)
808                .await
809        };
810
811        future
812            .retry(ExponentialBuilder::default())
813            .notify(|error, duration| {
814                warn!(
815                    "Retrying upload in {:?}, got an error: {:?}",
816                    error, duration
817                );
818            })
819            .await
820            .with_context(|| format!("Cannot upload: {:?}", artifact))
821    }
822
823    /// Downloads artifact FILE from mlflow tracking server into specified path.
824    ///
825    /// Returns a size of downloaded artifact in bytes.
826    pub async fn download_artifact_no_retry(&self, download: &DownloadRunArtifact) -> Result<u64> {
827        let path = format!(
828            "2.0/mlflow-artifacts/artifacts/{}/{}/artifacts/{}",
829            download.experiment_id, download.run_id, download.file,
830        );
831
832        if let Some(directory) = download.path().parent() {
833            create_dir_all(directory).await.with_context(|| {
834                format!(
835                    "Unable to create a directory {:?} for the artifact",
836                    directory
837                )
838            })?;
839        }
840
841        // skip file download if we are caching files and file is present
842        if self.download.cache_local_artifacts && download.path().exists() {
843            let path = download.path();
844
845            info!("Path {path:?} does exists and caching is enabled, skipping download");
846
847            let file = File::open(&path)
848                .await
849                .with_context(|| format!("Unable to open cached artifact: {path:?} for reading"))?;
850
851            return file
852                .metadata()
853                .await
854                .with_context(|| format!("Unable to get metadata for: {path:?}"))
855                .map(|metadata| metadata.len());
856        }
857
858        let mut file = File::create(&download.path())
859            .await
860            .with_context(|| format!("Cannot create artifact file: {:?}", download.path()))?;
861
862        let mut request = self.client.get(self.make_url(path));
863
864        request = self.assign_basic_auth(request);
865
866        let mut response = Response(request.send().await?).error_for_status().await?.0;
867
868        debug!(
869            "Content-Length of {:?} is {:?}",
870            download.file,
871            response.content_length()
872        );
873
874        let mut size = 0;
875
876        while let Some(chunk) = response
877            .chunk()
878            .await
879            .context("Cannot read artifact data")?
880        {
881            file.write_all(&chunk)
882                .await
883                .context("Cannot write data to artifact file")?;
884            size += chunk.len() as u64;
885        }
886
887        file.flush()
888            .await
889            .context("Unable to flush artifact file to disk")?;
890
891        file.sync_all()
892            .await
893            .context("Unable to write file metadata to disk")?;
894
895        Ok(size)
896    }
897
898    /// Downloads artifact FILE from mlflow tracking server into specified path.
899    pub async fn download_artifact(&self, download: &DownloadRunArtifact) -> Result<u64> {
900        let future = || async {
901            let size = self.download_artifact_no_retry(download).await?;
902            if self.download.file_size_check && size != download.expected_size {
903                fs::remove_file(download.path()).await?; // remove so it won't be considered "cached"
904                Err(anyhow!(
905                    "Expected size {} and downloaded {} do not match for {}",
906                    size,
907                    download.expected_size,
908                    download.file,
909                ))
910            } else {
911                Ok(size)
912            }
913        };
914
915        future
916            .retry(ExponentialBuilder::default())
917            .notify(|error, duration| {
918                warn!(
919                    "Retrying download in {:?}, got an error: {:?}",
920                    error, duration
921                );
922            })
923            .await
924            .with_context(|| format!("Cannot download: {:?}", download))
925    }
926}
927
928#[cfg(test)]
929mod test {
930    use super::*;
931    use redact::Secret;
932    use rstest::{fixture, rstest};
933    use std::path::PathBuf;
934    use tracing_test::traced_test;
935
936    #[fixture]
937    fn run_name() -> String {
938        format!("run-{}", rand::random::<u64>())
939    }
940
941    #[fixture]
942    fn model_name() -> String {
943        format!("registered-model-{}", rand::random::<u64>())
944    }
945
946    #[fixture]
947    fn experiment_name() -> String {
948        format!("experiment-{}", rand::random::<u64>())
949    }
950
951    #[fixture]
952    fn client() -> Client {
953        let host = match std::env::var("CI") {
954            Ok(_) => "mlflow",
955            Err(_) => "localhost",
956        };
957
958        Client::new(format!("http://{}:5000/api", host)).with_auth(BasicAuth {
959            user: Secret::from("kokot"),
960            password: Secret::from("bezpeci2021"),
961        })
962    }
963
964    #[rstest]
965    fn test_client_debug(client: Client) {
966        let formatted = format!("{client:?}");
967
968        // check that credentials are not visible in debug
969        assert!(!formatted.contains("kokot"));
970        assert!(!formatted.contains("bezpeci2021"));
971    }
972
973    #[rstest]
974    #[tokio::test]
975    async fn test_get_or_create_experiment(experiment_name: String, client: Client) {
976        client
977            .get_or_create_experiment(&experiment_name, vec![])
978            .await
979            .expect("BUG: Cannot create experiment");
980    }
981
982    #[rstest]
983    #[tokio::test]
984    async fn test_create_get_delete_experiment(experiment_name: String, client: Client) {
985        let id = client
986            .create_experiment(&experiment_name, vec![])
987            .await
988            .expect("BUG: Cannot create experiment");
989
990        let experiment = client
991            .get_experiment(&id)
992            .await
993            .expect("BUG: Cannot get experiment");
994
995        assert_eq!(experiment.name, experiment_name);
996
997        client
998            .delete_experiment(&id)
999            .await
1000            .expect("BUG: Cannot update experiment");
1001
1002        client
1003            .create_experiment(&experiment_name, vec![])
1004            .await
1005            .expect_err("BUG: Cannot create deleted experiment");
1006    }
1007
1008    #[rstest]
1009    #[tokio::test]
1010    async fn test_get_experiment_by_name(experiment_name: String, client: Client) {
1011        let experiment_id = client
1012            .create_experiment(&experiment_name, vec![])
1013            .await
1014            .expect("BUG: Cannot create experiment");
1015
1016        let experiment = client
1017            .get_experiment_by_name(&experiment_name)
1018            .await
1019            .expect("BUG: Cannot get crated experiment");
1020
1021        assert_eq!(experiment.experiment_id, experiment_id);
1022    }
1023
1024    #[rstest]
1025    #[tokio::test]
1026    #[awt]
1027    async fn test_run_search(client: Client, #[future] run: Run) {
1028        let search = SearchRuns::new()
1029            .experiment_ids(vec![run.info.experiment_id.clone()])
1030            .max_results(Some(16))
1031            .view(ViewType::All)
1032            .build();
1033
1034        let runs = client
1035            .search_runs(search)
1036            .await
1037            .expect("BUG: Cannot search runs");
1038
1039        assert!(!runs.is_empty());
1040
1041        // skip max-results and check that we've got something
1042        let search = SearchRuns::new()
1043            .experiment_ids(vec![run.info.experiment_id.clone()])
1044            .view(ViewType::All)
1045            .build();
1046
1047        let runs = client
1048            .search_runs(search)
1049            .await
1050            .expect("BUG: Cannot search runs");
1051
1052        assert!(!runs.is_empty());
1053    }
1054
1055    #[rstest]
1056    #[tokio::test]
1057    async fn test_run_create(experiment_name: String, run_name: String, client: Client) {
1058        let experiment_id = client
1059            .create_experiment(&experiment_name, vec![])
1060            .await
1061            .expect("BUG: Cannot create experiment");
1062
1063        let create = CreateRun::new()
1064            .run_name(&run_name)
1065            .experiment_id(&experiment_id)
1066            .build();
1067
1068        let run = client
1069            .create_run(create)
1070            .await
1071            .expect("BUG: Cannot create run");
1072
1073        let run1 = client
1074            .get_run(&run.info.run_id)
1075            .await
1076            .expect("BUG: Cannot get run");
1077
1078        assert_eq!(run, run1);
1079    }
1080
1081    #[rstest]
1082    #[tokio::test]
1083    async fn test_run_delete(experiment_name: String, run_name: String, client: Client) {
1084        let experiment_id = client
1085            .create_experiment(&experiment_name, vec![])
1086            .await
1087            .expect("BUG: Cannot create experiment");
1088
1089        let create = CreateRun::new()
1090            .run_name(&run_name)
1091            .experiment_id(&experiment_id)
1092            .build();
1093
1094        let run = client
1095            .create_run(create)
1096            .await
1097            .expect("BUG: Cannot create run");
1098
1099        client
1100            .delete_run(&run.info.run_id)
1101            .await
1102            .expect("BUG: Cannot delete run");
1103
1104        client
1105            .delete_run(&run.info.run_id)
1106            .await
1107            .expect("BUG: Can delete non-existing run");
1108    }
1109
1110    #[rstest]
1111    #[tokio::test]
1112    async fn test_run_update_data(experiment_name: String, run_name: String, client: Client) {
1113        let experiment_id = client
1114            .create_experiment(&experiment_name, vec![])
1115            .await
1116            .expect("BUG: Cannot create experiment");
1117
1118        let create = CreateRun::new()
1119            .run_name(&run_name)
1120            .experiment_id(&experiment_id)
1121            .build();
1122
1123        let run = client
1124            .create_run(create)
1125            .await
1126            .expect("BUG: Cannot create run");
1127
1128        let mut update = RunData::new()
1129            .metrics(vec![Metric::new().key("m").value(37.).step(0).build()])
1130            .params(vec![KeyValue::new().key("p").value("42").build()])
1131            .tags(vec![KeyValue::new().key("t").value("73").build()])
1132            .build();
1133
1134        client
1135            .add_run_meta(&run.info.run_id, update.clone())
1136            .await
1137            .expect("BUG: Cannot update run data");
1138
1139        update.tags.insert(
1140            0,
1141            KeyValue::new()
1142                .key("mlflow.runName")
1143                .value(run_name)
1144                .build(),
1145        );
1146
1147        let run1 = client
1148            .get_run(&run.info.run_id)
1149            .await
1150            .expect("BUG: Cannot get run");
1151
1152        assert_eq!(run1.data, update);
1153    }
1154
1155    #[rstest]
1156    #[tokio::test]
1157    async fn test_run_add_inputs(experiment_name: String, run_name: String, client: Client) {
1158        let experiment_id = client
1159            .create_experiment(&experiment_name, vec![])
1160            .await
1161            .expect("BUG: Cannot create experiment");
1162
1163        let create = CreateRun::new()
1164            .run_name(&run_name)
1165            .experiment_id(&experiment_id)
1166            .build();
1167
1168        let run = client
1169            .create_run(create)
1170            .await
1171            .expect("BUG: Cannot create run");
1172
1173        let input = DataSetInput::new()
1174            .tags(vec![KeyValue::new().key("a").value("x").build()])
1175            .dataset(
1176                DataSet::new()
1177                    .name("kokot")
1178                    .digest("123")
1179                    .source_type("kokot1")
1180                    .source("s3")
1181                    .schema("{}")
1182                    .profile("{\"rows\": 22}")
1183                    .build(),
1184            )
1185            .build();
1186
1187        client
1188            .add_run_inputs(&run.info.run_id, vec![input.clone()])
1189            .await
1190            .expect("BUG: Unable to add inputs to run");
1191
1192        let run1 = client
1193            .get_run(&run.info.run_id)
1194            .await
1195            .expect("BUG: Cannot get run");
1196
1197        assert_eq!(
1198            run1.inputs,
1199            RunInputs {
1200                inputs: vec![input]
1201            }
1202        );
1203    }
1204
1205    #[fixture]
1206    async fn run(experiment_name: String, run_name: String, client: Client) -> Run {
1207        let experiment_id = client
1208            .create_experiment(&experiment_name, vec![])
1209            .await
1210            .expect("BUG: Cannot create experiment");
1211
1212        let create = CreateRun::new()
1213            .run_name(&run_name)
1214            .experiment_id(&experiment_id)
1215            .build();
1216
1217        client
1218            .create_run(create)
1219            .await
1220            .expect("BUG: Cannot create run")
1221    }
1222
1223    #[rstest]
1224    #[tokio::test]
1225    #[awt]
1226    async fn test_run_update(client: Client, #[future] run: Run) {
1227        let end_time_ms: u64 = 1733565663000;
1228
1229        let update = UpdateRun::new()
1230            .run_id(&run.info.run_id)
1231            .status(RunStatus::Killed)
1232            .end_time(end_time_ms as i64)
1233            .experiment_id(&run.info.experiment_id)
1234            .build();
1235
1236        client
1237            .update_run(update)
1238            .await
1239            .expect("BUG: Unable to update run");
1240
1241        let run1 = client
1242            .get_run(&run.info.run_id)
1243            .await
1244            .expect("BUG: Unable to fetch run");
1245
1246        assert_eq!(run1.info.end_time, Some(end_time_ms));
1247        assert_eq!(run1.info.status, RunStatus::Killed);
1248    }
1249
1250    #[rstest]
1251    #[tokio::test]
1252    #[awt]
1253    #[traced_test]
1254    async fn test_artifacts(client: Client, #[future] run: Run) {
1255        let artifact = Artifact {
1256            experiment_id: run.info.experiment_id.clone(),
1257            run_id: run.info.run_id.clone(),
1258            path: PathBuf::from("abc/lock"),
1259        };
1260
1261        let mut artifact1 = artifact.clone();
1262        artifact1.path = PathBuf::from("abc/lock1");
1263
1264        let artifacts = vec![
1265            UploadArtifact {
1266                local: "Cargo.lock".into(),
1267                remote: artifact.clone(),
1268            },
1269            UploadArtifact {
1270                local: "Cargo.lock".into(),
1271                remote: artifact1.clone(),
1272            },
1273        ];
1274
1275        client
1276            .upload_artifacts(artifacts)
1277            .await
1278            .expect("BUG: Unable to upload artifacts");
1279
1280        let list = client
1281            .list_run_artifacts(&run.info.run_id)
1282            .await
1283            .expect("BUG: Unable to list run artifacts");
1284
1285        assert_eq!(list.files.len(), 2);
1286
1287        let download = client
1288            .prepare_run_download(&run.info.run_id, "/tmp")
1289            .await
1290            .expect("BUG: Unable to prepare run download");
1291
1292        let mut artifacts = client
1293            .download_run_artifacts(download)
1294            .await
1295            .expect("BUG: Cannot download artifact");
1296
1297        artifacts.paths.sort();
1298
1299        assert_eq!(
1300            artifacts.paths,
1301            vec![
1302                PathBuf::from("/tmp/abc/lock"),
1303                PathBuf::from("/tmp/abc/lock1"),
1304            ]
1305        );
1306    }
1307
1308    #[rstest]
1309    #[tokio::test]
1310    #[awt]
1311    async fn test_artifacts_cache(mut client: Client, #[future] run: Run) {
1312        // lets upload a file of non-zero size
1313        let artifact = Artifact {
1314            experiment_id: run.info.experiment_id.clone(),
1315            run_id: run.info.run_id.clone(),
1316            path: PathBuf::from("run.json"),
1317        };
1318
1319        client
1320            .upload_json_artifact(&run, &artifact)
1321            .await
1322            .expect("BUG: Cannot upload artifact");
1323
1324        // now enable caching, create local file and check that local file
1325        // was not overwriteen
1326        client.download.cache_local_artifacts = true;
1327
1328        let filedir = PathBuf::from(format!("/tmp/{}/", run.info.run_id));
1329        let filepath = filedir.join("run.json");
1330
1331        create_dir_all(&filedir)
1332            .await
1333            .expect("BUG: Unable to create destdir");
1334
1335        let file = File::create(&filepath)
1336            .await
1337            .expect("BUG: Unable to create dummy file");
1338
1339        let download = client
1340            .prepare_run_download(&run.info.run_id, filedir)
1341            .await
1342            .expect("BUG: Unable to prepare run download");
1343
1344        let artifacts = client
1345            .download_run_artifacts(download)
1346            .await
1347            .expect("BUG: Cannot download artifact");
1348
1349        assert_eq!(artifacts.paths, vec![filepath]);
1350
1351        assert_eq!(
1352            file.metadata()
1353                .await
1354                .expect("BUG: Unable to get file metadata")
1355                .len(),
1356            0
1357        );
1358    }
1359
1360    #[rstest]
1361    #[tokio::test]
1362    #[awt]
1363    async fn test_json_artifact(client: Client, #[future] run: Run) {
1364        let artifact = Artifact {
1365            experiment_id: run.info.experiment_id.clone(),
1366            run_id: run.info.run_id.clone(),
1367            path: PathBuf::from("run.json"),
1368        };
1369
1370        client
1371            .upload_json_artifact(&run, &artifact)
1372            .await
1373            .expect("BUG: Cannot upload artifact");
1374
1375        let run1 = client
1376            .download_json_artifact::<Run>(&artifact)
1377            .await
1378            .expect("BUG: Cannot download artifact");
1379
1380        assert_eq!(run, run1);
1381    }
1382
1383    #[rstest]
1384    #[tokio::test]
1385    #[awt]
1386    async fn test_registered_model(client: Client, model_name: String) {
1387        let registered = client
1388            .register_model(
1389                RegisterModel::new()
1390                    .name(&model_name)
1391                    .description("")
1392                    .build(),
1393            )
1394            .await
1395            .expect("BUG: Unable to register model");
1396
1397        let registered1 = client
1398            .get_registered_model(&model_name)
1399            .await
1400            .expect("BUG: Cannot get registered model");
1401
1402        assert_eq!(registered.name, registered1.name);
1403        assert_eq!(
1404            registered.creation_timestamp,
1405            registered1.creation_timestamp
1406        );
1407
1408        let _ = client
1409            .delete_registered_model(&model_name)
1410            .await
1411            .expect("BUG: Unable to delete registered model");
1412    }
1413
1414    #[rstest]
1415    #[tokio::test]
1416    #[awt]
1417    async fn test_search_registered_models(client: Client, #[future] run: Run, model_name: String) {
1418        let _registered = client
1419            .register_model(
1420                RegisterModel::new()
1421                    .name(&model_name)
1422                    .description("yep")
1423                    .tags(vec![KeyValue::new()
1424                        .key("kokot")
1425                        .value(&model_name)
1426                        .build()])
1427                    .build(),
1428            )
1429            .await
1430            .expect("BUG: Unable to register model");
1431
1432        let create = CreateModelVersion::new()
1433            .registered_model_name(&model_name)
1434            .artifacts_url("s3:///kokot")
1435            .run_id(&run.info.run_id)
1436            .description("xxx")
1437            .build();
1438
1439        client
1440            .create_model_version(create)
1441            .await
1442            .expect("BUG: Cannot create model version");
1443
1444        let mut models = client
1445            .search_registered_models(&format!("tags.kokot like '({}|abc)'", model_name))
1446            .await
1447            .expect("BUG: Cannot search registered models");
1448
1449        let mut model = models
1450            .pop()
1451            .expect("BUG: We must get at least one registred model");
1452
1453        let latest = model
1454            .latest_versions
1455            .pop()
1456            .expect("BUG: Model must have at least one latest version");
1457
1458        assert_eq!(latest.run_id, run.info.run_id);
1459    }
1460
1461    #[rstest]
1462    #[tokio::test]
1463    #[awt]
1464    async fn test_add_model_version(
1465        client: Client,
1466        #[future] run: Run,
1467        #[future]
1468        #[from(run)]
1469        run1: Run,
1470        model_name: String,
1471    ) {
1472        let _registered = client
1473            .register_model(
1474                RegisterModel::new()
1475                    .name(&model_name)
1476                    .description("yep")
1477                    .build(),
1478            )
1479            .await
1480            .expect("BUG: Unable to register model");
1481
1482        let create = CreateModelVersion::new()
1483            .registered_model_name(&model_name)
1484            .artifacts_url("s3:///kokot")
1485            .run_id(&run.info.run_id)
1486            .description("xxx")
1487            .build();
1488
1489        let mut create1 = create.clone();
1490        create1.run_id = run1.info.run_id;
1491
1492        client
1493            .create_model_version(create)
1494            .await
1495            .expect("BUG: Cannot create model version");
1496
1497        let version = client
1498            .create_model_version(create1)
1499            .await
1500            .expect("BUG: Cannot create model version");
1501
1502        assert_eq!(version.current_stage, "None");
1503
1504        let transition = TransitionModelVersionStage::new()
1505            .name(&model_name)
1506            .version(version.version)
1507            .stage(ModelVersionStage::Production)
1508            .archive_existing_versions(false)
1509            .build();
1510
1511        let version1 = client
1512            .transition_model_version_stage(transition)
1513            .await
1514            .expect("BUG: Cannot transition model stage");
1515
1516        assert_eq!(version1.current_stage, "Production");
1517    }
1518}