xvc_pipeline/pipeline/deps/
param.rs

1//! A parameter dependency is a key-value pair that is extracted from a parameter in YAML,
2//! TOML or JSON file.
3use crate::error::{Error, Result};
4use crate::XvcDependency;
5use serde_json::value::Value as JsonValue;
6use serde_yaml::Value as YamlValue;
7use std::ffi::OsString;
8use std::{ffi::OsStr, fmt::Display, fs, path::Path};
9use toml::Value as TomlValue;
10use xvc_core::types::diff::Diffable;
11use xvc_core::{Diff, XvcMetadata, XvcPath, XvcPathMetadataProvider, XvcRoot};
12use xvc_core::persist;
13
14use log::{error, warn};
15use serde::{Deserialize, Serialize};
16
17/// Invalidates when key in params file in path changes.
18#[derive(Debug, PartialOrd, Ord, Clone, Eq, PartialEq, Serialize, Deserialize)]
19pub struct ParamDep {
20    /// Format of the params file.
21    /// This is inferred from extension if not given.
22    pub format: XvcParamFormat,
23    /// Path of the file in the workspace
24    pub path: XvcPath,
25    /// Key like `mydict.mykey` to access the value
26    pub key: String,
27    /// The value of the key
28    pub value: Option<XvcParamValue>,
29    /// The metadata of the parameter file to detect if it has changed
30    pub xvc_metadata: Option<XvcMetadata>,
31}
32
33persist!(ParamDep, "param-dependency");
34
35impl From<ParamDep> for XvcDependency {
36    fn from(val: ParamDep) -> Self {
37        XvcDependency::Param(val)
38    }
39}
40
41impl ParamDep {
42    /// Creates a new ParamDep with the given path and key. If the format is None, it's inferred
43    /// from the path.
44    pub fn new(path: &XvcPath, format: Option<XvcParamFormat>, key: String) -> Result<Self> {
45        Ok(Self {
46            format: format.unwrap_or_else(|| XvcParamFormat::from_xvc_path(path)),
47            path: path.clone(),
48            key,
49            value: None,
50            xvc_metadata: None,
51        })
52    }
53
54    /// Update metada from the [XvcPathMetadataProvider]
55    pub fn update_metadata(self, pmp: &XvcPathMetadataProvider) -> Result<Self> {
56        let xvc_metadata = pmp.get(&self.path);
57        Ok(Self {
58            xvc_metadata,
59            ..self
60        })
61    }
62
63    /// Update value by reading the file
64    pub fn update_value(self, xvc_root: &XvcRoot) -> Result<Self> {
65        let path = self.path.to_absolute_path(xvc_root);
66        let value = Some(XvcParamValue::new_with_format(
67            &path,
68            &self.format,
69            &self.key,
70        )?);
71        Ok(Self { value, ..self })
72    }
73}
74
75impl Diffable for ParamDep {
76    type Item = Self;
77
78    /// ⚠️ Call actual.update_metadata before calling this function ⚠️
79    fn diff_superficial(record: &Self::Item, actual: &Self::Item) -> Diff<Self::Item> {
80        assert!(record.path == actual.path);
81        match (record.xvc_metadata, actual.xvc_metadata) {
82            (Some(record_md), Some(actual_md)) => {
83                if record_md == actual_md {
84                    Diff::Identical
85                } else {
86                    Diff::Different {
87                        record: record.clone(),
88                        actual: actual.clone(),
89                    }
90                }
91            }
92            (None, Some(_)) => Diff::RecordMissing {
93                actual: actual.clone(),
94            },
95            (Some(_), None) => Diff::ActualMissing {
96                record: record.clone(),
97            },
98            (None, None) => unreachable!("One of the metadata should always be present"),
99        }
100    }
101
102    /// ⚠️ Call actual.update_metadata and actual.update_value before calling this function ⚠️
103    fn diff_thorough(record: &Self::Item, actual: &Self::Item) -> Diff<Self::Item> {
104        assert!(record.path == actual.path);
105        match Self::diff_superficial(record, actual) {
106            Diff::Identical => Diff::Identical,
107            Diff::Different { .. } => {
108                if record.value == actual.value {
109                    Diff::Identical
110                } else {
111                    Diff::Different {
112                        record: record.clone(),
113                        actual: actual.clone(),
114                    }
115                }
116            }
117            Diff::RecordMissing { .. } => Diff::RecordMissing {
118                actual: actual.clone(),
119            },
120            Diff::ActualMissing { .. } => Diff::ActualMissing {
121                record: record.clone(),
122            },
123            Diff::Skipped => Diff::Skipped,
124        }
125    }
126}
127
128/// Parsable formats of a parameter file
129#[derive(Debug, Clone, Copy, Eq, PartialOrd, Ord, PartialEq, Serialize, Deserialize)]
130pub enum XvcParamFormat {
131    /// The default value if we cannot infer the format somehow
132    Unknown,
133    /// Yaml files are parsed with [serde_yaml]
134    YAML,
135    /// Json files are parsed with [serde_json]
136    JSON,
137    /// Toml files are parsed with [toml]
138    TOML,
139}
140
141impl XvcParamFormat {
142    fn from_extension(ext: &OsStr) -> Self {
143        match ext.to_str().unwrap_or("") {
144            "json" | "JSON" => Self::JSON,
145            "yaml" | "yml" => Self::YAML,
146            "toml" | "tom" | "tml" => Self::TOML,
147            _ => {
148                warn!("[W0000] Unknown parameter file extension: {:?}", ext);
149                Self::Unknown
150            }
151        }
152    }
153
154    /// Infer the (hyper)parameter file format from the file path, by checking
155    /// its extension.
156    pub fn from_path(path: &Path) -> Self {
157        match path.extension() {
158            None => {
159                error!("[E0000] Params file has no extension: {:?}", path);
160                Self::Unknown
161            }
162            Some(ext) => Self::from_extension(ext),
163        }
164    }
165
166    /// Infer the (hyper)parameter file format from the xvc_path's extension
167    pub fn from_xvc_path(xvc_path: &XvcPath) -> Self {
168        let extension: OsString = xvc_path
169            .extension()
170            .map(|s| s.to_owned())
171            .unwrap_or_else(|| "".to_owned())
172            .into();
173        Self::from_extension(&extension)
174    }
175}
176
177/// The value of a parameter
178#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
179pub enum XvcParamValue {
180    /// Value of a key in JSON file
181    Json(JsonValue),
182    /// Value of a key in YAML file
183    Yaml(YamlValue),
184    /// Value of a key in TOML file
185    Toml(TomlValue),
186}
187
188impl PartialOrd for XvcParamValue {
189    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
190        Some(self.cmp(other))
191    }
192}
193
194impl Ord for XvcParamValue {
195    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
196        let self_str = self.to_string();
197        let other_str = other.to_string();
198        self_str.cmp(&other_str)
199    }
200}
201
202impl Eq for XvcParamValue {}
203
204impl Display for XvcParamValue {
205    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206        match self {
207            XvcParamValue::Json(json) => write!(f, "{}", json),
208            XvcParamValue::Yaml(yaml) => {
209                let s =
210                    serde_yaml::to_string(yaml).unwrap_or_else(|_| "Error in YAML String".into());
211                write!(f, "{}", s)
212            }
213            XvcParamValue::Toml(toml) => write!(f, "{}", toml),
214        }
215    }
216}
217
218impl XvcParamValue {
219    /// Creates a new key with an empty value pointing to a file with an explicit [XvcParamFormat]
220    pub fn new_with_format(path: &Path, format: &XvcParamFormat, key: &str) -> Result<Self> {
221        let all_content = fs::read_to_string(path)?;
222
223        let res = match format {
224            XvcParamFormat::JSON => Self::parse_json(&all_content, key),
225            XvcParamFormat::YAML => Self::parse_yaml(&all_content, key),
226            XvcParamFormat::TOML => Self::parse_toml(&all_content, key),
227            XvcParamFormat::Unknown => Err(Error::UnsupportedParamFileFormat {
228                path: path.as_os_str().into(),
229            }),
230        };
231
232        match res {
233            // Adding the path here, normally there should be two different error messages
234            Err(Error::KeyNotFoundInDocument { .. }) => Err(Error::KeyNotFoundInDocument {
235                key: key.to_string(),
236                path: path.to_path_buf(),
237            }),
238            Err(e) => Err(e),
239            Ok(p) => Ok(p),
240        }
241    }
242
243    fn parse_json(all_content: &str, key: &str) -> Result<Self> {
244        let json_map: JsonValue = serde_json::from_str(all_content)?;
245        let nested_keys: Vec<&str> = key.split('.').collect();
246        let mut current_scope = json_map;
247        for k in &nested_keys {
248            if let Some(current_value) = current_scope.get(*k) {
249                match current_value {
250                    JsonValue::Object(_) => current_scope = current_value.clone(),
251                    JsonValue::String(_)
252                    | JsonValue::Number(_)
253                    | JsonValue::Bool(_)
254                    | JsonValue::Array(_) => return Ok(XvcParamValue::Json(current_value.clone())),
255                    JsonValue::Null => {
256                        return Err(Error::JsonNullValueForKey { key: key.into() });
257                    }
258                }
259            } else {
260                return Err(Error::KeyNotFound { key: key.into() });
261            }
262        }
263        // If we consumed all key elements and come to here, we consider the current scope as value
264        Ok(XvcParamValue::Json(current_scope))
265    }
266
267    /// Loads the key (in the form of a.b.c) from a YAML document
268    fn parse_yaml(all_content: &str, key: &str) -> Result<XvcParamValue> {
269        let yaml_map: YamlValue = serde_yaml::from_str(all_content)?;
270
271        let nested_keys: Vec<&str> = key.split('.').collect();
272        let mut current_scope: YamlValue = yaml_map;
273        for k in &nested_keys {
274            if let Some(current_value) = current_scope.get(*k) {
275                match current_value {
276                    YamlValue::Mapping(_) => {
277                        current_scope = serde_yaml::from_value(current_value.clone())?;
278                    }
279                    YamlValue::Tagged(tv) => {
280                        current_scope = serde_yaml::from_value(tv.value.clone())?
281                    }
282                    YamlValue::String(_)
283                    | YamlValue::Number(_)
284                    | YamlValue::Bool(_)
285                    | YamlValue::Sequence(_) => {
286                        return Ok(XvcParamValue::Yaml(current_value.clone()));
287                    }
288                    YamlValue::Null => {
289                        return Err(Error::YamlNullValueForKey { key: key.into() });
290                    }
291                }
292            } else {
293                return Err(Error::KeyNotFound { key: key.into() });
294            }
295        }
296        // If we consumed the key without errors, we consider the resulting scope as the value
297        Ok(XvcParamValue::Yaml(current_scope))
298    }
299
300    /// Loads a TOML file and returns the `XvcParamPair::TOML(TomlValue)`
301    /// associated with the key
302    fn parse_toml(all_content: &str, key: &str) -> Result<Self> {
303        let toml_map = all_content.parse::<TomlValue>()?;
304        let nested_keys: Vec<&str> = key.split('.').collect();
305        let mut current_scope: TomlValue = toml_map;
306        for k in &nested_keys {
307            if let Some(current_value) = current_scope.get(*k) {
308                match current_value {
309                    TomlValue::Table(_) => {
310                        current_scope = current_value.clone();
311                    }
312                    TomlValue::String(_)
313                    | TomlValue::Integer(_)
314                    | TomlValue::Float(_)
315                    | TomlValue::Boolean(_)
316                    | TomlValue::Datetime(_)
317                    | TomlValue::Array(_) => {
318                        return Ok(XvcParamValue::Toml(current_value.clone()));
319                    }
320                }
321            } else {
322                return Err(Error::KeyNotFound { key: key.into() });
323            }
324        }
325        // If we consumed the key without errors, we consider the resulting scope as the value
326        Ok(XvcParamValue::Toml(current_scope))
327    }
328}
329
330#[cfg(test)]
331mod tests {
332
333    use super::*;
334
335    const YAML_PARAMS: &str = r#"
336train:
337  epochs: 10
338model:
339  conv_units: 16
340"#;
341
342    #[test]
343    fn test_yaml_params() -> Result<()> {
344        let train_epochs = XvcParamValue::parse_yaml(YAML_PARAMS, "train.epochs")?;
345        if let XvcParamValue::Yaml(YamlValue::Number(n)) = train_epochs {
346            assert!(n.as_u64() == Some(10u64))
347        } else {
348            panic!("Mismatched Yaml Type: {}", train_epochs);
349        }
350        Ok(())
351    }
352}