1use crate::{Error, Result, Value};
7use std::collections::HashMap;
8use std::fmt;
9
10fn percent_decode(s: &str) -> String {
16 if !s.contains('%') {
17 return s.to_string();
18 }
19 let bytes = s.as_bytes();
20 let mut out: Vec<u8> = Vec::with_capacity(bytes.len());
21 let mut i = 0;
22 while i < bytes.len() {
23 if bytes[i] == b'%' && i + 2 < bytes.len() {
24 let hi = (bytes[i + 1] as char).to_digit(16);
25 let lo = (bytes[i + 2] as char).to_digit(16);
26 if let (Some(h), Some(l)) = (hi, lo) {
27 out.push(((h << 4) | l) as u8);
28 i += 3;
29 continue;
30 }
31 }
32 out.push(bytes[i]);
33 i += 1;
34 }
35 String::from_utf8_lossy(&out).into_owned()
36}
37
38#[derive(Debug, Clone, PartialEq, Eq, Hash)]
40pub enum TagHandle {
41 Primary,
43 Secondary,
45 Named(String),
47 Verbatim,
49}
50
51impl fmt::Display for TagHandle {
52 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53 match self {
54 Self::Primary => write!(f, "!"),
55 Self::Secondary => write!(f, "!!"),
56 Self::Named(name) => write!(f, "!{}!", name),
57 Self::Verbatim => write!(f, "!<>"),
58 }
59 }
60}
61
62#[derive(Debug, Clone, PartialEq, Eq, Hash)]
64pub struct Tag {
65 pub uri: String,
67 pub original: String,
69 pub kind: TagKind,
71}
72
73#[derive(Debug, Clone, PartialEq, Eq, Hash)]
75#[allow(missing_docs)]
76pub enum TagKind {
77 Null,
79 Bool,
80 Int,
81 Float,
82 Str,
83 Seq,
85 Map,
86 Binary,
88 Timestamp,
89 Set,
90 Omap,
91 Pairs,
92 Custom(String),
94}
95
96pub struct TagResolver {
98 directives: HashMap<String, String>,
100 handlers: HashMap<String, Box<dyn TagHandler>>,
102 schema: Schema,
104}
105
106impl fmt::Debug for TagResolver {
107 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108 f.debug_struct("TagResolver")
109 .field("directives", &self.directives)
110 .field("handlers_count", &self.handlers.len())
111 .field("schema", &self.schema)
112 .finish()
113 }
114}
115
116impl TagResolver {
117 pub fn new() -> Self {
119 Self::with_schema(Schema::Core)
120 }
121
122 pub fn with_schema(schema: Schema) -> Self {
124 let mut resolver = Self {
125 directives: HashMap::new(),
126 handlers: HashMap::new(),
127 schema,
128 };
129
130 resolver.directives.insert("!".to_string(), "!".to_string());
132 resolver
133 .directives
134 .insert("!!".to_string(), "tag:yaml.org,2002:".to_string());
135
136 resolver
137 }
138
139 pub fn add_directive(&mut self, handle: String, prefix: String) {
141 self.directives.insert(handle, prefix);
142 }
143
144 pub fn clear_directives(&mut self) {
146 self.directives.clear();
147 self.directives.insert("!".to_string(), "!".to_string());
149 self.directives
150 .insert("!!".to_string(), "tag:yaml.org,2002:".to_string());
151 }
152
153 pub fn register_handler(&mut self, tag_uri: String, handler: Box<dyn TagHandler>) {
155 self.handlers.insert(tag_uri, handler);
156 }
157
158 pub fn resolve(&self, tag_str: &str) -> Result<Tag> {
160 let (uri, original) = if tag_str.starts_with("tag:") {
161 (tag_str.to_string(), tag_str.to_string())
163 } else if tag_str.starts_with("!<") && tag_str.ends_with('>') {
164 let uri = tag_str[2..tag_str.len() - 1].to_string();
166 (uri, tag_str.to_string())
167 } else if tag_str.starts_with("!!") {
168 let suffix = &tag_str[2..];
170 let prefix = self
171 .directives
172 .get("!!")
173 .cloned()
174 .unwrap_or_else(|| "tag:yaml.org,2002:".to_string());
175 (
176 format!("{}{}", prefix, percent_decode(suffix)),
177 tag_str.to_string(),
178 )
179 } else if tag_str.starts_with('!') {
180 if let Some(end) = tag_str[1..].find('!') {
182 let handle_name = &tag_str[1..end + 1];
183 let handle = format!("!{}!", handle_name);
184 let suffix = &tag_str[end + 2..];
185
186 if let Some(prefix) = self.directives.get(&handle) {
187 (
188 format!("{}{}", prefix, percent_decode(suffix)),
189 tag_str.to_string(),
190 )
191 } else {
192 return Err(crate::Error::parse(
197 crate::Position::start(),
198 format!("Undefined tag handle `{handle}`"),
199 ));
200 }
201 } else {
202 let suffix = &tag_str[1..];
204 let prefix = self
205 .directives
206 .get("!")
207 .cloned()
208 .unwrap_or_else(|| "!".to_string());
209 (
210 format!("{}{}", prefix, percent_decode(suffix)),
211 tag_str.to_string(),
212 )
213 }
214 } else {
215 (
217 self.schema.default_tag_for(tag_str),
218 format!("!{}", tag_str),
219 )
220 };
221
222 let kind = Self::identify_tag_kind(&uri);
223
224 Ok(Tag {
225 uri,
226 original,
227 kind,
228 })
229 }
230
231 fn identify_tag_kind(uri: &str) -> TagKind {
233 match uri {
234 "tag:yaml.org,2002:null" => TagKind::Null,
235 "tag:yaml.org,2002:bool" => TagKind::Bool,
236 "tag:yaml.org,2002:int" => TagKind::Int,
237 "tag:yaml.org,2002:float" => TagKind::Float,
238 "tag:yaml.org,2002:str" => TagKind::Str,
239 "tag:yaml.org,2002:seq" => TagKind::Seq,
240 "tag:yaml.org,2002:map" => TagKind::Map,
241 "tag:yaml.org,2002:binary" => TagKind::Binary,
242 "tag:yaml.org,2002:timestamp" => TagKind::Timestamp,
243 "tag:yaml.org,2002:set" => TagKind::Set,
244 "tag:yaml.org,2002:omap" => TagKind::Omap,
245 "tag:yaml.org,2002:pairs" => TagKind::Pairs,
246 _ => TagKind::Custom(uri.to_string()),
247 }
248 }
249
250 pub fn apply_tag(&self, tag: &Tag, value: &str) -> Result<Value> {
252 if let Some(handler) = self.handlers.get(&tag.uri) {
254 return handler.construct(value);
255 }
256
257 match &tag.kind {
259 TagKind::Null => Ok(Value::Null),
260 TagKind::Bool => self.construct_bool(value),
261 TagKind::Int => self.construct_int(value),
262 TagKind::Float => self.construct_float(value),
263 TagKind::Str => Ok(Value::String(value.to_string())),
264 TagKind::Binary => self.construct_binary(value),
265 TagKind::Timestamp => self.construct_timestamp(value),
266 _ => Ok(Value::String(value.to_string())), }
268 }
269
270 fn construct_bool(&self, value: &str) -> Result<Value> {
272 match value.to_lowercase().as_str() {
273 "true" | "yes" | "on" => Ok(Value::Bool(true)),
274 "false" | "no" | "off" => Ok(Value::Bool(false)),
275 _ => Err(Error::Type {
276 expected: "boolean".to_string(),
277 found: format!("'{}'", value),
278 position: crate::Position::start(),
279 context: None,
280 }),
281 }
282 }
283
284 fn construct_int(&self, value: &str) -> Result<Value> {
286 let parsed = if value.starts_with("0x") || value.starts_with("0X") {
288 i64::from_str_radix(&value[2..], 16)
290 } else if value.starts_with("0o") || value.starts_with("0O") {
291 i64::from_str_radix(&value[2..], 8)
293 } else if value.starts_with("0b") || value.starts_with("0B") {
294 i64::from_str_radix(&value[2..], 2)
296 } else {
297 value.replace('_', "").parse::<i64>()
299 };
300
301 parsed.map(Value::Int).map_err(|_| Error::Type {
302 expected: "integer".to_string(),
303 found: format!("'{}'", value),
304 position: crate::Position::start(),
305 context: None,
306 })
307 }
308
309 fn construct_float(&self, value: &str) -> Result<Value> {
311 match value.to_lowercase().as_str() {
312 ".inf" | "+.inf" => Ok(Value::Float(f64::INFINITY)),
313 "-.inf" => Ok(Value::Float(f64::NEG_INFINITY)),
314 ".nan" => Ok(Value::Float(f64::NAN)),
315 _ => value
316 .replace('_', "")
317 .parse::<f64>()
318 .map(Value::Float)
319 .map_err(|_| Error::Type {
320 expected: "float".to_string(),
321 found: format!("'{}'", value),
322 position: crate::Position::start(),
323 context: None,
324 }),
325 }
326 }
327
328 fn construct_binary(&self, value: &str) -> Result<Value> {
330 use base64::{Engine as _, engine::general_purpose::STANDARD};
331
332 let clean = value
334 .chars()
335 .filter(|c| !c.is_whitespace())
336 .collect::<String>();
337
338 match STANDARD.decode(&clean) {
339 Ok(bytes) => {
340 match String::from_utf8(bytes) {
342 Ok(s) => Ok(Value::String(s)),
343 Err(_) => Ok(Value::String(format!(
344 "[binary data: {} bytes]",
345 clean.len() / 4 * 3
346 ))),
347 }
348 }
349 Err(_) => Err(Error::Type {
350 expected: "base64-encoded binary".to_string(),
351 found: format!("invalid base64: '{}'", value),
352 position: crate::Position::start(),
353 context: None,
354 }),
355 }
356 }
357
358 fn construct_timestamp(&self, value: &str) -> Result<Value> {
360 Ok(Value::String(format!("timestamp:{}", value)))
363 }
364}
365
366impl Default for TagResolver {
367 fn default() -> Self {
368 Self::new()
369 }
370}
371
372#[derive(Debug, Clone, Copy, PartialEq, Eq)]
374pub enum Schema {
375 Core,
377 Json,
379 Failsafe,
381}
382
383impl Schema {
384 pub fn default_tag_for(&self, _value: &str) -> String {
386 match self {
387 Self::Core => "tag:yaml.org,2002:str".to_string(),
388 Self::Json => "tag:yaml.org,2002:str".to_string(),
389 Self::Failsafe => "tag:yaml.org,2002:str".to_string(),
390 }
391 }
392
393 pub fn allows_implicit_typing(&self) -> bool {
395 match self {
396 Self::Core => true,
397 Self::Json => true,
398 Self::Failsafe => false,
399 }
400 }
401}
402
403pub trait TagHandler: Send + Sync {
405 fn construct(&self, value: &str) -> Result<Value>;
407
408 fn represent(&self, value: &Value) -> Result<String>;
410}
411
412pub struct PointTagHandler;
414
415impl TagHandler for PointTagHandler {
416 fn construct(&self, value: &str) -> Result<Value> {
417 let parts: Vec<&str> = value.split(',').collect();
419 if parts.len() != 2 {
420 return Err(Error::Type {
421 expected: "point (x,y)".to_string(),
422 found: value.to_string(),
423 position: crate::Position::start(),
424 context: None,
425 });
426 }
427
428 let x = parts[0].trim().parse::<f64>().map_err(|_| Error::Type {
429 expected: "number".to_string(),
430 found: parts[0].to_string(),
431 position: crate::Position::start(),
432 context: None,
433 })?;
434
435 let y = parts[1].trim().parse::<f64>().map_err(|_| Error::Type {
436 expected: "number".to_string(),
437 found: parts[1].to_string(),
438 position: crate::Position::start(),
439 context: None,
440 })?;
441
442 Ok(Value::Sequence(vec![Value::Float(x), Value::Float(y)]))
444 }
445
446 fn represent(&self, value: &Value) -> Result<String> {
447 if let Value::Sequence(seq) = value {
448 if seq.len() == 2 {
449 if let (Some(Value::Float(x)), Some(Value::Float(y))) = (seq.get(0), seq.get(1)) {
450 return Ok(format!("{},{}", x, y));
451 }
452 }
453 }
454 Err(Error::Type {
455 expected: "point sequence".to_string(),
456 found: format!("{:?}", value),
457 position: crate::Position::start(),
458 context: None,
459 })
460 }
461}
462
463#[cfg(test)]
464mod tests {
465 use super::*;
466
467 #[test]
468 fn test_tag_resolution() {
469 let mut resolver = TagResolver::new();
470
471 let tag = resolver.resolve("!!str").unwrap();
473 assert_eq!(tag.uri, "tag:yaml.org,2002:str");
474 assert_eq!(tag.kind, TagKind::Str);
475
476 let tag = resolver.resolve("!!int").unwrap();
477 assert_eq!(tag.uri, "tag:yaml.org,2002:int");
478 assert_eq!(tag.kind, TagKind::Int);
479
480 resolver.add_directive("!".to_string(), "tag:example.com,2024:".to_string());
482 let tag = resolver.resolve("!custom").unwrap();
483 assert_eq!(tag.uri, "tag:example.com,2024:custom");
484
485 resolver.add_directive("!e!".to_string(), "tag:example.com,2024:".to_string());
487 let tag = resolver.resolve("!e!widget").unwrap();
488 assert_eq!(tag.uri, "tag:example.com,2024:widget");
489
490 let tag = resolver.resolve("!<tag:explicit.com,2024:type>").unwrap();
492 assert_eq!(tag.uri, "tag:explicit.com,2024:type");
493 }
494
495 #[test]
496 fn test_tag_construction() {
497 let resolver = TagResolver::new();
498
499 let tag = Tag {
501 uri: "tag:yaml.org,2002:bool".to_string(),
502 original: "!!bool".to_string(),
503 kind: TagKind::Bool,
504 };
505
506 assert_eq!(resolver.apply_tag(&tag, "true").unwrap(), Value::Bool(true));
507 assert_eq!(
508 resolver.apply_tag(&tag, "false").unwrap(),
509 Value::Bool(false)
510 );
511 assert_eq!(resolver.apply_tag(&tag, "yes").unwrap(), Value::Bool(true));
512 assert_eq!(resolver.apply_tag(&tag, "no").unwrap(), Value::Bool(false));
513
514 let tag = Tag {
516 uri: "tag:yaml.org,2002:int".to_string(),
517 original: "!!int".to_string(),
518 kind: TagKind::Int,
519 };
520
521 assert_eq!(resolver.apply_tag(&tag, "42").unwrap(), Value::Int(42));
522 assert_eq!(resolver.apply_tag(&tag, "0x2A").unwrap(), Value::Int(42));
523 assert_eq!(resolver.apply_tag(&tag, "0o52").unwrap(), Value::Int(42));
524 assert_eq!(
525 resolver.apply_tag(&tag, "0b101010").unwrap(),
526 Value::Int(42)
527 );
528 assert_eq!(resolver.apply_tag(&tag, "1_234").unwrap(), Value::Int(1234));
529
530 let tag = Tag {
532 uri: "tag:yaml.org,2002:float".to_string(),
533 original: "!!float".to_string(),
534 kind: TagKind::Float,
535 };
536
537 assert_eq!(
538 resolver.apply_tag(&tag, "3.14").unwrap(),
539 Value::Float(3.14)
540 );
541 assert_eq!(
542 resolver.apply_tag(&tag, ".inf").unwrap(),
543 Value::Float(f64::INFINITY)
544 );
545 assert_eq!(
546 resolver.apply_tag(&tag, "-.inf").unwrap(),
547 Value::Float(f64::NEG_INFINITY)
548 );
549 assert!(matches!(resolver.apply_tag(&tag, ".nan").unwrap(), Value::Float(f) if f.is_nan()));
550 }
551
552 #[test]
553 fn test_custom_tag_handler() {
554 let mut resolver = TagResolver::new();
555
556 resolver.register_handler(
558 "tag:example.com,2024:point".to_string(),
559 Box::new(PointTagHandler),
560 );
561
562 resolver.add_directive("!".to_string(), "tag:example.com,2024:".to_string());
564 let tag = resolver.resolve("!point").unwrap();
565
566 let value = resolver.apply_tag(&tag, "3.5, 7.2").unwrap();
567 assert_eq!(
568 value,
569 Value::Sequence(vec![Value::Float(3.5), Value::Float(7.2)])
570 );
571 }
572}