xidl_parser/rest_hir/semantics/
cors.rs1use crate::hir;
2use serde::{Deserialize, Serialize};
3
4use super::annotations::{annotation_name, annotation_params};
5
6#[cfg(test)]
7mod tests;
8
9#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
10pub enum HttpCorsProfile {
11 Any,
12 Origins(Vec<String>),
13}
14
15pub fn effective_cors(
16 interface_annotations: &[hir::Annotation],
17 method_annotations: &[hir::Annotation],
18) -> Result<Option<HttpCorsProfile>, String> {
19 collect_cors(method_annotations)?.map_or_else(
20 || collect_cors(interface_annotations),
21 |profile| Ok(Some(profile)),
22 )
23}
24
25pub(crate) fn collect_cors(
26 annotations: &[hir::Annotation],
27) -> Result<Option<HttpCorsProfile>, String> {
28 let mut matches = annotations.iter().filter(|annotation| {
29 annotation_name(annotation)
30 .map(|name| name.eq_ignore_ascii_case("cors"))
31 .unwrap_or(false)
32 });
33 let Some(annotation) = matches.next() else {
34 return Ok(None);
35 };
36 if matches.next().is_some() {
37 return Err("duplicate @cors annotation".to_string());
38 }
39 parse_cors(annotation).map(Some)
40}
41
42fn parse_cors(annotation: &hir::Annotation) -> Result<HttpCorsProfile, String> {
43 let Some(params) = annotation_params(annotation) else {
44 return Ok(HttpCorsProfile::Any);
45 };
46 match params {
47 hir::AnnotationParams::ConstExpr(expr) => {
48 Ok(HttpCorsProfile::Origins(parse_const_expr_origins(expr)?))
49 }
50 hir::AnnotationParams::Raw(_) | hir::AnnotationParams::Params(_) => {
51 Err("@cors only accepts string literals joined by '|'".to_string())
52 }
53 }
54}
55
56fn parse_const_expr_origins(expr: &hir::ConstExpr) -> Result<Vec<String>, String> {
57 match expr {
58 hir::ConstExpr::Literal(hir::Literal::StringLiteral(value)) => {
59 Ok(vec![parse_origin_literal(value)?])
60 }
61 hir::ConstExpr::BinaryExpr(hir::BinaryOperator::Or, left, right) => {
62 let mut origins = parse_const_expr_origins(left)?;
63 origins.extend(parse_const_expr_origins(right)?);
64 Ok(origins)
65 }
66 _ => Err("@cors only accepts string literals joined by '|'".to_string()),
67 }
68}
69
70fn parse_origin_literal(value: &str) -> Result<String, String> {
71 let Some(value) = trim_string_literal(value) else {
72 return Err("@cors only accepts string literals joined by '|'".to_string());
73 };
74 if value.is_empty() {
75 return Err("@cors origins must not be empty".to_string());
76 }
77 if !is_valid_origin(&value) {
78 return Err(format!("invalid @cors origin '{value}'"));
79 }
80 Ok(value)
81}
82
83fn is_valid_origin(value: &str) -> bool {
84 value == "*"
85 || (value.is_ascii()
86 && !value.bytes().any(|byte| byte.is_ascii_control())
87 && !value.is_empty())
88}
89
90fn trim_string_literal(value: &str) -> Option<String> {
91 let value = value.trim();
92 if value.len() >= 2 && value.starts_with('"') && value.ends_with('"') {
93 Some(value[1..value.len() - 1].to_string())
94 } else {
95 None
96 }
97}