use crate::error::{Error, Result};
use crate::XvcDependency;
use serde_json::value::Value as JsonValue;
use serde_yaml::Value as YamlValue;
use std::ffi::OsString;
use std::{ffi::OsStr, fmt::Display, fs, path::Path};
use toml::Value as TomlValue;
use xvc_core::types::diff::Diffable;
use xvc_core::{Diff, XvcMetadata, XvcPath, XvcPathMetadataMap, XvcRoot};
use xvc_ecs::persist;
use xvc_logging::watch;
use log::{error, warn};
use serde::{Deserialize, Serialize};
#[derive(Debug, PartialOrd, Ord, Clone, Eq, PartialEq, Serialize, Deserialize)]
pub struct ParamDep {
pub format: XvcParamFormat,
pub path: XvcPath,
pub key: String,
pub value: Option<XvcParamValue>,
pub xvc_metadata: Option<XvcMetadata>,
}
persist!(ParamDep, "param-dependency");
impl Into<XvcDependency> for ParamDep {
fn into(self) -> XvcDependency {
XvcDependency::Param(self)
}
}
impl ParamDep {
pub fn new(path: &XvcPath, format: Option<XvcParamFormat>, key: String) -> Result<Self> {
Ok(Self {
format: format.unwrap_or_else(|| XvcParamFormat::from_xvc_path(path)),
path: path.clone(),
key,
value: None,
xvc_metadata: None,
})
}
pub fn update_metadata(self, pmm: &XvcPathMetadataMap) -> Result<Self> {
let xvc_metadata = pmm.get(&self.path).cloned();
Ok(Self {
xvc_metadata,
..self
})
}
pub fn update_value(self, xvc_root: &XvcRoot) -> Result<Self> {
let path = self.path.to_absolute_path(xvc_root);
let value = Some(XvcParamValue::new_with_format(
&path,
&self.format,
&self.key,
)?);
Ok(Self { value, ..self })
}
}
impl Diffable for ParamDep {
type Item = Self;
fn diff_superficial(record: &Self::Item, actual: &Self::Item) -> Diff<Self::Item> {
assert!(record.path == actual.path);
watch!(record);
watch!(actual);
match (record.xvc_metadata, actual.xvc_metadata) {
(Some(record_md), Some(actual_md)) => {
if record_md == actual_md {
Diff::Identical
} else {
Diff::Different {
record: record.clone(),
actual: actual.clone(),
}
}
}
(None, Some(_)) => Diff::RecordMissing {
actual: actual.clone(),
},
(Some(_), None) => Diff::ActualMissing {
record: record.clone(),
},
(None, None) => unreachable!("One of the metadata should always be present"),
}
}
fn diff_thorough(record: &Self::Item, actual: &Self::Item) -> Diff<Self::Item> {
assert!(record.path == actual.path);
watch!(record);
watch!(actual);
match Self::diff_superficial(record, actual) {
Diff::Identical => Diff::Identical,
Diff::Different { .. } => {
if record.value == actual.value {
Diff::Identical
} else {
Diff::Different {
record: record.clone(),
actual: actual.clone(),
}
}
}
Diff::RecordMissing { .. } => Diff::RecordMissing {
actual: actual.clone(),
},
Diff::ActualMissing { .. } => Diff::ActualMissing {
record: record.clone(),
},
Diff::Skipped => Diff::Skipped,
}
}
}
#[derive(Debug, Clone, Copy, Eq, PartialOrd, Ord, PartialEq, Serialize, Deserialize)]
pub enum XvcParamFormat {
Unknown,
YAML,
JSON,
TOML,
}
impl XvcParamFormat {
fn from_extension(ext: &OsStr) -> Self {
match ext.to_str().unwrap_or("") {
"json" | "JSON" => Self::JSON,
"yaml" | "yml" => Self::YAML,
"toml" | "tom" | "tml" => Self::TOML,
_ => {
warn!("[W0000] Unknown parameter file extension: {:?}", ext);
Self::Unknown
}
}
}
pub fn from_path(path: &Path) -> Self {
match path.extension() {
None => {
error!("[E0000] Params file has no extension: {:?}", path);
Self::Unknown
}
Some(ext) => Self::from_extension(ext),
}
}
pub fn from_xvc_path(xvc_path: &XvcPath) -> Self {
let extension: OsString = xvc_path
.extension()
.map(|s| s.to_owned())
.unwrap_or_else(|| "".to_owned())
.into();
Self::from_extension(&extension)
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum XvcParamValue {
Json(JsonValue),
Yaml(YamlValue),
Toml(TomlValue),
}
impl PartialOrd for XvcParamValue {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
match (self, other) {
(XvcParamValue::Json(json1), XvcParamValue::Json(json2)) => {
let json1str = json1.to_string();
let json2str = json2.to_string();
json1str.partial_cmp(&json2str)
}
(XvcParamValue::Yaml(yaml1), XvcParamValue::Yaml(yaml2)) => yaml1.partial_cmp(yaml2),
(XvcParamValue::Toml(toml1), XvcParamValue::Toml(toml2)) => {
let toml1str = toml1.to_string();
let toml2str = toml2.to_string();
toml1str.partial_cmp(&toml2str)
}
_ => None,
}
}
}
impl Ord for XvcParamValue {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
let self_str = self.to_string();
let other_str = other.to_string();
self_str.cmp(&other_str)
}
}
impl Eq for XvcParamValue {}
impl Display for XvcParamValue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
XvcParamValue::Json(json) => write!(f, "{}", json),
XvcParamValue::Yaml(yaml) => {
let s =
serde_yaml::to_string(yaml).unwrap_or_else(|_| "Error in YAML String".into());
write!(f, "{}", s)
}
XvcParamValue::Toml(toml) => write!(f, "{}", toml),
}
}
}
impl XvcParamValue {
pub fn new_with_format(path: &Path, format: &XvcParamFormat, key: &str) -> Result<Self> {
let all_content = fs::read_to_string(path)?;
let res = match format {
XvcParamFormat::JSON => Self::parse_json(&all_content, key),
XvcParamFormat::YAML => Self::parse_yaml(&all_content, key),
XvcParamFormat::TOML => Self::parse_toml(&all_content, key),
XvcParamFormat::Unknown => Err(Error::UnsupportedParamFileFormat {
path: path.as_os_str().into(),
}),
};
match res {
Err(Error::KeyNotFoundInDocument { .. }) => Err(Error::KeyNotFoundInDocument {
key: key.to_string(),
path: path.to_path_buf(),
}),
Err(e) => Err(e),
Ok(p) => Ok(p),
}
}
fn parse_json(all_content: &str, key: &str) -> Result<Self> {
let json_map: JsonValue = serde_json::from_str(all_content)?;
let nested_keys: Vec<&str> = key.split('.').collect();
let mut current_scope = json_map;
for k in &nested_keys {
if let Some(current_value) = current_scope.get(*k) {
match current_value {
JsonValue::Object(_) => current_scope = current_value.clone(),
JsonValue::String(_)
| JsonValue::Number(_)
| JsonValue::Bool(_)
| JsonValue::Array(_) => return Ok(XvcParamValue::Json(current_value.clone())),
JsonValue::Null => {
return Err(Error::JsonNullValueForKey { key: key.into() });
}
}
} else {
return Err(Error::KeyNotFound { key: key.into() });
}
}
Ok(XvcParamValue::Json(current_scope))
}
fn parse_yaml(all_content: &str, key: &str) -> Result<XvcParamValue> {
let yaml_map: YamlValue = serde_yaml::from_str(all_content)?;
let nested_keys: Vec<&str> = key.split('.').collect();
let mut current_scope: YamlValue = yaml_map;
for k in &nested_keys {
if let Some(current_value) = current_scope.get(*k) {
match current_value {
YamlValue::Mapping(_) => {
current_scope = serde_yaml::from_value(current_value.clone())?;
}
YamlValue::Tagged(tv) => {
current_scope = serde_yaml::from_value(tv.value.clone())?
}
YamlValue::String(_)
| YamlValue::Number(_)
| YamlValue::Bool(_)
| YamlValue::Sequence(_) => {
return Ok(XvcParamValue::Yaml(current_value.clone()));
}
YamlValue::Null => {
return Err(Error::YamlNullValueForKey { key: key.into() });
}
}
} else {
return Err(Error::KeyNotFound { key: key.into() });
}
}
Ok(XvcParamValue::Yaml(current_scope))
}
fn parse_toml(all_content: &str, key: &str) -> Result<Self> {
let toml_map = all_content.parse::<TomlValue>()?;
let nested_keys: Vec<&str> = key.split('.').collect();
let mut current_scope: TomlValue = toml_map;
for k in &nested_keys {
if let Some(current_value) = current_scope.get(*k) {
match current_value {
TomlValue::Table(_) => {
current_scope = current_value.clone();
}
TomlValue::String(_)
| TomlValue::Integer(_)
| TomlValue::Float(_)
| TomlValue::Boolean(_)
| TomlValue::Datetime(_)
| TomlValue::Array(_) => {
return Ok(XvcParamValue::Toml(current_value.clone()));
}
}
} else {
return Err(Error::KeyNotFound { key: key.into() });
}
}
Ok(XvcParamValue::Toml(current_scope))
}
}
#[cfg(test)]
mod tests {
use super::*;
const YAML_PARAMS: &str = r#"
train:
epochs: 10
model:
conv_units: 16
"#;
#[test]
fn test_yaml_params() -> Result<()> {
let train_epochs = XvcParamValue::parse_yaml(YAML_PARAMS, "train.epochs")?;
if let XvcParamValue::Yaml(YamlValue::Number(n)) = train_epochs {
assert!(n.as_u64() == Some(10u64))
} else {
panic!("Mismatched Yaml Type: {}", train_epochs);
}
Ok(())
}
}