Skip to main content

spikard_cli/codegen/asyncapi/
spec_parser.rs

1//! `AsyncAPI` v3 specification parsing and extraction.
2//!
3//! This module handles parsing `AsyncAPI` v3 specs and extracting structured data
4//! for code generation, including channels, messages, operations, and metadata.
5
6use anyhow::{Context, Result};
7use asyncapiv3::spec::{AsyncApiSpec, AsyncApiV3Spec};
8use serde_json::Value;
9use std::collections::{HashMap, HashSet};
10use std::fs;
11use std::path::Path;
12
13/// Message definition with schema and examples
14#[derive(Debug, Clone)]
15pub struct MessageDefinition {
16    pub schema: Value,
17    pub examples: Vec<Value>,
18}
19
20/// Message operation metadata from `AsyncAPI` spec
21#[derive(Debug, Clone)]
22pub struct MessageOperationMetadata {
23    pub name: String,
24    pub action: String,
25    pub replies: Vec<String>,
26}
27
28/// Channel operation metadata from `AsyncAPI` spec
29#[derive(Debug, Clone)]
30#[allow(dead_code)]
31pub struct ChannelOperation {
32    pub name: String,
33    pub action: String,
34    pub messages: Vec<String>,
35    pub replies: Vec<String>,
36}
37
38/// Parse an `AsyncAPI` v3 specification file
39///
40/// Supports both JSON and YAML formats
41pub fn parse_asyncapi_schema(path: &Path) -> Result<AsyncApiV3Spec> {
42    let content =
43        fs::read_to_string(path).with_context(|| format!("Failed to read AsyncAPI file: {}", path.display()))?;
44
45    let spec: AsyncApiSpec = if path.extension().and_then(|s| s.to_str()) == Some("json") {
46        serde_json::from_str(&content)
47            .with_context(|| format!("Failed to parse AsyncAPI JSON from {}", path.display()))?
48    } else {
49        serde_saphyr::from_str(&content)
50            .with_context(|| format!("Failed to parse AsyncAPI YAML from {}", path.display()))?
51    };
52
53    match spec {
54        AsyncApiSpec::V3_0_0(v3_spec) => Ok(v3_spec),
55    }
56}
57
58/// Extract message schemas from `AsyncAPI` spec for fixture generation
59///
60/// Returns a map of message name -> JSON Schema for generating test fixtures
61pub fn extract_message_schemas(spec: &AsyncApiV3Spec) -> Result<HashMap<String, MessageDefinition>> {
62    use asyncapiv3::spec::common::Either;
63    use asyncapiv3::spec::{channel::Channel, message::Message};
64
65    let mut schemas = HashMap::new();
66    let spec_doc = serde_json::to_value(spec).context("Failed to serialize AsyncAPI spec for $ref resolution")?;
67
68    for (message_name, message_ref_or) in &spec.components.messages {
69        tracing::debug!("Processing message: {}", message_name);
70
71        match message_ref_or {
72            Either::Right(message) => {
73                if let Some(definition) = build_message_definition(message, message_name, &spec_doc)? {
74                    schemas.insert(message_name.clone(), definition);
75                }
76            }
77            Either::Left(reference) => {
78                if let Some(message) = resolve_ref_as::<Message>(&spec_doc, &reference.reference) {
79                    if let Some(definition) = build_message_definition(&message, message_name, &spec_doc)? {
80                        schemas.insert(message_name.clone(), definition);
81                    }
82                } else {
83                    tracing::debug!(
84                        "Skipping unresolved message reference: {} -> {}",
85                        message_name,
86                        reference.reference
87                    );
88                }
89            }
90        }
91    }
92
93    for (channel_name, channel_ref_or) in &spec.channels {
94        tracing::debug!("Processing channel: {}", channel_name);
95
96        match channel_ref_or {
97            Either::Right(channel) => {
98                process_channel_messages(channel_name, channel, &spec_doc, &mut schemas)?;
99            }
100            Either::Left(reference) => {
101                if let Some(channel) = resolve_ref_as::<Channel>(&spec_doc, &reference.reference) {
102                    process_channel_messages(channel_name, &channel, &spec_doc, &mut schemas)?;
103                } else {
104                    tracing::debug!("Skipping unresolved channel reference: {}", reference.reference);
105                }
106            }
107        }
108    }
109
110    Ok(schemas)
111}
112
113fn process_channel_messages(
114    channel_name: &str,
115    channel: &asyncapiv3::spec::channel::Channel,
116    spec_doc: &Value,
117    schemas: &mut HashMap<String, MessageDefinition>,
118) -> Result<()> {
119    use asyncapiv3::spec::common::Either;
120    use asyncapiv3::spec::message::Message;
121
122    for (msg_name, msg_ref_or) in &channel.messages {
123        let full_name = format!("{}_{}", channel_name.trim_start_matches('/'), msg_name);
124        match msg_ref_or {
125            Either::Right(message) => {
126                if let Some(definition) = build_message_definition(message, &full_name, spec_doc)? {
127                    schemas.insert(full_name, definition);
128                }
129            }
130            Either::Left(reference) => {
131                if let Some(message) = resolve_ref_as::<Message>(spec_doc, &reference.reference) {
132                    if let Some(definition) = build_message_definition(&message, &full_name, spec_doc)? {
133                        schemas.insert(full_name, definition);
134                    }
135                } else {
136                    tracing::debug!(
137                        "Channel {} message {} unresolved reference: {}",
138                        channel_name,
139                        msg_name,
140                        reference.reference
141                    );
142                }
143            }
144        }
145    }
146
147    Ok(())
148}
149
150fn build_message_definition(
151    message: &asyncapiv3::spec::message::Message,
152    message_name: &str,
153    spec_doc: &Value,
154) -> Result<Option<MessageDefinition>> {
155    let schema = match extract_schema_from_message(message, message_name, spec_doc)? {
156        Some(schema) => schema,
157        None => return Ok(None),
158    };
159    let schema = resolve_schema_tree(spec_doc, &schema, 32);
160
161    let mut examples: Vec<Value> = Vec::new();
162    for example in &message.examples {
163        if !example.payload.is_empty() {
164            let value = serde_json::to_value(&example.payload)
165                .context("Failed to serialize AsyncAPI message example payload")?;
166            examples.push(value);
167        }
168    }
169
170    if examples.is_empty() {
171        examples = generate_example_from_schema(&schema)?;
172    }
173
174    Ok(Some(MessageDefinition { schema, examples }))
175}
176
177/// Extract JSON Schema from an `AsyncAPI` Message object
178fn extract_schema_from_message(
179    message: &asyncapiv3::spec::message::Message,
180    message_name: &str,
181    spec_doc: &Value,
182) -> Result<Option<Value>> {
183    use asyncapiv3::spec::common::Either;
184
185    let payload = if let Some(payload_ref_or) = &message.payload {
186        payload_ref_or
187    } else {
188        tracing::debug!("Message {} has no payload", message_name);
189        return Ok(None);
190    };
191
192    match payload {
193        Either::Right(schema_or_multiformat) => match schema_or_multiformat {
194            Either::Left(schema) => {
195                let schema_json =
196                    serde_json::to_value(schema).context("Failed to serialize schemars::Schema to JSON")?;
197                Ok(Some(schema_json))
198            }
199            Either::Right(multi_format) => Ok(Some(multi_format.schema.clone())),
200        },
201        Either::Left(reference) => {
202            if let Some(resolved) = resolve_ref_value(spec_doc, &reference.reference) {
203                Ok(Some(normalize_schema_ref_value(resolved)))
204            } else {
205                tracing::debug!(
206                    "Message {} payload has unresolved reference: {}",
207                    message_name,
208                    reference.reference
209                );
210                Ok(None)
211            }
212        }
213    }
214}
215
216/// Generate example data from JSON Schema
217///
218/// Creates a simple valid example based on the schema properties
219pub fn generate_example_from_schema(schema: &Value) -> Result<Vec<Value>> {
220    let mut examples = Vec::new();
221
222    if let Some(schema_examples) = schema.get("examples").and_then(|e| e.as_array()) {
223        examples.extend(schema_examples.clone());
224    }
225
226    if examples.is_empty()
227        && schema
228            .get("type")
229            .and_then(|value| value.as_str())
230            .is_some_and(|ty| ty.eq_ignore_ascii_case("array"))
231    {
232        if let Some(items) = schema.get("items") {
233            let generated = generate_example_from_schema(items)?;
234            let template = generated
235                .into_iter()
236                .next()
237                .unwrap_or_else(|| Value::Object(serde_json::Map::new()));
238            let min_items = schema.get("minItems").and_then(serde_json::Value::as_u64).unwrap_or(1);
239            let mut target_len = usize::try_from(min_items).unwrap_or(usize::MAX);
240            if target_len == 0 {
241                target_len = 1;
242            }
243            let capped_len = target_len.min(5);
244            let mut array_values = Vec::new();
245            for _ in 0..capped_len {
246                array_values.push(template.clone());
247            }
248            examples.push(Value::Array(array_values));
249        } else {
250            examples.push(Value::Array(vec![]));
251        }
252    }
253
254    if examples.is_empty()
255        && let Some(obj) = schema.get("properties").and_then(|p| p.as_object())
256    {
257        let mut example = serde_json::Map::new();
258
259        for (prop_name, prop_schema) in obj {
260            let example_value = if let Some(const_val) = prop_schema.get("const") {
261                const_val.clone()
262            } else if let Some(type_str) = prop_schema.get("type").and_then(|t| t.as_str()) {
263                match type_str {
264                    "string" => {
265                        if let Some(format) = prop_schema.get("format").and_then(|f| f.as_str()) {
266                            match format {
267                                "date-time" => Value::String("2024-01-15T10:30:00Z".to_string()),
268                                "date" => Value::String("2024-01-15".to_string()),
269                                "time" => Value::String("10:30:00".to_string()),
270                                "email" => Value::String("user@example.com".to_string()),
271                                "uri" => Value::String("https://example.com".to_string()),
272                                "uuid" => Value::String("550e8400-e29b-41d4-a716-446655440000".to_string()),
273                                _ => Value::String(format!("example_{prop_name}")),
274                            }
275                        } else {
276                            Value::String(format!("example_{prop_name}"))
277                        }
278                    }
279                    "number" => Value::Number(
280                        serde_json::Number::from_f64(std::f64::consts::PI)
281                            .unwrap_or_else(|| serde_json::Number::from(314)),
282                    ),
283                    "integer" => Value::Number(serde_json::Number::from(42)),
284                    "boolean" => Value::Bool(true),
285                    _ => Value::Null,
286                }
287            } else {
288                Value::Null
289            };
290
291            example.insert(prop_name.clone(), example_value);
292        }
293
294        examples.push(Value::Object(example));
295    }
296
297    if examples.is_empty() {
298        examples.push(Value::Object(serde_json::Map::new()));
299    }
300
301    Ok(examples)
302}
303
304/// Protocol types supported by `AsyncAPI`
305#[derive(Debug, Clone, Copy, PartialEq, Eq)]
306pub enum Protocol {
307    WebSocket,
308    Sse,
309    Http,
310    Kafka,
311    Mqtt,
312    Amqp,
313    Other,
314}
315
316impl Protocol {
317    /// Detect protocol from `AsyncAPI` server definition
318    #[must_use]
319    pub fn from_protocol_string(protocol: &str) -> Self {
320        match protocol.to_lowercase().as_str() {
321            "ws" | "wss" | "websocket" | "websockets" => Self::WebSocket,
322            "sse" | "server-sent-events" => Self::Sse,
323            "http" | "https" => Self::Http,
324            "kafka" => Self::Kafka,
325            "mqtt" => Self::Mqtt,
326            "amqp" => Self::Amqp,
327            _ => Self::Other,
328        }
329    }
330
331    #[must_use]
332    pub const fn as_str(&self) -> &'static str {
333        match self {
334            Self::WebSocket => "websocket",
335            Self::Sse => "sse",
336            Self::Http => "http",
337            Self::Kafka => "kafka",
338            Self::Mqtt => "mqtt",
339            Self::Amqp => "amqp",
340            Self::Other => "other",
341        }
342    }
343}
344
345impl std::fmt::Display for Protocol {
346    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
347        write!(f, "{}", self.as_str())
348    }
349}
350
351/// Determine primary protocol from `AsyncAPI` spec
352pub fn detect_primary_protocol(spec: &AsyncApiV3Spec) -> Result<Protocol> {
353    use asyncapiv3::spec::common::Either;
354    use asyncapiv3::spec::server::Server;
355
356    let spec_doc =
357        serde_json::to_value(spec).context("Failed to serialize AsyncAPI spec for server $ref resolution")?;
358
359    for server_or_ref in spec.servers.values() {
360        match server_or_ref {
361            Either::Right(server) => {
362                let protocol = Protocol::from_protocol_string(&server.protocol);
363                tracing::debug!("Detected protocol: {:?} from '{}'", protocol, server.protocol);
364                return Ok(protocol);
365            }
366            Either::Left(reference) => {
367                if let Some(server) = resolve_ref_as::<Server>(&spec_doc, &reference.reference) {
368                    let protocol = Protocol::from_protocol_string(&server.protocol);
369                    tracing::debug!(
370                        "Detected protocol: {:?} from referenced '{}'",
371                        protocol,
372                        server.protocol
373                    );
374                    return Ok(protocol);
375                }
376                tracing::debug!("Skipping unresolved server reference: {}", reference.reference);
377            }
378        }
379    }
380
381    tracing::warn!("Could not determine protocol from spec, defaulting to WebSocket");
382    Ok(Protocol::WebSocket)
383}
384
385/// Decode JSON pointer segments
386pub fn decode_pointer_segment(segment: &str) -> String {
387    segment.replace("~1", "/").replace("~0", "~")
388}
389
390fn reference_to_pointer(reference: &str) -> Option<String> {
391    let raw = reference.strip_prefix("#/")?;
392    let mut pointer = String::new();
393    for segment in raw.split('/') {
394        pointer.push('/');
395        pointer.push_str(&decode_pointer_segment(segment));
396    }
397    Some(pointer)
398}
399
400fn resolve_ref_value(document: &Value, reference: &str) -> Option<Value> {
401    let mut current = reference.to_string();
402    let mut visited = HashSet::new();
403
404    for _ in 0..32 {
405        if !visited.insert(current.clone()) {
406            return None;
407        }
408
409        let pointer = reference_to_pointer(&current)?;
410        let value = document.pointer(&pointer)?;
411
412        if let Some(next_ref) = value.get("$ref").and_then(Value::as_str) {
413            current = next_ref.to_string();
414            continue;
415        }
416
417        return Some(value.clone());
418    }
419
420    None
421}
422
423fn resolve_ref_as<T>(document: &Value, reference: &str) -> Option<T>
424where
425    T: serde::de::DeserializeOwned,
426{
427    let value = resolve_ref_value(document, reference)?;
428    serde_json::from_value(value).ok()
429}
430
431fn normalize_schema_ref_value(value: Value) -> Value {
432    if let Some(obj) = value.as_object()
433        && obj.get("schemaFormat").is_some()
434        && let Some(schema) = obj.get("schema")
435    {
436        return schema.clone();
437    }
438    value
439}
440
441fn resolve_schema_tree(document: &Value, schema: &Value, remaining_depth: usize) -> Value {
442    if remaining_depth == 0 {
443        return schema.clone();
444    }
445
446    if let Some(reference) = schema.get("$ref").and_then(Value::as_str)
447        && let Some(resolved) = resolve_ref_value(document, reference)
448    {
449        return resolve_schema_tree(document, &normalize_schema_ref_value(resolved), remaining_depth - 1);
450    }
451
452    match schema {
453        Value::Object(map) => {
454            let mut resolved = serde_json::Map::new();
455            for (key, value) in map {
456                resolved.insert(key.clone(), resolve_schema_tree(document, value, remaining_depth - 1));
457            }
458            Value::Object(resolved)
459        }
460        Value::Array(items) => Value::Array(
461            items
462                .iter()
463                .map(|item| resolve_schema_tree(document, item, remaining_depth - 1))
464                .collect(),
465        ),
466        _ => schema.clone(),
467    }
468}
469
470/// Resolve channel reference to channel path
471pub fn resolve_channel_from_ref(reference: &str) -> Option<String> {
472    let raw = reference.strip_prefix("#/channels/")?;
473    let decoded = raw.split('/').map(decode_pointer_segment).collect::<Vec<_>>().join("/");
474    let normalized = decoded.trim_start_matches('/').to_string();
475    Some(format!("/{normalized}"))
476}
477
478/// Resolve message reference to message name
479pub fn resolve_message_from_ref(reference: &str) -> Option<String> {
480    if let Some(name) = reference.strip_prefix("#/components/messages/") {
481        return Some(name.to_string());
482    }
483
484    if let Some(rest) = reference.strip_prefix("#/channels/") {
485        let mut parts = rest.split('/');
486        let channel = parts.next()?;
487        if parts.next()? != "messages" {
488            return None;
489        }
490        let message = parts.next()?;
491        let channel_name = decode_pointer_segment(channel);
492        let slug = channel_name.trim_start_matches('/').replace('/', "_");
493        return Some(format!("{}_{}", slug, decode_pointer_segment(message)));
494    }
495
496    None
497}
498
499/// Get operation action name as string
500pub const fn operation_action_name(action: &asyncapiv3::spec::operation::OperationAction) -> &'static str {
501    use asyncapiv3::spec::operation::OperationAction;
502    match action {
503        OperationAction::Send => "send",
504        OperationAction::Receive => "receive",
505    }
506}
507
508/// Collect message channel addresses from spec
509pub fn collect_message_channels(spec: &AsyncApiV3Spec) -> (HashMap<String, String>, HashMap<String, String>) {
510    use asyncapiv3::spec::common::Either;
511
512    let mut map = HashMap::new();
513    let mut aliases = HashMap::new();
514
515    for (channel_path, channel_ref_or) in &spec.channels {
516        let address = match channel_ref_or {
517            Either::Right(channel) => channel.address.clone().unwrap_or_else(|| channel_path.clone()),
518            Either::Left(_) => continue,
519        };
520        let normalized_address = if address.starts_with('/') {
521            address.clone()
522        } else {
523            format!("/{address}")
524        };
525
526        if let Either::Right(channel) = channel_ref_or {
527            for (message_name, message_ref) in &channel.messages {
528                let slug = channel_path.trim_start_matches('/').replace('/', "_");
529                let inline_key = format!("{slug}_{message_name}");
530                match message_ref {
531                    Either::Right(_) => {
532                        map.entry(inline_key.clone())
533                            .or_insert_with(|| normalized_address.clone());
534                    }
535                    Either::Left(reference) => {
536                        let target =
537                            resolve_message_from_ref(&reference.reference).unwrap_or_else(|| message_name.clone());
538                        map.entry(target.clone()).or_insert_with(|| normalized_address.clone());
539                        aliases.insert(inline_key, target);
540                    }
541                }
542            }
543        }
544    }
545
546    (map, aliases)
547}
548
549/// Collect message operations from spec
550pub fn collect_message_operations(
551    spec: &AsyncApiV3Spec,
552    aliases: &HashMap<String, String>,
553) -> HashMap<String, Vec<MessageOperationMetadata>> {
554    use asyncapiv3::spec::common::Either;
555
556    let mut map: HashMap<String, Vec<MessageOperationMetadata>> = HashMap::new();
557
558    for (op_name, operation_ref) in &spec.operations {
559        let operation = match operation_ref {
560            Either::Right(op) => op,
561            Either::Left(_) => continue,
562        };
563
564        let replies: Vec<String> = if let Some(Either::Right(reply)) = &operation.reply {
565            reply
566                .messages
567                .iter()
568                .filter_map(|reference| resolve_message_from_ref(&reference.reference))
569                .collect()
570        } else {
571            Vec::new()
572        };
573
574        if let Some(message_refs) = &operation.messages {
575            for reference in message_refs {
576                if let Some(name) = resolve_message_from_ref(&reference.reference) {
577                    let resolved_name = aliases.get(&name).cloned().unwrap_or(name.clone());
578                    map.entry(resolved_name).or_default().push(MessageOperationMetadata {
579                        name: op_name.clone(),
580                        action: operation_action_name(&operation.action).to_string(),
581                        replies: replies.clone(),
582                    });
583                }
584            }
585        }
586    }
587
588    map
589}
590
591/// Collect channel operations from spec
592pub fn collect_channel_operations(spec: &AsyncApiV3Spec) -> HashMap<String, Vec<ChannelOperation>> {
593    use asyncapiv3::spec::common::Either;
594
595    let mut map: HashMap<String, Vec<ChannelOperation>> = HashMap::new();
596
597    for (op_name, operation_ref) in &spec.operations {
598        let operation = match operation_ref {
599            Either::Right(op) => op,
600            Either::Left(_) => continue,
601        };
602
603        let channel_path = match resolve_channel_from_ref(&operation.channel.reference) {
604            Some(path) => path,
605            None => continue,
606        };
607
608        let messages = operation
609            .messages
610            .as_ref()
611            .map(|refs| {
612                refs.iter()
613                    .filter_map(|reference| resolve_message_from_ref(&reference.reference))
614                    .collect::<Vec<_>>()
615            })
616            .unwrap_or_default();
617
618        let replies = if let Some(Either::Right(reply)) = &operation.reply {
619            reply
620                .messages
621                .iter()
622                .filter_map(|reference| resolve_message_from_ref(&reference.reference))
623                .collect::<Vec<_>>()
624        } else {
625            Vec::new()
626        };
627
628        map.entry(channel_path.clone()).or_default().push(ChannelOperation {
629            name: op_name.clone(),
630            action: operation_action_name(&operation.action).to_string(),
631            messages,
632            replies,
633        });
634    }
635
636    map
637}
638
639#[cfg(test)]
640mod tests {
641    use super::*;
642
643    #[test]
644    fn test_protocol_detection() {
645        assert_eq!(Protocol::from_protocol_string("ws"), Protocol::WebSocket);
646        assert_eq!(Protocol::from_protocol_string("wss"), Protocol::WebSocket);
647        assert_eq!(Protocol::from_protocol_string("websocket"), Protocol::WebSocket);
648        assert_eq!(Protocol::from_protocol_string("sse"), Protocol::Sse);
649        assert_eq!(Protocol::from_protocol_string("server-sent-events"), Protocol::Sse);
650        assert_eq!(Protocol::from_protocol_string("http"), Protocol::Http);
651        assert_eq!(Protocol::from_protocol_string("https"), Protocol::Http);
652        assert_eq!(Protocol::from_protocol_string("kafka"), Protocol::Kafka);
653        assert_eq!(Protocol::from_protocol_string("unknown"), Protocol::Other);
654    }
655
656    #[test]
657    fn test_decode_pointer_segment() {
658        assert_eq!(decode_pointer_segment("hello~1world"), "hello/world");
659        assert_eq!(decode_pointer_segment("test~0value"), "test~value");
660    }
661
662    #[test]
663    fn test_resolve_message_from_ref_components() {
664        let result = resolve_message_from_ref("#/components/messages/UserMessage");
665        assert_eq!(result, Some("UserMessage".to_string()));
666    }
667
668    #[test]
669    fn test_reference_to_pointer_decodes_json_pointer_segments() {
670        let pointer = reference_to_pointer("#/channels/user~1signedup/messages/user~0created");
671        assert_eq!(
672            pointer,
673            Some("/channels/user/signedup/messages/user~created".to_string())
674        );
675    }
676
677    #[test]
678    fn test_resolve_ref_value_follows_nested_local_refs() {
679        let doc = serde_json::json!({
680            "components": {
681                "schemas": {
682                    "A": { "$ref": "#/components/schemas/B" },
683                    "B": { "type": "object", "properties": { "id": { "type": "string" } } }
684                }
685            }
686        });
687
688        let resolved = resolve_ref_value(&doc, "#/components/schemas/A").expect("resolved schema");
689        assert_eq!(resolved["type"], "object");
690        assert!(resolved["properties"].get("id").is_some());
691    }
692
693    #[test]
694    fn test_detect_primary_protocol_resolves_server_refs() {
695        let spec_value = serde_json::json!({
696            "asyncapi": "3.0.0",
697            "info": { "title": "Test", "version": "1.0.0" },
698            "servers": {
699                "default": { "$ref": "#/components/servers/wsServer" }
700            },
701            "channels": {},
702            "operations": {},
703            "components": {
704                "servers": {
705                    "wsServer": {
706                        "host": "example.com",
707                        "protocol": "wss"
708                    }
709                }
710            }
711        });
712
713        let spec = match serde_json::from_value::<AsyncApiSpec>(spec_value).expect("valid asyncapi spec") {
714            AsyncApiSpec::V3_0_0(v3) => v3,
715        };
716
717        let protocol = detect_primary_protocol(&spec).expect("protocol detection");
718        assert_eq!(protocol, Protocol::WebSocket);
719    }
720}