Skip to main content

spikard_cli/codegen/protobuf/
spec_parser.rs

1//! Protobuf (.proto) specification parsing and extraction.
2//!
3//! This module handles parsing Protocol Buffer specifications (proto3 syntax only)
4//! and extracting structured data for code generation, including messages, services,
5//! enums, and field definitions.
6
7use anyhow::{Context, Result, anyhow, bail};
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10use std::fs;
11use std::path::{Path, PathBuf};
12
13/// Parsed Protobuf schema representation
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ProtobufSchema {
16    /// Package name (e.g., "com.example.service")
17    pub package: Option<String>,
18    /// Map of message names to their definitions
19    pub messages: HashMap<String, MessageDef>,
20    /// Map of service names to their definitions
21    pub services: HashMap<String, ServiceDef>,
22    /// Map of enum names to their definitions
23    pub enums: HashMap<String, EnumDef>,
24    /// List of imported proto files
25    pub imports: Vec<String>,
26    /// Proto file syntax version (enforced to be "proto3")
27    pub syntax: String,
28    /// Schema description/comments
29    pub description: Option<String>,
30}
31
32/// Protobuf message definition
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct MessageDef {
35    /// Message name
36    pub name: String,
37    /// Message fields
38    pub fields: Vec<FieldDef>,
39    /// Nested message definitions
40    pub nested_messages: HashMap<String, Self>,
41    /// Nested enum definitions
42    pub nested_enums: HashMap<String, EnumDef>,
43    /// Message description from comments
44    pub description: Option<String>,
45}
46
47/// Protobuf service definition
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct ServiceDef {
50    /// Service name
51    pub name: String,
52    /// Service methods/RPCs
53    pub methods: Vec<MethodDef>,
54    /// Service description from comments
55    pub description: Option<String>,
56}
57
58/// Protobuf RPC method definition
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct MethodDef {
61    /// Method name
62    pub name: String,
63    /// Input message type name
64    pub input_type: String,
65    /// Output message type name
66    pub output_type: String,
67    /// Whether input is a stream
68    pub input_streaming: bool,
69    /// Whether output is a stream
70    pub output_streaming: bool,
71    /// Method description from comments
72    pub description: Option<String>,
73}
74
75/// Protobuf field definition
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct FieldDef {
78    /// Field name
79    pub name: String,
80    /// Field number (1-536870911)
81    pub number: u32,
82    /// Field type
83    pub field_type: ProtoType,
84    /// Field label (optional, repeated, or neither for required)
85    pub label: FieldLabel,
86    /// Default value (if applicable)
87    pub default_value: Option<String>,
88    /// Field description from comments
89    pub description: Option<String>,
90}
91
92/// Protocol Buffer field label
93#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
94pub enum FieldLabel {
95    /// No label (proto3 default: optional for scalars, required for messages)
96    None,
97    /// Repeated field (becomes a list)
98    Repeated,
99    /// Optional field (may be unset)
100    Optional,
101}
102
103/// Protobuf enum definition
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct EnumDef {
106    /// Enum name
107    pub name: String,
108    /// Enum values
109    pub values: Vec<EnumValue>,
110    /// Enum description from comments
111    pub description: Option<String>,
112}
113
114/// Protobuf enum value
115#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct EnumValue {
117    /// Value name
118    pub name: String,
119    /// Numeric value
120    pub number: i32,
121    /// Value description from comments
122    pub description: Option<String>,
123}
124
125/// Protocol Buffer type enumeration
126#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
127pub enum ProtoType {
128    // Scalar types
129    Double,
130    Float,
131    Int32,
132    Int64,
133    Uint32,
134    Uint64,
135    Sint32,
136    Sint64,
137    Fixed32,
138    Fixed64,
139    Sfixed32,
140    Sfixed64,
141    Bool,
142    String,
143    Bytes,
144    // Complex types (resolved by name in the schema)
145    Message(String),
146    Enum(String),
147}
148
149impl ProtoType {
150    /// Get the string representation of a proto type
151    #[must_use]
152    pub fn as_str(&self) -> String {
153        match self {
154            Self::Double => "double".to_string(),
155            Self::Float => "float".to_string(),
156            Self::Int32 => "int32".to_string(),
157            Self::Int64 => "int64".to_string(),
158            Self::Uint32 => "uint32".to_string(),
159            Self::Uint64 => "uint64".to_string(),
160            Self::Sint32 => "sint32".to_string(),
161            Self::Sint64 => "sint64".to_string(),
162            Self::Fixed32 => "fixed32".to_string(),
163            Self::Fixed64 => "fixed64".to_string(),
164            Self::Sfixed32 => "sfixed32".to_string(),
165            Self::Sfixed64 => "sfixed64".to_string(),
166            Self::Bool => "bool".to_string(),
167            Self::String => "string".to_string(),
168            Self::Bytes => "bytes".to_string(),
169            Self::Message(name) => name.clone(),
170            Self::Enum(name) => name.clone(),
171        }
172    }
173}
174
175/// Parse a Protobuf schema from a .proto file
176///
177/// # Arguments
178/// * `path` - Path to .proto file
179///
180/// # Returns
181/// Parsed `ProtobufSchema` or error (rejects proto2 syntax)
182pub fn parse_proto_schema(path: &Path) -> Result<ProtobufSchema> {
183    let content = fs::read_to_string(path).with_context(|| format!("Failed to read proto file: {}", path.display()))?;
184
185    parse_proto_schema_string(&content).with_context(|| format!("Failed to parse proto schema from {}", path.display()))
186}
187
188/// Parse a Protobuf schema from a .proto file and recursively merge import dependencies.
189///
190/// Imported files are resolved relative to the source file first and then against
191/// any additional include paths supplied by the caller.
192pub fn parse_proto_schema_with_includes(path: &Path, include_paths: &[PathBuf]) -> Result<ProtobufSchema> {
193    let mut visited = HashSet::new();
194    parse_proto_schema_recursive(path, include_paths, &mut visited)
195}
196
197/// Parse a Protobuf schema from a string
198pub fn parse_proto_schema_string(content: &str) -> Result<ProtobufSchema> {
199    // For Phase 1, we implement a basic parser that validates proto3 syntax
200    // and extracts the essential structure.
201    // A production implementation would use protox for full parsing with
202    // complete support for nested structures and all field details.
203
204    let mut schema = ProtobufSchema {
205        package: None,
206        messages: HashMap::new(),
207        services: HashMap::new(),
208        enums: HashMap::new(),
209        imports: Vec::new(),
210        syntax: String::new(),
211        description: None,
212    };
213
214    // Extract syntax declaration
215    schema.syntax = extract_syntax_declaration(content).unwrap_or_else(|| "proto3".to_string());
216
217    // Validate proto3 syntax
218    if schema.syntax != "proto3" {
219        return Err(anyhow!(
220            "Only proto3 syntax is supported. Found: {}\n\
221             Please convert your proto file to proto3 syntax or use proto3-compatible definitions.\n\
222             See: https://developers.google.com/protocol-buffers/docs/proto3",
223            schema.syntax
224        ));
225    }
226
227    // Extract package name
228    schema.package = extract_package_name(content);
229
230    // Extract imports
231    schema.imports = extract_imports(content);
232
233    // Extract top-level definitions
234    parse_top_level_definitions(content, &mut schema)?;
235
236    Ok(schema)
237}
238
239fn parse_proto_schema_recursive(
240    path: &Path,
241    include_paths: &[PathBuf],
242    visited: &mut HashSet<PathBuf>,
243) -> Result<ProtobufSchema> {
244    let visit_key = canonical_or_original(path);
245    if !visited.insert(visit_key) {
246        return Ok(ProtobufSchema {
247            package: None,
248            messages: HashMap::new(),
249            services: HashMap::new(),
250            enums: HashMap::new(),
251            imports: Vec::new(),
252            syntax: "proto3".to_string(),
253            description: None,
254        });
255    }
256
257    let content = fs::read_to_string(path).with_context(|| format!("Failed to read proto file: {}", path.display()))?;
258    let mut schema = parse_proto_schema_string(&content)
259        .with_context(|| format!("Failed to parse proto schema from {}", path.display()))?;
260
261    for import in schema.imports.clone() {
262        let Some(import_path) = resolve_import_path(path, &import, include_paths) else {
263            continue;
264        };
265        let imported_schema = parse_proto_schema_recursive(&import_path, include_paths, visited)?;
266        merge_schema(&mut schema, imported_schema)?;
267    }
268
269    Ok(schema)
270}
271
272fn resolve_import_path(path: &Path, import: &str, include_paths: &[PathBuf]) -> Option<PathBuf> {
273    let mut relative_candidates = path
274        .parent()
275        .into_iter()
276        .map(|parent| parent.join(import))
277        .chain(include_paths.iter().map(|include| include.join(import)));
278
279    relative_candidates
280        .find(|candidate| candidate.is_file())
281        .map(|candidate| canonical_or_original(&candidate))
282}
283
284fn canonical_or_original(path: &Path) -> PathBuf {
285    fs::canonicalize(path).unwrap_or_else(|_| path.to_path_buf())
286}
287
288fn merge_schema(target: &mut ProtobufSchema, imported: ProtobufSchema) -> Result<()> {
289    merge_named_defs("message", &mut target.messages, imported.messages)?;
290    merge_named_defs("enum", &mut target.enums, imported.enums)?;
291    merge_named_defs("service", &mut target.services, imported.services)?;
292
293    for import in imported.imports {
294        if !target.imports.contains(&import) {
295            target.imports.push(import);
296        }
297    }
298
299    Ok(())
300}
301
302fn merge_named_defs<T>(kind: &str, target: &mut HashMap<String, T>, source: HashMap<String, T>) -> Result<()> {
303    for (name, def) in source {
304        if target.contains_key(&name) {
305            bail!("Duplicate {kind} definition found while resolving imports: {name}");
306        }
307        target.insert(name, def);
308    }
309    Ok(())
310}
311
312/// Helper function to extract syntax declaration from proto content
313fn extract_syntax_declaration(content: &str) -> Option<String> {
314    for line in content.lines() {
315        let trimmed = line.trim();
316        if trimmed.starts_with("syntax") {
317            // Try to extract syntax value: syntax = "proto3";
318            let quote_start = trimmed.find('"')?;
319            let remaining = &trimmed[quote_start + 1..];
320            let quote_end = remaining.find('"')?;
321            return Some(remaining[..quote_end].to_string());
322        }
323    }
324    None
325}
326
327/// Helper function to extract package name from proto content
328fn extract_package_name(content: &str) -> Option<String> {
329    for line in content.lines() {
330        let trimmed = line.trim();
331        if trimmed.starts_with("package") && !trimmed.starts_with("package ") {
332            continue; // Not a package declaration
333        }
334        if let Some(package_part) = trimmed.strip_prefix("package ") {
335            // Extract package name: package com.example.service;
336            let semicolon_pos = package_part.find(';')?;
337            let package_name = package_part[..semicolon_pos].trim();
338            return Some(package_name.to_string());
339        }
340    }
341    None
342}
343
344/// Helper function to extract imports from proto content
345fn extract_imports(content: &str) -> Vec<String> {
346    let mut imports = Vec::new();
347    for line in content.lines() {
348        let trimmed = line.trim();
349        if trimmed.starts_with("import ") && trimmed.contains('"') {
350            // Extract import path: import "google/protobuf/timestamp.proto";
351            if let Some(quote_start) = trimmed.find('"') {
352                let remaining = &trimmed[quote_start + 1..];
353                if let Some(quote_end) = remaining.find('"') {
354                    imports.push(remaining[..quote_end].to_string());
355                }
356            }
357        }
358    }
359    imports
360}
361
362fn parse_top_level_definitions(content: &str, schema: &mut ProtobufSchema) -> Result<()> {
363    let lines: Vec<&str> = content.lines().collect();
364    let mut index = 0;
365    let mut pending_comment: Vec<String> = Vec::new();
366
367    while index < lines.len() {
368        let trimmed = strip_inline_comment(lines[index]).trim();
369
370        if trimmed.is_empty() {
371            if !pending_comment.is_empty() {
372                pending_comment.clear();
373            }
374            index += 1;
375            continue;
376        }
377
378        if let Some(comment) = lines[index].trim().strip_prefix("//") {
379            pending_comment.push(comment.trim().to_string());
380            index += 1;
381            continue;
382        }
383
384        if trimmed.starts_with("message ") {
385            let (message, next_index) = parse_message_block(&lines, index, take_comment(&mut pending_comment))?;
386            schema.messages.insert(message.name.clone(), message);
387            index = next_index;
388            continue;
389        }
390
391        if trimmed.starts_with("enum ") {
392            let (enum_def, next_index) = parse_enum_block(&lines, index, take_comment(&mut pending_comment))?;
393            schema.enums.insert(enum_def.name.clone(), enum_def);
394            index = next_index;
395            continue;
396        }
397
398        if trimmed.starts_with("service ") {
399            let (service, next_index) = parse_service_block(&lines, index, take_comment(&mut pending_comment))?;
400            schema.services.insert(service.name.clone(), service);
401            index = next_index;
402            continue;
403        }
404
405        pending_comment.clear();
406        index += 1;
407    }
408
409    Ok(())
410}
411
412fn parse_message_block(lines: &[&str], start: usize, description: Option<String>) -> Result<(MessageDef, usize)> {
413    let header = strip_inline_comment(lines[start]).trim();
414    let name = extract_block_name(header, "message")
415        .ok_or_else(|| anyhow!("Invalid message declaration: {}", lines[start].trim()))?;
416
417    let mut message = MessageDef {
418        name,
419        fields: Vec::new(),
420        nested_messages: HashMap::new(),
421        nested_enums: HashMap::new(),
422        description,
423    };
424
425    let mut index = start + 1;
426    let mut depth = usize::from(header.contains('{'));
427    let mut pending_comment: Vec<String> = Vec::new();
428
429    while index < lines.len() {
430        let raw_line = lines[index];
431        let line = strip_inline_comment(raw_line);
432        let trimmed = line.trim();
433
434        if trimmed.starts_with("//") {
435            if let Some(comment) = raw_line.trim().strip_prefix("//") {
436                pending_comment.push(comment.trim().to_string());
437            }
438            index += 1;
439            continue;
440        }
441
442        let opens = trimmed.matches('{').count();
443        let closes = trimmed.matches('}').count();
444
445        if depth == 1 && !trimmed.is_empty() && !trimmed.starts_with("message ") && !trimmed.starts_with("enum ") {
446            if let Some(field) = parse_field(trimmed, take_comment(&mut pending_comment))? {
447                message.fields.push(field);
448            }
449        }
450
451        depth += opens;
452        depth = depth.saturating_sub(closes);
453        index += 1;
454
455        if depth == 0 {
456            break;
457        }
458    }
459
460    Ok((message, index))
461}
462
463fn parse_enum_block(lines: &[&str], start: usize, description: Option<String>) -> Result<(EnumDef, usize)> {
464    let header = strip_inline_comment(lines[start]).trim();
465    let name = extract_block_name(header, "enum")
466        .ok_or_else(|| anyhow!("Invalid enum declaration: {}", lines[start].trim()))?;
467
468    let mut enum_def = EnumDef {
469        name,
470        values: Vec::new(),
471        description,
472    };
473
474    let mut index = start + 1;
475    let mut depth = usize::from(header.contains('{'));
476    let mut pending_comment: Vec<String> = Vec::new();
477
478    while index < lines.len() {
479        let raw_line = lines[index];
480        let line = strip_inline_comment(raw_line);
481        let trimmed = line.trim();
482
483        if trimmed.starts_with("//") {
484            if let Some(comment) = raw_line.trim().strip_prefix("//") {
485                pending_comment.push(comment.trim().to_string());
486            }
487            index += 1;
488            continue;
489        }
490
491        let opens = trimmed.matches('{').count();
492        let closes = trimmed.matches('}').count();
493
494        if depth == 1 && trimmed.contains('=') && trimmed.ends_with(';') {
495            if let Some(value) = parse_enum_value(trimmed, take_comment(&mut pending_comment))? {
496                enum_def.values.push(value);
497            }
498        }
499
500        depth += opens;
501        depth = depth.saturating_sub(closes);
502        index += 1;
503
504        if depth == 0 {
505            break;
506        }
507    }
508
509    Ok((enum_def, index))
510}
511
512fn parse_service_block(lines: &[&str], start: usize, description: Option<String>) -> Result<(ServiceDef, usize)> {
513    let header = strip_inline_comment(lines[start]).trim();
514    let name = extract_block_name(header, "service")
515        .ok_or_else(|| anyhow!("Invalid service declaration: {}", lines[start].trim()))?;
516
517    let mut service = ServiceDef {
518        name,
519        methods: Vec::new(),
520        description,
521    };
522
523    let mut index = start + 1;
524    let mut depth = usize::from(header.contains('{'));
525    let mut pending_comment: Vec<String> = Vec::new();
526
527    while index < lines.len() {
528        let raw_line = lines[index];
529        let line = strip_inline_comment(raw_line);
530        let trimmed = line.trim();
531
532        if trimmed.starts_with("//") {
533            if let Some(comment) = raw_line.trim().strip_prefix("//") {
534                pending_comment.push(comment.trim().to_string());
535            }
536            index += 1;
537            continue;
538        }
539
540        let opens = trimmed.matches('{').count();
541        let closes = trimmed.matches('}').count();
542
543        if depth == 1 && trimmed.starts_with("rpc ") {
544            if let Some(method) = parse_rpc_method(trimmed, take_comment(&mut pending_comment))? {
545                service.methods.push(method);
546            }
547        }
548
549        depth += opens;
550        depth = depth.saturating_sub(closes);
551        index += 1;
552
553        if depth == 0 {
554            break;
555        }
556    }
557
558    Ok((service, index))
559}
560
561fn parse_field(line: &str, description: Option<String>) -> Result<Option<FieldDef>> {
562    if !line.ends_with(';') || line.starts_with("option ") || line.starts_with("reserved ") {
563        return Ok(None);
564    }
565
566    let without_semicolon = line.trim_end_matches(';');
567    let declaration = without_semicolon.split('[').next().unwrap_or(without_semicolon).trim();
568    let parts: Vec<&str> = declaration.split_whitespace().collect();
569
570    if parts.len() < 4 {
571        return Ok(None);
572    }
573
574    let (label, type_index) = match parts[0] {
575        "repeated" => (FieldLabel::Repeated, 1),
576        "optional" => (FieldLabel::Optional, 1),
577        _ => (FieldLabel::None, 0),
578    };
579
580    if parts.len() <= type_index + 2 {
581        return Ok(None);
582    }
583
584    let field_type = parse_proto_type(parts[type_index]);
585    let field_name = parts[type_index + 1].to_string();
586    let number = parts[type_index + 3]
587        .parse::<u32>()
588        .with_context(|| format!("Invalid field number in line: {line}"))?;
589
590    Ok(Some(FieldDef {
591        name: field_name,
592        number,
593        field_type,
594        label,
595        default_value: None,
596        description,
597    }))
598}
599
600fn parse_enum_value(line: &str, description: Option<String>) -> Result<Option<EnumValue>> {
601    let without_semicolon = line.trim_end_matches(';').trim();
602    let (name, number) = without_semicolon
603        .split_once('=')
604        .ok_or_else(|| anyhow!("Invalid enum value declaration: {line}"))?;
605
606    Ok(Some(EnumValue {
607        name: name.trim().to_string(),
608        number: number
609            .trim()
610            .parse::<i32>()
611            .with_context(|| format!("Invalid enum value number in line: {line}"))?,
612        description,
613    }))
614}
615
616fn parse_rpc_method(line: &str, description: Option<String>) -> Result<Option<MethodDef>> {
617    let without_semicolon = line.trim_end_matches(';').trim();
618    let after_rpc = without_semicolon
619        .strip_prefix("rpc ")
620        .ok_or_else(|| anyhow!("Invalid RPC declaration: {line}"))?;
621    let method_name_end = after_rpc
622        .find('(')
623        .ok_or_else(|| anyhow!("Invalid RPC declaration: {line}"))?;
624    let method_name = after_rpc[..method_name_end].trim().to_string();
625    let rest = &after_rpc[method_name_end + 1..];
626    let request_end = rest
627        .find(')')
628        .ok_or_else(|| anyhow!("Invalid RPC request declaration: {line}"))?;
629    let request_decl = rest[..request_end].trim();
630    let after_request = rest[request_end + 1..].trim();
631    let returns_decl = after_request
632        .strip_prefix("returns")
633        .ok_or_else(|| anyhow!("Invalid RPC returns declaration: {line}"))?
634        .trim();
635    let returns_decl = returns_decl
636        .strip_prefix('(')
637        .ok_or_else(|| anyhow!("Invalid RPC returns declaration: {line}"))?;
638    let response_end = returns_decl
639        .find(')')
640        .ok_or_else(|| anyhow!("Invalid RPC returns declaration: {line}"))?;
641    let response_decl = returns_decl[..response_end].trim();
642
643    let (input_streaming, input_type) = parse_streaming_type(request_decl);
644    let (output_streaming, output_type) = parse_streaming_type(response_decl);
645
646    Ok(Some(MethodDef {
647        name: method_name,
648        input_type,
649        output_type,
650        input_streaming,
651        output_streaming,
652        description,
653    }))
654}
655
656fn parse_streaming_type(declaration: &str) -> (bool, String) {
657    if let Some(rest) = declaration.strip_prefix("stream ") {
658        (true, rest.trim().to_string())
659    } else {
660        (false, declaration.trim().to_string())
661    }
662}
663
664fn extract_block_name(header: &str, keyword: &str) -> Option<String> {
665    header
666        .strip_prefix(keyword)?
667        .trim()
668        .strip_suffix('{')
669        .unwrap_or_else(|| header.strip_prefix(keyword).unwrap().trim())
670        .split_whitespace()
671        .next()
672        .map(std::string::ToString::to_string)
673}
674
675fn strip_inline_comment(line: &str) -> &str {
676    if let Some((before, _)) = line.split_once("//") {
677        before
678    } else {
679        line
680    }
681}
682
683fn take_comment(pending_comment: &mut Vec<String>) -> Option<String> {
684    if pending_comment.is_empty() {
685        None
686    } else {
687        let comment = pending_comment.join(" ");
688        pending_comment.clear();
689        Some(comment)
690    }
691}
692
693/// Helper function to parse type name from proto syntax
694#[allow(dead_code)]
695fn parse_proto_type(type_str: &str) -> ProtoType {
696    match type_str {
697        "double" => ProtoType::Double,
698        "float" => ProtoType::Float,
699        "int32" => ProtoType::Int32,
700        "int64" => ProtoType::Int64,
701        "uint32" => ProtoType::Uint32,
702        "uint64" => ProtoType::Uint64,
703        "sint32" => ProtoType::Sint32,
704        "sint64" => ProtoType::Sint64,
705        "fixed32" => ProtoType::Fixed32,
706        "fixed64" => ProtoType::Fixed64,
707        "sfixed32" => ProtoType::Sfixed32,
708        "sfixed64" => ProtoType::Sfixed64,
709        "bool" => ProtoType::Bool,
710        "string" => ProtoType::String,
711        "bytes" => ProtoType::Bytes,
712        _ => {
713            // Assume it's a message or enum type
714            ProtoType::Message(type_str.to_string())
715        }
716    }
717}
718
719#[cfg(test)]
720mod tests {
721    use super::*;
722    use tempfile::tempdir;
723
724    #[test]
725    fn test_parse_simple_proto3_schema() {
726        let proto = r#"syntax = "proto3";
727
728package example;
729
730message User {
731  string id = 1;
732  string name = 2;
733  string email = 3;
734}
735"#;
736
737        let schema = parse_proto_schema_string(proto).expect("Failed to parse proto");
738        assert_eq!(schema.syntax, "proto3");
739        assert_eq!(schema.package, Some("example".to_string()));
740        let user = schema.messages.get("User").expect("message should be parsed");
741        assert_eq!(user.fields.len(), 3);
742        assert_eq!(user.fields[0].name, "id");
743    }
744
745    #[test]
746    fn test_parse_proto_with_imports() {
747        let proto = r#"syntax = "proto3";
748
749import "google/protobuf/timestamp.proto";
750import "other.proto";
751
752package example;
753"#;
754
755        let schema = parse_proto_schema_string(proto).expect("Failed to parse proto");
756        assert_eq!(schema.imports.len(), 2);
757        assert!(schema.imports.contains(&"google/protobuf/timestamp.proto".to_string()));
758        assert!(schema.imports.contains(&"other.proto".to_string()));
759    }
760
761    #[test]
762    fn test_parse_proto_schema_with_includes_merges_imported_messages() {
763        let temp_dir = tempdir().expect("temp dir");
764        let shared_dir = temp_dir.path().join("common");
765        fs::create_dir_all(&shared_dir).expect("create include dir");
766
767        let shared_proto = shared_dir.join("types.proto");
768        fs::write(
769            &shared_proto,
770            r#"syntax = "proto3";
771
772package common;
773
774message SharedType {
775  string id = 1;
776}
777"#,
778        )
779        .expect("write shared proto");
780
781        let root_proto = temp_dir.path().join("service.proto");
782        fs::write(
783            &root_proto,
784            r#"syntax = "proto3";
785
786import "common/types.proto";
787
788package example;
789
790message UsesShared {
791  SharedType shared = 1;
792}
793"#,
794        )
795        .expect("write root proto");
796
797        let schema = parse_proto_schema_with_includes(&root_proto, &[temp_dir.path().to_path_buf()])
798            .expect("schema should resolve imports");
799
800        assert!(schema.messages.contains_key("UsesShared"));
801        assert!(schema.messages.contains_key("SharedType"));
802        assert!(schema.imports.contains(&"common/types.proto".to_string()));
803    }
804
805    #[test]
806    fn test_reject_proto2_syntax() {
807        let proto = r#"syntax = "proto2";
808
809package example;
810
811message User {
812  required string id = 1;
813}
814"#;
815
816        let result = parse_proto_schema_string(proto);
817        assert!(result.is_err());
818        let error_msg = format!("{}", result.unwrap_err());
819        assert!(error_msg.contains("Only proto3 syntax is supported"));
820        assert!(error_msg.contains("proto2"));
821    }
822
823    #[test]
824    fn test_parse_proto_type_scalars() {
825        assert_eq!(parse_proto_type("double"), ProtoType::Double);
826        assert_eq!(parse_proto_type("float"), ProtoType::Float);
827        assert_eq!(parse_proto_type("int32"), ProtoType::Int32);
828        assert_eq!(parse_proto_type("int64"), ProtoType::Int64);
829        assert_eq!(parse_proto_type("bool"), ProtoType::Bool);
830        assert_eq!(parse_proto_type("string"), ProtoType::String);
831        assert_eq!(parse_proto_type("bytes"), ProtoType::Bytes);
832    }
833
834    #[test]
835    fn test_parse_proto_type_message() {
836        match parse_proto_type("User") {
837            ProtoType::Message(name) => assert_eq!(name, "User"),
838            _ => panic!("Expected Message type"),
839        }
840    }
841
842    #[test]
843    fn test_parse_service_and_enum() {
844        let proto = r#"syntax = "proto3";
845
846package example;
847
848enum Status {
849  STATUS_UNKNOWN = 0;
850  STATUS_ACTIVE = 1;
851}
852
853service UserService {
854  rpc GetUser (GetUserRequest) returns (User);
855  rpc ListUsers (ListUsersRequest) returns (stream User);
856}
857"#;
858
859        let schema = parse_proto_schema_string(proto).expect("Failed to parse proto");
860        let status = schema.enums.get("Status").expect("enum should be parsed");
861        assert_eq!(status.values.len(), 2);
862
863        let service = schema.services.get("UserService").expect("service should be parsed");
864        assert_eq!(service.methods.len(), 2);
865        assert_eq!(service.methods[0].name, "GetUser");
866        assert!(service.methods[1].output_streaming);
867    }
868
869    #[test]
870    fn test_proto_type_as_str() {
871        assert_eq!(ProtoType::Double.as_str(), "double");
872        assert_eq!(ProtoType::String.as_str(), "string");
873        assert_eq!(ProtoType::Message("User".to_string()).as_str(), "User");
874    }
875}