1use anyhow::{Context, Result};
7use asyncapiv3::spec::{AsyncApiSpec, AsyncApiV3Spec};
8use serde_json::Value;
9use std::collections::{HashMap, HashSet};
10use std::fs;
11use std::path::Path;
12
13#[derive(Debug, Clone)]
15pub struct MessageDefinition {
16 pub schema: Value,
17 pub examples: Vec<Value>,
18}
19
20#[derive(Debug, Clone)]
22pub struct MessageOperationMetadata {
23 pub name: String,
24 pub action: String,
25 pub replies: Vec<String>,
26}
27
28#[derive(Debug, Clone)]
30#[allow(dead_code)]
31pub struct ChannelOperation {
32 pub name: String,
33 pub action: String,
34 pub messages: Vec<String>,
35 pub replies: Vec<String>,
36}
37
38pub fn parse_asyncapi_schema(path: &Path) -> Result<AsyncApiV3Spec> {
42 let content =
43 fs::read_to_string(path).with_context(|| format!("Failed to read AsyncAPI file: {}", path.display()))?;
44
45 let spec: AsyncApiSpec = if path.extension().and_then(|s| s.to_str()) == Some("json") {
46 serde_json::from_str(&content)
47 .with_context(|| format!("Failed to parse AsyncAPI JSON from {}", path.display()))?
48 } else {
49 serde_saphyr::from_str(&content)
50 .with_context(|| format!("Failed to parse AsyncAPI YAML from {}", path.display()))?
51 };
52
53 match spec {
54 AsyncApiSpec::V3_0_0(v3_spec) => Ok(v3_spec),
55 }
56}
57
58pub fn extract_message_schemas(spec: &AsyncApiV3Spec) -> Result<HashMap<String, MessageDefinition>> {
62 use asyncapiv3::spec::common::Either;
63 use asyncapiv3::spec::{channel::Channel, message::Message};
64
65 let mut schemas = HashMap::new();
66 let spec_doc = serde_json::to_value(spec).context("Failed to serialize AsyncAPI spec for $ref resolution")?;
67
68 for (message_name, message_ref_or) in &spec.components.messages {
69 tracing::debug!("Processing message: {}", message_name);
70
71 match message_ref_or {
72 Either::Right(message) => {
73 if let Some(definition) = build_message_definition(message, message_name, &spec_doc)? {
74 schemas.insert(message_name.clone(), definition);
75 }
76 }
77 Either::Left(reference) => {
78 if let Some(message) = resolve_ref_as::<Message>(&spec_doc, &reference.reference) {
79 if let Some(definition) = build_message_definition(&message, message_name, &spec_doc)? {
80 schemas.insert(message_name.clone(), definition);
81 }
82 } else {
83 tracing::debug!(
84 "Skipping unresolved message reference: {} -> {}",
85 message_name,
86 reference.reference
87 );
88 }
89 }
90 }
91 }
92
93 for (channel_name, channel_ref_or) in &spec.channels {
94 tracing::debug!("Processing channel: {}", channel_name);
95
96 match channel_ref_or {
97 Either::Right(channel) => {
98 process_channel_messages(channel_name, channel, &spec_doc, &mut schemas)?;
99 }
100 Either::Left(reference) => {
101 if let Some(channel) = resolve_ref_as::<Channel>(&spec_doc, &reference.reference) {
102 process_channel_messages(channel_name, &channel, &spec_doc, &mut schemas)?;
103 } else {
104 tracing::debug!("Skipping unresolved channel reference: {}", reference.reference);
105 }
106 }
107 }
108 }
109
110 Ok(schemas)
111}
112
113fn process_channel_messages(
114 channel_name: &str,
115 channel: &asyncapiv3::spec::channel::Channel,
116 spec_doc: &Value,
117 schemas: &mut HashMap<String, MessageDefinition>,
118) -> Result<()> {
119 use asyncapiv3::spec::common::Either;
120 use asyncapiv3::spec::message::Message;
121
122 for (msg_name, msg_ref_or) in &channel.messages {
123 let full_name = format!("{}_{}", channel_name.trim_start_matches('/'), msg_name);
124 match msg_ref_or {
125 Either::Right(message) => {
126 if let Some(definition) = build_message_definition(message, &full_name, spec_doc)? {
127 schemas.insert(full_name, definition);
128 }
129 }
130 Either::Left(reference) => {
131 if let Some(message) = resolve_ref_as::<Message>(spec_doc, &reference.reference) {
132 if let Some(definition) = build_message_definition(&message, &full_name, spec_doc)? {
133 schemas.insert(full_name, definition);
134 }
135 } else {
136 tracing::debug!(
137 "Channel {} message {} unresolved reference: {}",
138 channel_name,
139 msg_name,
140 reference.reference
141 );
142 }
143 }
144 }
145 }
146
147 Ok(())
148}
149
150fn build_message_definition(
151 message: &asyncapiv3::spec::message::Message,
152 message_name: &str,
153 spec_doc: &Value,
154) -> Result<Option<MessageDefinition>> {
155 let schema = match extract_schema_from_message(message, message_name, spec_doc)? {
156 Some(schema) => schema,
157 None => return Ok(None),
158 };
159 let schema = resolve_schema_tree(spec_doc, &schema, 32);
160
161 let mut examples: Vec<Value> = Vec::new();
162 for example in &message.examples {
163 if !example.payload.is_empty() {
164 let value = serde_json::to_value(&example.payload)
165 .context("Failed to serialize AsyncAPI message example payload")?;
166 examples.push(value);
167 }
168 }
169
170 if examples.is_empty() {
171 examples = generate_example_from_schema(&schema)?;
172 }
173
174 Ok(Some(MessageDefinition { schema, examples }))
175}
176
177fn extract_schema_from_message(
179 message: &asyncapiv3::spec::message::Message,
180 message_name: &str,
181 spec_doc: &Value,
182) -> Result<Option<Value>> {
183 use asyncapiv3::spec::common::Either;
184
185 let payload = if let Some(payload_ref_or) = &message.payload {
186 payload_ref_or
187 } else {
188 tracing::debug!("Message {} has no payload", message_name);
189 return Ok(None);
190 };
191
192 match payload {
193 Either::Right(schema_or_multiformat) => match schema_or_multiformat {
194 Either::Left(schema) => {
195 let schema_json =
196 serde_json::to_value(schema).context("Failed to serialize schemars::Schema to JSON")?;
197 Ok(Some(schema_json))
198 }
199 Either::Right(multi_format) => Ok(Some(multi_format.schema.clone())),
200 },
201 Either::Left(reference) => {
202 if let Some(resolved) = resolve_ref_value(spec_doc, &reference.reference) {
203 Ok(Some(normalize_schema_ref_value(resolved)))
204 } else {
205 tracing::debug!(
206 "Message {} payload has unresolved reference: {}",
207 message_name,
208 reference.reference
209 );
210 Ok(None)
211 }
212 }
213 }
214}
215
216pub fn generate_example_from_schema(schema: &Value) -> Result<Vec<Value>> {
220 let mut examples = Vec::new();
221
222 if let Some(schema_examples) = schema.get("examples").and_then(|e| e.as_array()) {
223 examples.extend(schema_examples.clone());
224 }
225
226 if examples.is_empty()
227 && schema
228 .get("type")
229 .and_then(|value| value.as_str())
230 .is_some_and(|ty| ty.eq_ignore_ascii_case("array"))
231 {
232 if let Some(items) = schema.get("items") {
233 let generated = generate_example_from_schema(items)?;
234 let template = generated
235 .into_iter()
236 .next()
237 .unwrap_or_else(|| Value::Object(serde_json::Map::new()));
238 let min_items = schema.get("minItems").and_then(serde_json::Value::as_u64).unwrap_or(1);
239 let mut target_len = usize::try_from(min_items).unwrap_or(usize::MAX);
240 if target_len == 0 {
241 target_len = 1;
242 }
243 let capped_len = target_len.min(5);
244 let mut array_values = Vec::new();
245 for _ in 0..capped_len {
246 array_values.push(template.clone());
247 }
248 examples.push(Value::Array(array_values));
249 } else {
250 examples.push(Value::Array(vec![]));
251 }
252 }
253
254 if examples.is_empty()
255 && let Some(obj) = schema.get("properties").and_then(|p| p.as_object())
256 {
257 let mut example = serde_json::Map::new();
258
259 for (prop_name, prop_schema) in obj {
260 let example_value = if let Some(const_val) = prop_schema.get("const") {
261 const_val.clone()
262 } else if let Some(type_str) = prop_schema.get("type").and_then(|t| t.as_str()) {
263 match type_str {
264 "string" => {
265 if let Some(format) = prop_schema.get("format").and_then(|f| f.as_str()) {
266 match format {
267 "date-time" => Value::String("2024-01-15T10:30:00Z".to_string()),
268 "date" => Value::String("2024-01-15".to_string()),
269 "time" => Value::String("10:30:00".to_string()),
270 "email" => Value::String("user@example.com".to_string()),
271 "uri" => Value::String("https://example.com".to_string()),
272 "uuid" => Value::String("550e8400-e29b-41d4-a716-446655440000".to_string()),
273 _ => Value::String(format!("example_{prop_name}")),
274 }
275 } else {
276 Value::String(format!("example_{prop_name}"))
277 }
278 }
279 "number" => Value::Number(
280 serde_json::Number::from_f64(std::f64::consts::PI)
281 .unwrap_or_else(|| serde_json::Number::from(314)),
282 ),
283 "integer" => Value::Number(serde_json::Number::from(42)),
284 "boolean" => Value::Bool(true),
285 _ => Value::Null,
286 }
287 } else {
288 Value::Null
289 };
290
291 example.insert(prop_name.clone(), example_value);
292 }
293
294 examples.push(Value::Object(example));
295 }
296
297 if examples.is_empty() {
298 examples.push(Value::Object(serde_json::Map::new()));
299 }
300
301 Ok(examples)
302}
303
304#[derive(Debug, Clone, Copy, PartialEq, Eq)]
306pub enum Protocol {
307 WebSocket,
308 Sse,
309 Http,
310 Kafka,
311 Mqtt,
312 Amqp,
313 Other,
314}
315
316impl Protocol {
317 #[must_use]
319 pub fn from_protocol_string(protocol: &str) -> Self {
320 match protocol.to_lowercase().as_str() {
321 "ws" | "wss" | "websocket" | "websockets" => Self::WebSocket,
322 "sse" | "server-sent-events" => Self::Sse,
323 "http" | "https" => Self::Http,
324 "kafka" => Self::Kafka,
325 "mqtt" => Self::Mqtt,
326 "amqp" => Self::Amqp,
327 _ => Self::Other,
328 }
329 }
330
331 #[must_use]
332 pub const fn as_str(&self) -> &'static str {
333 match self {
334 Self::WebSocket => "websocket",
335 Self::Sse => "sse",
336 Self::Http => "http",
337 Self::Kafka => "kafka",
338 Self::Mqtt => "mqtt",
339 Self::Amqp => "amqp",
340 Self::Other => "other",
341 }
342 }
343}
344
345impl std::fmt::Display for Protocol {
346 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
347 write!(f, "{}", self.as_str())
348 }
349}
350
351pub fn detect_primary_protocol(spec: &AsyncApiV3Spec) -> Result<Protocol> {
353 use asyncapiv3::spec::common::Either;
354 use asyncapiv3::spec::server::Server;
355
356 let spec_doc =
357 serde_json::to_value(spec).context("Failed to serialize AsyncAPI spec for server $ref resolution")?;
358
359 for server_or_ref in spec.servers.values() {
360 match server_or_ref {
361 Either::Right(server) => {
362 let protocol = Protocol::from_protocol_string(&server.protocol);
363 tracing::debug!("Detected protocol: {:?} from '{}'", protocol, server.protocol);
364 return Ok(protocol);
365 }
366 Either::Left(reference) => {
367 if let Some(server) = resolve_ref_as::<Server>(&spec_doc, &reference.reference) {
368 let protocol = Protocol::from_protocol_string(&server.protocol);
369 tracing::debug!(
370 "Detected protocol: {:?} from referenced '{}'",
371 protocol,
372 server.protocol
373 );
374 return Ok(protocol);
375 }
376 tracing::debug!("Skipping unresolved server reference: {}", reference.reference);
377 }
378 }
379 }
380
381 tracing::warn!("Could not determine protocol from spec, defaulting to WebSocket");
382 Ok(Protocol::WebSocket)
383}
384
385pub fn decode_pointer_segment(segment: &str) -> String {
387 segment.replace("~1", "/").replace("~0", "~")
388}
389
390fn reference_to_pointer(reference: &str) -> Option<String> {
391 let raw = reference.strip_prefix("#/")?;
392 let mut pointer = String::new();
393 for segment in raw.split('/') {
394 pointer.push('/');
395 pointer.push_str(&decode_pointer_segment(segment));
396 }
397 Some(pointer)
398}
399
400fn resolve_ref_value(document: &Value, reference: &str) -> Option<Value> {
401 let mut current = reference.to_string();
402 let mut visited = HashSet::new();
403
404 for _ in 0..32 {
405 if !visited.insert(current.clone()) {
406 return None;
407 }
408
409 let pointer = reference_to_pointer(¤t)?;
410 let value = document.pointer(&pointer)?;
411
412 if let Some(next_ref) = value.get("$ref").and_then(Value::as_str) {
413 current = next_ref.to_string();
414 continue;
415 }
416
417 return Some(value.clone());
418 }
419
420 None
421}
422
423fn resolve_ref_as<T>(document: &Value, reference: &str) -> Option<T>
424where
425 T: serde::de::DeserializeOwned,
426{
427 let value = resolve_ref_value(document, reference)?;
428 serde_json::from_value(value).ok()
429}
430
431fn normalize_schema_ref_value(value: Value) -> Value {
432 if let Some(obj) = value.as_object()
433 && obj.get("schemaFormat").is_some()
434 && let Some(schema) = obj.get("schema")
435 {
436 return schema.clone();
437 }
438 value
439}
440
441fn resolve_schema_tree(document: &Value, schema: &Value, remaining_depth: usize) -> Value {
442 if remaining_depth == 0 {
443 return schema.clone();
444 }
445
446 if let Some(reference) = schema.get("$ref").and_then(Value::as_str)
447 && let Some(resolved) = resolve_ref_value(document, reference)
448 {
449 return resolve_schema_tree(document, &normalize_schema_ref_value(resolved), remaining_depth - 1);
450 }
451
452 match schema {
453 Value::Object(map) => {
454 let mut resolved = serde_json::Map::new();
455 for (key, value) in map {
456 resolved.insert(key.clone(), resolve_schema_tree(document, value, remaining_depth - 1));
457 }
458 Value::Object(resolved)
459 }
460 Value::Array(items) => Value::Array(
461 items
462 .iter()
463 .map(|item| resolve_schema_tree(document, item, remaining_depth - 1))
464 .collect(),
465 ),
466 _ => schema.clone(),
467 }
468}
469
470pub fn resolve_channel_from_ref(reference: &str) -> Option<String> {
472 let raw = reference.strip_prefix("#/channels/")?;
473 let decoded = raw.split('/').map(decode_pointer_segment).collect::<Vec<_>>().join("/");
474 let normalized = decoded.trim_start_matches('/').to_string();
475 Some(format!("/{normalized}"))
476}
477
478pub fn resolve_message_from_ref(reference: &str) -> Option<String> {
480 if let Some(name) = reference.strip_prefix("#/components/messages/") {
481 return Some(name.to_string());
482 }
483
484 if let Some(rest) = reference.strip_prefix("#/channels/") {
485 let mut parts = rest.split('/');
486 let channel = parts.next()?;
487 if parts.next()? != "messages" {
488 return None;
489 }
490 let message = parts.next()?;
491 let channel_name = decode_pointer_segment(channel);
492 let slug = channel_name.trim_start_matches('/').replace('/', "_");
493 return Some(format!("{}_{}", slug, decode_pointer_segment(message)));
494 }
495
496 None
497}
498
499pub const fn operation_action_name(action: &asyncapiv3::spec::operation::OperationAction) -> &'static str {
501 use asyncapiv3::spec::operation::OperationAction;
502 match action {
503 OperationAction::Send => "send",
504 OperationAction::Receive => "receive",
505 }
506}
507
508pub fn collect_message_channels(spec: &AsyncApiV3Spec) -> (HashMap<String, String>, HashMap<String, String>) {
510 use asyncapiv3::spec::common::Either;
511
512 let mut map = HashMap::new();
513 let mut aliases = HashMap::new();
514
515 for (channel_path, channel_ref_or) in &spec.channels {
516 let address = match channel_ref_or {
517 Either::Right(channel) => channel.address.clone().unwrap_or_else(|| channel_path.clone()),
518 Either::Left(_) => continue,
519 };
520 let normalized_address = if address.starts_with('/') {
521 address.clone()
522 } else {
523 format!("/{address}")
524 };
525
526 if let Either::Right(channel) = channel_ref_or {
527 for (message_name, message_ref) in &channel.messages {
528 let slug = channel_path.trim_start_matches('/').replace('/', "_");
529 let inline_key = format!("{slug}_{message_name}");
530 match message_ref {
531 Either::Right(_) => {
532 map.entry(inline_key.clone())
533 .or_insert_with(|| normalized_address.clone());
534 }
535 Either::Left(reference) => {
536 let target =
537 resolve_message_from_ref(&reference.reference).unwrap_or_else(|| message_name.clone());
538 map.entry(target.clone()).or_insert_with(|| normalized_address.clone());
539 aliases.insert(inline_key, target);
540 }
541 }
542 }
543 }
544 }
545
546 (map, aliases)
547}
548
549pub fn collect_message_operations(
551 spec: &AsyncApiV3Spec,
552 aliases: &HashMap<String, String>,
553) -> HashMap<String, Vec<MessageOperationMetadata>> {
554 use asyncapiv3::spec::common::Either;
555
556 let mut map: HashMap<String, Vec<MessageOperationMetadata>> = HashMap::new();
557
558 for (op_name, operation_ref) in &spec.operations {
559 let operation = match operation_ref {
560 Either::Right(op) => op,
561 Either::Left(_) => continue,
562 };
563
564 let replies: Vec<String> = if let Some(Either::Right(reply)) = &operation.reply {
565 reply
566 .messages
567 .iter()
568 .filter_map(|reference| resolve_message_from_ref(&reference.reference))
569 .collect()
570 } else {
571 Vec::new()
572 };
573
574 if let Some(message_refs) = &operation.messages {
575 for reference in message_refs {
576 if let Some(name) = resolve_message_from_ref(&reference.reference) {
577 let resolved_name = aliases.get(&name).cloned().unwrap_or(name.clone());
578 map.entry(resolved_name).or_default().push(MessageOperationMetadata {
579 name: op_name.clone(),
580 action: operation_action_name(&operation.action).to_string(),
581 replies: replies.clone(),
582 });
583 }
584 }
585 }
586 }
587
588 map
589}
590
591pub fn collect_channel_operations(spec: &AsyncApiV3Spec) -> HashMap<String, Vec<ChannelOperation>> {
593 use asyncapiv3::spec::common::Either;
594
595 let mut map: HashMap<String, Vec<ChannelOperation>> = HashMap::new();
596
597 for (op_name, operation_ref) in &spec.operations {
598 let operation = match operation_ref {
599 Either::Right(op) => op,
600 Either::Left(_) => continue,
601 };
602
603 let channel_path = match resolve_channel_from_ref(&operation.channel.reference) {
604 Some(path) => path,
605 None => continue,
606 };
607
608 let messages = operation
609 .messages
610 .as_ref()
611 .map(|refs| {
612 refs.iter()
613 .filter_map(|reference| resolve_message_from_ref(&reference.reference))
614 .collect::<Vec<_>>()
615 })
616 .unwrap_or_default();
617
618 let replies = if let Some(Either::Right(reply)) = &operation.reply {
619 reply
620 .messages
621 .iter()
622 .filter_map(|reference| resolve_message_from_ref(&reference.reference))
623 .collect::<Vec<_>>()
624 } else {
625 Vec::new()
626 };
627
628 map.entry(channel_path.clone()).or_default().push(ChannelOperation {
629 name: op_name.clone(),
630 action: operation_action_name(&operation.action).to_string(),
631 messages,
632 replies,
633 });
634 }
635
636 map
637}
638
639#[cfg(test)]
640mod tests {
641 use super::*;
642
643 #[test]
644 fn test_protocol_detection() {
645 assert_eq!(Protocol::from_protocol_string("ws"), Protocol::WebSocket);
646 assert_eq!(Protocol::from_protocol_string("wss"), Protocol::WebSocket);
647 assert_eq!(Protocol::from_protocol_string("websocket"), Protocol::WebSocket);
648 assert_eq!(Protocol::from_protocol_string("sse"), Protocol::Sse);
649 assert_eq!(Protocol::from_protocol_string("server-sent-events"), Protocol::Sse);
650 assert_eq!(Protocol::from_protocol_string("http"), Protocol::Http);
651 assert_eq!(Protocol::from_protocol_string("https"), Protocol::Http);
652 assert_eq!(Protocol::from_protocol_string("kafka"), Protocol::Kafka);
653 assert_eq!(Protocol::from_protocol_string("unknown"), Protocol::Other);
654 }
655
656 #[test]
657 fn test_decode_pointer_segment() {
658 assert_eq!(decode_pointer_segment("hello~1world"), "hello/world");
659 assert_eq!(decode_pointer_segment("test~0value"), "test~value");
660 }
661
662 #[test]
663 fn test_resolve_message_from_ref_components() {
664 let result = resolve_message_from_ref("#/components/messages/UserMessage");
665 assert_eq!(result, Some("UserMessage".to_string()));
666 }
667
668 #[test]
669 fn test_reference_to_pointer_decodes_json_pointer_segments() {
670 let pointer = reference_to_pointer("#/channels/user~1signedup/messages/user~0created");
671 assert_eq!(
672 pointer,
673 Some("/channels/user/signedup/messages/user~created".to_string())
674 );
675 }
676
677 #[test]
678 fn test_resolve_ref_value_follows_nested_local_refs() {
679 let doc = serde_json::json!({
680 "components": {
681 "schemas": {
682 "A": { "$ref": "#/components/schemas/B" },
683 "B": { "type": "object", "properties": { "id": { "type": "string" } } }
684 }
685 }
686 });
687
688 let resolved = resolve_ref_value(&doc, "#/components/schemas/A").expect("resolved schema");
689 assert_eq!(resolved["type"], "object");
690 assert!(resolved["properties"].get("id").is_some());
691 }
692
693 #[test]
694 fn test_detect_primary_protocol_resolves_server_refs() {
695 let spec_value = serde_json::json!({
696 "asyncapi": "3.0.0",
697 "info": { "title": "Test", "version": "1.0.0" },
698 "servers": {
699 "default": { "$ref": "#/components/servers/wsServer" }
700 },
701 "channels": {},
702 "operations": {},
703 "components": {
704 "servers": {
705 "wsServer": {
706 "host": "example.com",
707 "protocol": "wss"
708 }
709 }
710 }
711 });
712
713 let spec = match serde_json::from_value::<AsyncApiSpec>(spec_value).expect("valid asyncapi spec") {
714 AsyncApiSpec::V3_0_0(v3) => v3,
715 };
716
717 let protocol = detect_primary_protocol(&spec).expect("protocol detection");
718 assert_eq!(protocol, Protocol::WebSocket);
719 }
720}