shopify_sdk/auth/
scopes.rs1use crate::error::ConfigError;
7use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
8use std::collections::HashSet;
9use std::fmt;
10use std::str::FromStr;
11
12#[derive(Clone, Debug, PartialEq, Eq, Default)]
51pub struct AuthScopes {
52 scopes: HashSet<String>,
53}
54
55impl AuthScopes {
56 #[must_use]
58 pub fn new() -> Self {
59 Self::default()
60 }
61
62 #[must_use]
64 pub fn is_empty(&self) -> bool {
65 self.scopes.is_empty()
66 }
67
68 #[must_use]
73 pub fn covers(&self, other: &Self) -> bool {
74 other.scopes.iter().all(|s| self.scopes.contains(s))
75 }
76
77 pub fn iter(&self) -> impl Iterator<Item = &str> {
79 self.scopes.iter().map(String::as_str)
80 }
81
82 fn add_implied_scopes(&mut self) {
87 let implied: Vec<String> = self
88 .scopes
89 .iter()
90 .filter_map(|scope| Self::get_implied_scope(scope))
91 .collect();
92
93 for scope in implied {
94 self.scopes.insert(scope);
95 }
96 }
97
98 fn get_implied_scope(scope: &str) -> Option<String> {
99 scope
100 .strip_prefix("unauthenticated_write_")
101 .map(|rest| format!("unauthenticated_read_{rest}"))
102 .or_else(|| {
103 scope
104 .strip_prefix("write_")
105 .map(|rest| format!("read_{rest}"))
106 })
107 }
108}
109
110impl FromStr for AuthScopes {
111 type Err = ConfigError;
112
113 fn from_str(s: &str) -> Result<Self, Self::Err> {
114 let mut scopes = HashSet::new();
115
116 for scope in s.split(',') {
117 let scope = scope.trim();
118 if scope.is_empty() {
119 continue;
120 }
121
122 if !scope.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') {
124 return Err(ConfigError::InvalidScopes {
125 reason: format!("Invalid characters in scope: '{scope}'"),
126 });
127 }
128
129 scopes.insert(scope.to_string());
130 }
131
132 let mut auth_scopes = Self { scopes };
133 auth_scopes.add_implied_scopes();
134
135 Ok(auth_scopes)
136 }
137}
138
139impl From<Vec<String>> for AuthScopes {
140 fn from(scopes: Vec<String>) -> Self {
141 let scopes: HashSet<String> = scopes
142 .into_iter()
143 .map(|s| s.trim().to_string())
144 .filter(|s| !s.is_empty())
145 .collect();
146
147 let mut auth_scopes = Self { scopes };
148 auth_scopes.add_implied_scopes();
149
150 auth_scopes
151 }
152}
153
154impl fmt::Display for AuthScopes {
155 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
156 let mut scopes: Vec<&str> = self.scopes.iter().map(String::as_str).collect();
157 scopes.sort_unstable();
158 write!(f, "{}", scopes.join(","))
159 }
160}
161
162impl Serialize for AuthScopes {
163 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
164 where
165 S: Serializer,
166 {
167 serializer.serialize_str(&self.to_string())
169 }
170}
171
172impl<'de> Deserialize<'de> for AuthScopes {
173 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
174 where
175 D: Deserializer<'de>,
176 {
177 let s = String::deserialize(deserializer)?;
178 s.parse().map_err(de::Error::custom)
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185
186 #[test]
187 fn test_auth_scopes_parses_comma_separated() {
188 let scopes: AuthScopes = "read_products, write_orders".parse().unwrap();
189 assert!(scopes.iter().any(|s| s == "read_products"));
190 assert!(scopes.iter().any(|s| s == "write_orders"));
191 }
192
193 #[test]
194 fn test_auth_scopes_handles_implied_scopes() {
195 let scopes: AuthScopes = "write_products".parse().unwrap();
196
197 assert!(scopes.iter().any(|s| s == "write_products"));
199 assert!(scopes.iter().any(|s| s == "read_products"));
200 }
201
202 #[test]
203 fn test_auth_scopes_handles_unauthenticated_implied() {
204 let scopes: AuthScopes = "unauthenticated_write_products".parse().unwrap();
205
206 assert!(scopes.iter().any(|s| s == "unauthenticated_write_products"));
208 assert!(scopes.iter().any(|s| s == "unauthenticated_read_products"));
209 }
210
211 #[test]
212 fn test_auth_scopes_covers() {
213 let scopes: AuthScopes = "read_products, write_orders".parse().unwrap();
214 let required: AuthScopes = "read_products".parse().unwrap();
215
216 assert!(scopes.covers(&required));
217
218 let more_required: AuthScopes = "read_products, read_customers".parse().unwrap();
219 assert!(!scopes.covers(&more_required));
220 }
221
222 #[test]
223 fn test_auth_scopes_is_empty() {
224 let empty = AuthScopes::new();
225 assert!(empty.is_empty());
226
227 let scopes: AuthScopes = "read_products".parse().unwrap();
228 assert!(!scopes.is_empty());
229 }
230
231 #[test]
232 fn test_auth_scopes_from_vec() {
233 let scopes = AuthScopes::from(vec![
234 "read_products".to_string(),
235 "write_orders".to_string(),
236 ]);
237 assert!(scopes.iter().any(|s| s == "read_products"));
238 assert!(scopes.iter().any(|s| s == "write_orders"));
239 assert!(scopes.iter().any(|s| s == "read_orders"));
241 }
242
243 #[test]
246 fn test_auth_scopes_serializes_to_comma_separated_string() {
247 let scopes: AuthScopes = "read_products,write_orders".parse().unwrap();
248 let json = serde_json::to_string(&scopes).unwrap();
249 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
252 assert!(parsed.is_string());
253 let scope_str = parsed.as_str().unwrap();
254 assert!(scope_str.contains("read_products"));
255 assert!(scope_str.contains("write_orders"));
256 assert!(scope_str.contains("read_orders")); }
258
259 #[test]
260 fn test_auth_scopes_deserializes_from_comma_separated_string() {
261 let json = r#""read_products,write_orders""#;
262 let scopes: AuthScopes = serde_json::from_str(json).unwrap();
263 assert!(scopes.iter().any(|s| s == "read_products"));
264 assert!(scopes.iter().any(|s| s == "write_orders"));
265 assert!(scopes.iter().any(|s| s == "read_orders")); }
267
268 #[test]
269 fn test_empty_auth_scopes_serializes_to_empty_string() {
270 let scopes = AuthScopes::new();
271 let json = serde_json::to_string(&scopes).unwrap();
272 assert_eq!(json, r#""""#);
273 }
274
275 #[test]
276 fn test_auth_scopes_round_trip_serialization() {
277 let original: AuthScopes = "read_products,write_orders,read_customers".parse().unwrap();
278 let json = serde_json::to_string(&original).unwrap();
279 let deserialized: AuthScopes = serde_json::from_str(&json).unwrap();
280 assert_eq!(original, deserialized);
281 }
282}