Skip to main content

xidl_parser/rest_hir/semantics/
stream.rs

1use crate::hir;
2use serde::{Deserialize, Serialize};
3
4use super::annotations::{annotation_name, annotation_params, normalize_annotation_params};
5
6#[cfg(test)]
7mod tests;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
10pub enum HttpStreamKind {
11    Server,
12    Client,
13    Bidi,
14}
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
17pub enum HttpStreamCodec {
18    Sse,
19    Ndjson,
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
23pub struct HttpStreamConfig {
24    pub kind: Option<HttpStreamKind>,
25    pub codec: HttpStreamCodec,
26}
27
28#[derive(Debug, Clone, Copy)]
29pub struct HttpStreamTargetSupport<'a> {
30    pub target: &'a str,
31    pub supports_bidi: bool,
32    pub server_codec: HttpStreamCodec,
33    pub client_codec: HttpStreamCodec,
34    pub server_method: &'a str,
35    pub client_method: &'a str,
36    pub bidi_method: &'a str,
37}
38
39pub fn http_stream_config(annotations: &[hir::Annotation]) -> Result<HttpStreamConfig, String> {
40    let kind = stream_kind(annotations)?;
41    let mut codec = match kind {
42        Some(HttpStreamKind::Server) => HttpStreamCodec::Sse,
43        Some(HttpStreamKind::Client | HttpStreamKind::Bidi) | None => HttpStreamCodec::Ndjson,
44    };
45    for annotation in annotations {
46        let Some(name) = annotation_name(annotation) else {
47            continue;
48        };
49        if !name.eq_ignore_ascii_case("stream_codec") {
50            continue;
51        }
52        let value = annotation_params(annotation)
53            .map(normalize_annotation_params)
54            .and_then(|params| params.get("value").cloned())
55            .unwrap_or_else(|| "sse".to_string());
56        codec = match value.to_ascii_lowercase().as_str() {
57            "sse" => HttpStreamCodec::Sse,
58            "ndjson" => HttpStreamCodec::Ndjson,
59            other => {
60                return Err(format!(
61                    "unsupported @stream_codec value '{other}', expected 'sse' or 'ndjson'"
62                ));
63            }
64        };
65    }
66    Ok(HttpStreamConfig { kind, codec })
67}
68
69pub fn validate_http_stream_target(
70    op_name: &str,
71    config: HttpStreamConfig,
72    support: HttpStreamTargetSupport<'_>,
73) -> Result<(), String> {
74    match config.kind {
75        Some(HttpStreamKind::Server) if config.codec != support.server_codec => Err(format!(
76            "{} currently supports only {} for @server_stream methods: '{}'",
77            support.target,
78            stream_codec_name(support.server_codec),
79            op_name
80        )),
81        Some(HttpStreamKind::Client) if config.codec != support.client_codec => Err(format!(
82            "{} currently supports only {} for @client_stream methods: '{}'",
83            support.target,
84            stream_codec_name(support.client_codec),
85            op_name
86        )),
87        Some(HttpStreamKind::Bidi) if !support.supports_bidi => Err(format!(
88            "{} currently does not support @bidi_stream methods: '{}'",
89            support.target, op_name
90        )),
91        Some(HttpStreamKind::Client | HttpStreamKind::Bidi)
92            if config.codec == HttpStreamCodec::Sse =>
93        {
94            Err(format!(
95                "@stream_codec(\"sse\") requires @server_stream on method '{}'",
96                op_name
97            ))
98        }
99        _ => Ok(()),
100    }
101}
102
103pub fn validate_http_stream_method(
104    op_name: &str,
105    kind: Option<HttpStreamKind>,
106    method: &str,
107    support: HttpStreamTargetSupport<'_>,
108) -> Result<(), String> {
109    let method = method.to_ascii_uppercase();
110    match kind {
111        Some(HttpStreamKind::Server) if method != support.server_method => Err(format!(
112            "@server_stream method '{}' must use {}",
113            op_name, support.server_method
114        )),
115        Some(HttpStreamKind::Client) if method != support.client_method => Err(format!(
116            "@client_stream method '{}' must use {}",
117            op_name, support.client_method
118        )),
119        Some(HttpStreamKind::Bidi) if method != support.bidi_method => Err(format!(
120            "@bidi_stream method '{}' must use {}",
121            op_name, support.bidi_method
122        )),
123        _ => Ok(()),
124    }
125}
126
127fn stream_kind(annotations: &[hir::Annotation]) -> Result<Option<HttpStreamKind>, String> {
128    let mut kind = None;
129    for annotation in annotations {
130        let Some(name) = annotation_name(annotation) else {
131            continue;
132        };
133        let current = if name.eq_ignore_ascii_case("server_stream") {
134            Some(HttpStreamKind::Server)
135        } else if name.eq_ignore_ascii_case("client_stream") {
136            Some(HttpStreamKind::Client)
137        } else if name.eq_ignore_ascii_case("bidi_stream") {
138            Some(HttpStreamKind::Bidi)
139        } else {
140            None
141        };
142        let Some(current) = current else {
143            continue;
144        };
145        match kind {
146            None => kind = Some(current),
147            Some(prev) if prev == current => {}
148            Some(_) => {
149                return Err(
150                    "@server_stream/@client_stream/@bidi_stream are mutually exclusive".to_string(),
151                );
152            }
153        }
154    }
155    Ok(kind)
156}
157
158fn stream_codec_name(codec: HttpStreamCodec) -> &'static str {
159    match codec {
160        HttpStreamCodec::Sse => "SSE",
161        HttpStreamCodec::Ndjson => "NDJSON",
162    }
163}