schema_registry_api/domain/schema/
version.rs1use std::fmt::Display;
2use std::num::NonZeroU32;
3use std::str::FromStr;
4
5use serde::de::Visitor;
6use serde::{Deserialize, Serialize};
7
8use crate::SchemaVersionError;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
33pub enum SchemaVersion {
34 Version(NonZeroU32),
38 Latest,
40}
41
42impl FromStr for SchemaVersion {
43 type Err = SchemaVersionError;
44
45 fn from_str(s: &str) -> Result<Self, Self::Err> {
46 if s == LATEST_STRING || s == "-1" {
47 return Ok(Self::Latest);
48 }
49 let number = s
50 .parse::<NonZeroU32>()
51 .map_err(|_| SchemaVersionError(s.to_string()))?;
52 Ok(Self::Version(number))
53 }
54}
55
56const LATEST_STRING: &str = "latest";
57
58impl Display for SchemaVersion {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 match self {
61 Self::Version(v) => write!(f, "{v}"),
62 Self::Latest => write!(f, "{LATEST_STRING}"),
63 }
64 }
65}
66
67impl Serialize for SchemaVersion {
68 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
69 where
70 S: serde::Serializer,
71 {
72 match self {
73 SchemaVersion::Version(v) => serializer.serialize_u32(v.get()),
74 SchemaVersion::Latest => serializer.serialize_str(LATEST_STRING),
75 }
76 }
77}
78
79struct SchemaVersionVisitor;
80impl<'de> Visitor<'de> for SchemaVersionVisitor {
81 type Value = SchemaVersion;
82
83 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
84 formatter.write_str("Expected an u32 version number of the \"latest\" string")
85 }
86
87 fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
88 where
89 E: serde::de::Error,
90 {
91 let i = u32::try_from(v).map_err(serde::de::Error::custom)?;
92 let version = NonZeroU32::try_from(i).map_err(serde::de::Error::custom)?;
93 Ok(Self::Value::Version(version))
94 }
95
96 fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
97 where
98 E: serde::de::Error,
99 {
100 if v == -1 {
101 Ok(Self::Value::Latest)
102 } else {
103 let msg = format!("Invalid negative value, only support -1 as 'latest', got {v}");
104 Err(serde::de::Error::custom(msg))
105 }
106 }
107
108 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
109 where
110 E: serde::de::Error,
111 {
112 if v == LATEST_STRING {
113 Ok(Self::Value::Latest)
114 } else {
115 let msg = format!("Expected 'latest', got {v}");
116 Err(serde::de::Error::custom(msg))
117 }
118 }
119}
120
121impl<'de> Deserialize<'de> for SchemaVersion {
122 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
123 where
124 D: serde::Deserializer<'de>,
125 {
126 deserializer.deserialize_any(SchemaVersionVisitor)
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use assert2::{check, let_assert};
133
134 use super::*;
135
136 #[test]
137 fn should_serde_version() {
138 let json = "42";
139 let version = serde_json::from_str::<SchemaVersion>(json).unwrap();
140 check!(version == SchemaVersion::Version(NonZeroU32::new(42).unwrap()));
141 let s = serde_json::to_string(&version).unwrap();
142 check!(s == json);
143 }
144
145 #[test]
146 fn should_serde_latest_version() {
147 let json = "\"latest\"";
148 let version = serde_json::from_str::<SchemaVersion>(json).unwrap();
149 check!(version == SchemaVersion::Latest);
150 let s = serde_json::to_string(&version).unwrap();
151 check!(s == json);
152 }
153
154 #[test]
155 fn should_serde_latest_version_minus_one() {
156 let json = "-1";
157 let version = serde_json::from_str::<SchemaVersion>(json).unwrap();
158 check!(version == SchemaVersion::Latest);
159 }
160
161 #[rstest::rstest]
162 #[case::bad_string("plop")]
163 #[case::zero("0")]
164 #[case::negative("-2")]
165 #[case::too_big("4294967296")]
166 fn should_not_serde(#[case] json: &str) {
167 let result = serde_json::from_str::<SchemaVersion>(json);
168 let_assert!(Err(_) = result);
169 }
170}