xvc_pipeline/pipeline/deps/
param.rs1use crate::error::{Error, Result};
4use crate::XvcDependency;
5use serde_json::value::Value as JsonValue;
6use serde_yaml::Value as YamlValue;
7use std::ffi::OsString;
8use std::{ffi::OsStr, fmt::Display, fs, path::Path};
9use toml::Value as TomlValue;
10use xvc_core::types::diff::Diffable;
11use xvc_core::{Diff, XvcMetadata, XvcPath, XvcPathMetadataProvider, XvcRoot};
12use xvc_core::persist;
13
14use log::{error, warn};
15use serde::{Deserialize, Serialize};
16
17#[derive(Debug, PartialOrd, Ord, Clone, Eq, PartialEq, Serialize, Deserialize)]
19pub struct ParamDep {
20 pub format: XvcParamFormat,
23 pub path: XvcPath,
25 pub key: String,
27 pub value: Option<XvcParamValue>,
29 pub xvc_metadata: Option<XvcMetadata>,
31}
32
33persist!(ParamDep, "param-dependency");
34
35impl From<ParamDep> for XvcDependency {
36 fn from(val: ParamDep) -> Self {
37 XvcDependency::Param(val)
38 }
39}
40
41impl ParamDep {
42 pub fn new(path: &XvcPath, format: Option<XvcParamFormat>, key: String) -> Result<Self> {
45 Ok(Self {
46 format: format.unwrap_or_else(|| XvcParamFormat::from_xvc_path(path)),
47 path: path.clone(),
48 key,
49 value: None,
50 xvc_metadata: None,
51 })
52 }
53
54 pub fn update_metadata(self, pmp: &XvcPathMetadataProvider) -> Result<Self> {
56 let xvc_metadata = pmp.get(&self.path);
57 Ok(Self {
58 xvc_metadata,
59 ..self
60 })
61 }
62
63 pub fn update_value(self, xvc_root: &XvcRoot) -> Result<Self> {
65 let path = self.path.to_absolute_path(xvc_root);
66 let value = Some(XvcParamValue::new_with_format(
67 &path,
68 &self.format,
69 &self.key,
70 )?);
71 Ok(Self { value, ..self })
72 }
73}
74
75impl Diffable for ParamDep {
76 type Item = Self;
77
78 fn diff_superficial(record: &Self::Item, actual: &Self::Item) -> Diff<Self::Item> {
80 assert!(record.path == actual.path);
81 match (record.xvc_metadata, actual.xvc_metadata) {
82 (Some(record_md), Some(actual_md)) => {
83 if record_md == actual_md {
84 Diff::Identical
85 } else {
86 Diff::Different {
87 record: record.clone(),
88 actual: actual.clone(),
89 }
90 }
91 }
92 (None, Some(_)) => Diff::RecordMissing {
93 actual: actual.clone(),
94 },
95 (Some(_), None) => Diff::ActualMissing {
96 record: record.clone(),
97 },
98 (None, None) => unreachable!("One of the metadata should always be present"),
99 }
100 }
101
102 fn diff_thorough(record: &Self::Item, actual: &Self::Item) -> Diff<Self::Item> {
104 assert!(record.path == actual.path);
105 match Self::diff_superficial(record, actual) {
106 Diff::Identical => Diff::Identical,
107 Diff::Different { .. } => {
108 if record.value == actual.value {
109 Diff::Identical
110 } else {
111 Diff::Different {
112 record: record.clone(),
113 actual: actual.clone(),
114 }
115 }
116 }
117 Diff::RecordMissing { .. } => Diff::RecordMissing {
118 actual: actual.clone(),
119 },
120 Diff::ActualMissing { .. } => Diff::ActualMissing {
121 record: record.clone(),
122 },
123 Diff::Skipped => Diff::Skipped,
124 }
125 }
126}
127
128#[derive(Debug, Clone, Copy, Eq, PartialOrd, Ord, PartialEq, Serialize, Deserialize)]
130pub enum XvcParamFormat {
131 Unknown,
133 YAML,
135 JSON,
137 TOML,
139}
140
141impl XvcParamFormat {
142 fn from_extension(ext: &OsStr) -> Self {
143 match ext.to_str().unwrap_or("") {
144 "json" | "JSON" => Self::JSON,
145 "yaml" | "yml" => Self::YAML,
146 "toml" | "tom" | "tml" => Self::TOML,
147 _ => {
148 warn!("[W0000] Unknown parameter file extension: {:?}", ext);
149 Self::Unknown
150 }
151 }
152 }
153
154 pub fn from_path(path: &Path) -> Self {
157 match path.extension() {
158 None => {
159 error!("[E0000] Params file has no extension: {:?}", path);
160 Self::Unknown
161 }
162 Some(ext) => Self::from_extension(ext),
163 }
164 }
165
166 pub fn from_xvc_path(xvc_path: &XvcPath) -> Self {
168 let extension: OsString = xvc_path
169 .extension()
170 .map(|s| s.to_owned())
171 .unwrap_or_else(|| "".to_owned())
172 .into();
173 Self::from_extension(&extension)
174 }
175}
176
177#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
179pub enum XvcParamValue {
180 Json(JsonValue),
182 Yaml(YamlValue),
184 Toml(TomlValue),
186}
187
188impl PartialOrd for XvcParamValue {
189 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
190 Some(self.cmp(other))
191 }
192}
193
194impl Ord for XvcParamValue {
195 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
196 let self_str = self.to_string();
197 let other_str = other.to_string();
198 self_str.cmp(&other_str)
199 }
200}
201
202impl Eq for XvcParamValue {}
203
204impl Display for XvcParamValue {
205 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206 match self {
207 XvcParamValue::Json(json) => write!(f, "{}", json),
208 XvcParamValue::Yaml(yaml) => {
209 let s =
210 serde_yaml::to_string(yaml).unwrap_or_else(|_| "Error in YAML String".into());
211 write!(f, "{}", s)
212 }
213 XvcParamValue::Toml(toml) => write!(f, "{}", toml),
214 }
215 }
216}
217
218impl XvcParamValue {
219 pub fn new_with_format(path: &Path, format: &XvcParamFormat, key: &str) -> Result<Self> {
221 let all_content = fs::read_to_string(path)?;
222
223 let res = match format {
224 XvcParamFormat::JSON => Self::parse_json(&all_content, key),
225 XvcParamFormat::YAML => Self::parse_yaml(&all_content, key),
226 XvcParamFormat::TOML => Self::parse_toml(&all_content, key),
227 XvcParamFormat::Unknown => Err(Error::UnsupportedParamFileFormat {
228 path: path.as_os_str().into(),
229 }),
230 };
231
232 match res {
233 Err(Error::KeyNotFoundInDocument { .. }) => Err(Error::KeyNotFoundInDocument {
235 key: key.to_string(),
236 path: path.to_path_buf(),
237 }),
238 Err(e) => Err(e),
239 Ok(p) => Ok(p),
240 }
241 }
242
243 fn parse_json(all_content: &str, key: &str) -> Result<Self> {
244 let json_map: JsonValue = serde_json::from_str(all_content)?;
245 let nested_keys: Vec<&str> = key.split('.').collect();
246 let mut current_scope = json_map;
247 for k in &nested_keys {
248 if let Some(current_value) = current_scope.get(*k) {
249 match current_value {
250 JsonValue::Object(_) => current_scope = current_value.clone(),
251 JsonValue::String(_)
252 | JsonValue::Number(_)
253 | JsonValue::Bool(_)
254 | JsonValue::Array(_) => return Ok(XvcParamValue::Json(current_value.clone())),
255 JsonValue::Null => {
256 return Err(Error::JsonNullValueForKey { key: key.into() });
257 }
258 }
259 } else {
260 return Err(Error::KeyNotFound { key: key.into() });
261 }
262 }
263 Ok(XvcParamValue::Json(current_scope))
265 }
266
267 fn parse_yaml(all_content: &str, key: &str) -> Result<XvcParamValue> {
269 let yaml_map: YamlValue = serde_yaml::from_str(all_content)?;
270
271 let nested_keys: Vec<&str> = key.split('.').collect();
272 let mut current_scope: YamlValue = yaml_map;
273 for k in &nested_keys {
274 if let Some(current_value) = current_scope.get(*k) {
275 match current_value {
276 YamlValue::Mapping(_) => {
277 current_scope = serde_yaml::from_value(current_value.clone())?;
278 }
279 YamlValue::Tagged(tv) => {
280 current_scope = serde_yaml::from_value(tv.value.clone())?
281 }
282 YamlValue::String(_)
283 | YamlValue::Number(_)
284 | YamlValue::Bool(_)
285 | YamlValue::Sequence(_) => {
286 return Ok(XvcParamValue::Yaml(current_value.clone()));
287 }
288 YamlValue::Null => {
289 return Err(Error::YamlNullValueForKey { key: key.into() });
290 }
291 }
292 } else {
293 return Err(Error::KeyNotFound { key: key.into() });
294 }
295 }
296 Ok(XvcParamValue::Yaml(current_scope))
298 }
299
300 fn parse_toml(all_content: &str, key: &str) -> Result<Self> {
303 let toml_map = all_content.parse::<TomlValue>()?;
304 let nested_keys: Vec<&str> = key.split('.').collect();
305 let mut current_scope: TomlValue = toml_map;
306 for k in &nested_keys {
307 if let Some(current_value) = current_scope.get(*k) {
308 match current_value {
309 TomlValue::Table(_) => {
310 current_scope = current_value.clone();
311 }
312 TomlValue::String(_)
313 | TomlValue::Integer(_)
314 | TomlValue::Float(_)
315 | TomlValue::Boolean(_)
316 | TomlValue::Datetime(_)
317 | TomlValue::Array(_) => {
318 return Ok(XvcParamValue::Toml(current_value.clone()));
319 }
320 }
321 } else {
322 return Err(Error::KeyNotFound { key: key.into() });
323 }
324 }
325 Ok(XvcParamValue::Toml(current_scope))
327 }
328}
329
330#[cfg(test)]
331mod tests {
332
333 use super::*;
334
335 const YAML_PARAMS: &str = r#"
336train:
337 epochs: 10
338model:
339 conv_units: 16
340"#;
341
342 #[test]
343 fn test_yaml_params() -> Result<()> {
344 let train_epochs = XvcParamValue::parse_yaml(YAML_PARAMS, "train.epochs")?;
345 if let XvcParamValue::Yaml(YamlValue::Number(n)) = train_epochs {
346 assert!(n.as_u64() == Some(10u64))
347 } else {
348 panic!("Mismatched Yaml Type: {}", train_epochs);
349 }
350 Ok(())
351 }
352}