Skip to main content

shaperail_codegen/
proto.rs

1//! Proto file generation from resource definitions (M16).
2//!
3//! Generates `.proto` files from `ResourceDefinition`s. Each resource produces
4//! a service with CRUD RPCs including server-streaming for list endpoints.
5
6use shaperail_core::{FieldType, ResourceDefinition};
7
8/// Maps a Shaperail `FieldType` to a protobuf type string.
9fn field_type_to_proto(ft: &FieldType) -> &'static str {
10    match ft {
11        FieldType::Uuid => "string",
12        FieldType::String => "string",
13        FieldType::Integer => "int32",
14        FieldType::Bigint => "int64",
15        FieldType::Number => "double",
16        FieldType::Boolean => "bool",
17        FieldType::Timestamp => "google.protobuf.Timestamp",
18        FieldType::Date => "string",
19        FieldType::Enum => "string",
20        FieldType::Json => "google.protobuf.Struct",
21        FieldType::Array => "google.protobuf.ListValue",
22        FieldType::File => "string",
23    }
24}
25
26/// Returns true if the field type requires a well-known-types import.
27pub fn needs_wkt_import(ft: &FieldType) -> bool {
28    matches!(
29        ft,
30        FieldType::Timestamp | FieldType::Json | FieldType::Array
31    )
32}
33
34/// Converts a snake_case resource name to PascalCase for protobuf message names.
35fn to_pascal_case(s: &str) -> String {
36    s.split('_')
37        .map(|part| {
38            let mut chars = part.chars();
39            match chars.next() {
40                Some(c) => {
41                    let upper: String = c.to_uppercase().collect();
42                    upper + chars.as_str()
43                }
44                None => String::new(),
45            }
46        })
47        .collect()
48}
49
50/// Converts a plural resource name to a singular form for message naming.
51/// Simple heuristic: strip trailing 's' if present.
52fn to_singular(s: &str) -> String {
53    // Words that end in 's' but are not plural
54    const EXCEPTIONS: &[&str] = &["status", "bus", "alias", "canvas"];
55    if EXCEPTIONS.iter().any(|e| s.ends_with(e)) {
56        return s.to_string();
57    }
58
59    if let Some(stripped) = s.strip_suffix("ies") {
60        format!("{stripped}y")
61    } else if s.ends_with("ses") || s.ends_with("xes") || s.ends_with("zes") {
62        // Strip "es" from "addresses", "boxes", "buzzes" etc.
63        s[..s.len() - 2].to_string()
64    } else if let Some(stripped) = s.strip_suffix('s') {
65        if stripped.ends_with('s') {
66            s.to_string()
67        } else {
68            stripped.to_string()
69        }
70    } else {
71        s.to_string()
72    }
73}
74
75/// Generates a `.proto` file content from a single `ResourceDefinition`.
76///
77/// The generated proto includes:
78/// - A message type for the resource with all schema fields
79/// - Create/Update input messages based on endpoint `input` fields
80/// - Request/Response wrappers for each endpoint
81/// - A gRPC service with RPCs for each declared endpoint
82/// - Server-streaming RPC for list endpoints
83pub fn generate_proto(resource: &ResourceDefinition) -> String {
84    let resource_name = &resource.resource;
85    let singular = to_singular(resource_name);
86    let pascal = to_pascal_case(&singular);
87    let pascal_plural = to_pascal_case(resource_name);
88    let version = resource.version;
89
90    let mut needs_timestamp = false;
91    let mut needs_struct = false;
92
93    // Check if we need WKT imports
94    for field in resource.schema.values() {
95        if matches!(field.field_type, FieldType::Timestamp) {
96            needs_timestamp = true;
97        }
98        if matches!(field.field_type, FieldType::Json | FieldType::Array) {
99            needs_struct = true;
100        }
101    }
102
103    let mut proto = String::new();
104    proto.push_str("syntax = \"proto3\";\n\n");
105    proto.push_str(&format!(
106        "package shaperail.v{version}.{resource_name};\n\n"
107    ));
108
109    if needs_timestamp {
110        proto.push_str("import \"google/protobuf/timestamp.proto\";\n");
111    }
112    if needs_struct {
113        proto.push_str("import \"google/protobuf/struct.proto\";\n");
114    }
115    if needs_timestamp || needs_struct {
116        proto.push('\n');
117    }
118
119    // Resource message
120    proto.push_str(&format!("// {pascal} resource message.\n"));
121    proto.push_str(&format!("message {pascal} {{\n"));
122    for (i, (field_name, field_schema)) in resource.schema.iter().enumerate() {
123        let proto_type = field_type_to_proto(&field_schema.field_type);
124        let field_num = i + 1;
125        if field_schema.field_type == FieldType::Enum {
126            if let Some(ref values) = field_schema.values {
127                proto.push_str(&format!("  // Allowed values: {}\n", values.join(", ")));
128            }
129        }
130        proto.push_str(&format!("  {proto_type} {field_name} = {field_num};\n"));
131    }
132    proto.push_str("}\n\n");
133
134    // Determine endpoints
135    let endpoints = resource.endpoints.as_ref();
136    let has_list = endpoints.and_then(|e| e.get("list")).is_some();
137    let has_get = endpoints.and_then(|e| e.get("get")).is_some();
138    let has_create = endpoints.and_then(|e| e.get("create")).is_some();
139    let has_update = endpoints.and_then(|e| e.get("update")).is_some();
140    let has_delete = endpoints.and_then(|e| e.get("delete")).is_some();
141
142    // List request/response
143    if has_list {
144        proto.push_str(&format!("message List{pascal_plural}Request {{\n"));
145        // Add filter fields from the list endpoint
146        if let Some(ep) = endpoints.and_then(|e| e.get("list")) {
147            let mut field_num = 1;
148            if let Some(ref filters) = ep.filters {
149                for f in filters {
150                    proto.push_str(&format!("  string {f} = {field_num};\n"));
151                    field_num += 1;
152                }
153            }
154            if ep.search.is_some() {
155                proto.push_str(&format!("  string search = {field_num};\n"));
156                field_num += 1;
157            }
158            proto.push_str(&format!("  string cursor = {field_num};\n"));
159            field_num += 1;
160            proto.push_str(&format!("  int32 page_size = {field_num};\n"));
161            field_num += 1;
162            proto.push_str(&format!("  string sort = {field_num};\n"));
163        }
164        proto.push_str("}\n\n");
165
166        proto.push_str(&format!("message List{pascal_plural}Response {{\n"));
167        proto.push_str(&format!("  repeated {pascal} items = 1;\n"));
168        proto.push_str("  string next_cursor = 2;\n");
169        proto.push_str("  bool has_more = 3;\n");
170        proto.push_str("  int64 total = 4;\n");
171        proto.push_str("}\n\n");
172    }
173
174    // Get request/response
175    if has_get {
176        proto.push_str(&format!("message Get{pascal}Request {{\n"));
177        proto.push_str("  string id = 1;\n");
178        proto.push_str("}\n\n");
179
180        proto.push_str(&format!("message Get{pascal}Response {{\n"));
181        proto.push_str(&format!("  {pascal} data = 1;\n"));
182        proto.push_str("}\n\n");
183    }
184
185    // Create request/response
186    if has_create {
187        proto.push_str(&format!("message Create{pascal}Request {{\n"));
188        if let Some(ep) = endpoints.and_then(|e| e.get("create")) {
189            if let Some(ref input) = ep.input {
190                for (i, field_name) in input.iter().enumerate() {
191                    let proto_type = resource
192                        .schema
193                        .get(field_name.as_str())
194                        .map(|f| field_type_to_proto(&f.field_type))
195                        .unwrap_or("string");
196                    proto.push_str(&format!("  {proto_type} {field_name} = {};\n", i + 1));
197                }
198            }
199        }
200        proto.push_str("}\n\n");
201
202        proto.push_str(&format!("message Create{pascal}Response {{\n"));
203        proto.push_str(&format!("  {pascal} data = 1;\n"));
204        proto.push_str("}\n\n");
205    }
206
207    // Update request/response
208    if has_update {
209        proto.push_str(&format!("message Update{pascal}Request {{\n"));
210        proto.push_str("  string id = 1;\n");
211        if let Some(ep) = endpoints.and_then(|e| e.get("update")) {
212            if let Some(ref input) = ep.input {
213                for (i, field_name) in input.iter().enumerate() {
214                    let proto_type = resource
215                        .schema
216                        .get(field_name.as_str())
217                        .map(|f| field_type_to_proto(&f.field_type))
218                        .unwrap_or("string");
219                    proto.push_str(&format!("  {proto_type} {field_name} = {};\n", i + 2));
220                }
221            }
222        }
223        proto.push_str("}\n\n");
224
225        proto.push_str(&format!("message Update{pascal}Response {{\n"));
226        proto.push_str(&format!("  {pascal} data = 1;\n"));
227        proto.push_str("}\n\n");
228    }
229
230    // Delete request/response
231    if has_delete {
232        proto.push_str(&format!("message Delete{pascal}Request {{\n"));
233        proto.push_str("  string id = 1;\n");
234        proto.push_str("}\n\n");
235
236        proto.push_str(&format!("message Delete{pascal}Response {{\n"));
237        proto.push_str("  bool success = 1;\n");
238        proto.push_str("}\n\n");
239    }
240
241    // Service definition
242    proto.push_str(&format!(
243        "// gRPC service for {resource_name} (v{version}).\n"
244    ));
245    proto.push_str(&format!("service {pascal}Service {{\n"));
246
247    if has_list {
248        proto.push_str(&format!(
249            "  // Lists {resource_name} with filters, pagination, and sorting.\n"
250        ));
251        proto.push_str(&format!(
252            "  rpc List{pascal_plural}(List{pascal_plural}Request) returns (List{pascal_plural}Response);\n\n"
253        ));
254        proto.push_str(&format!(
255            "  // Streams {resource_name} matching the request filters.\n"
256        ));
257        proto.push_str(&format!(
258            "  rpc Stream{pascal_plural}(List{pascal_plural}Request) returns (stream {pascal});\n\n"
259        ));
260    }
261
262    if has_get {
263        proto.push_str(&format!("  // Gets a single {singular} by ID.\n"));
264        proto.push_str(&format!(
265            "  rpc Get{pascal}(Get{pascal}Request) returns (Get{pascal}Response);\n\n"
266        ));
267    }
268
269    if has_create {
270        proto.push_str(&format!("  // Creates a new {singular}.\n"));
271        proto.push_str(&format!(
272            "  rpc Create{pascal}(Create{pascal}Request) returns (Create{pascal}Response);\n\n"
273        ));
274    }
275
276    if has_update {
277        proto.push_str(&format!("  // Updates an existing {singular}.\n"));
278        proto.push_str(&format!(
279            "  rpc Update{pascal}(Update{pascal}Request) returns (Update{pascal}Response);\n\n"
280        ));
281    }
282
283    if has_delete {
284        proto.push_str(&format!("  // Deletes a {singular} by ID.\n"));
285        proto.push_str(&format!(
286            "  rpc Delete{pascal}(Delete{pascal}Request) returns (Delete{pascal}Response);\n"
287        ));
288    }
289
290    proto.push_str("}\n");
291
292    proto
293}
294
295/// Generates proto files for all resources, returning `(filename, content)` pairs.
296pub fn generate_all_protos(resources: &[ResourceDefinition]) -> Vec<(String, String)> {
297    resources
298        .iter()
299        .map(|r| {
300            let filename = format!("{}.proto", r.resource);
301            let content = generate_proto(r);
302            (filename, content)
303        })
304        .collect()
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310
311    #[test]
312    fn to_pascal_case_simple() {
313        assert_eq!(to_pascal_case("users"), "Users");
314        assert_eq!(to_pascal_case("blog_posts"), "BlogPosts");
315        assert_eq!(to_pascal_case("api_keys"), "ApiKeys");
316    }
317
318    #[test]
319    fn to_singular_simple() {
320        assert_eq!(to_singular("users"), "user");
321        assert_eq!(to_singular("blog_posts"), "blog_post");
322        assert_eq!(to_singular("categories"), "category");
323        assert_eq!(to_singular("addresses"), "address");
324        assert_eq!(to_singular("status"), "status");
325    }
326
327    #[test]
328    fn field_type_mapping() {
329        assert_eq!(field_type_to_proto(&FieldType::Uuid), "string");
330        assert_eq!(field_type_to_proto(&FieldType::String), "string");
331        assert_eq!(field_type_to_proto(&FieldType::Integer), "int32");
332        assert_eq!(field_type_to_proto(&FieldType::Bigint), "int64");
333        assert_eq!(field_type_to_proto(&FieldType::Number), "double");
334        assert_eq!(field_type_to_proto(&FieldType::Boolean), "bool");
335        assert_eq!(
336            field_type_to_proto(&FieldType::Timestamp),
337            "google.protobuf.Timestamp"
338        );
339        assert_eq!(field_type_to_proto(&FieldType::Date), "string");
340        assert_eq!(field_type_to_proto(&FieldType::Enum), "string");
341        assert_eq!(
342            field_type_to_proto(&FieldType::Json),
343            "google.protobuf.Struct"
344        );
345    }
346
347    #[test]
348    fn wkt_import_detection() {
349        assert!(needs_wkt_import(&FieldType::Timestamp));
350        assert!(needs_wkt_import(&FieldType::Json));
351        assert!(needs_wkt_import(&FieldType::Array));
352        assert!(!needs_wkt_import(&FieldType::String));
353        assert!(!needs_wkt_import(&FieldType::Uuid));
354    }
355
356    use indexmap::IndexMap;
357    use shaperail_core::{EndpointSpec, FieldSchema, HttpMethod};
358
359    fn field(ft: FieldType) -> FieldSchema {
360        FieldSchema {
361            field_type: ft,
362            primary: false,
363            generated: false,
364            required: false,
365            unique: false,
366            nullable: false,
367            reference: None,
368            min: None,
369            max: None,
370            format: None,
371            values: None,
372            default: None,
373            sensitive: false,
374            search: false,
375            items: None,
376        }
377    }
378
379    fn endpoint(method: HttpMethod, path: &str) -> EndpointSpec {
380        EndpointSpec {
381            method,
382            path: path.to_string(),
383            auth: None,
384            input: None,
385            filters: None,
386            search: None,
387            pagination: None,
388            sort: None,
389            cache: None,
390            controller: None,
391            events: None,
392            jobs: None,
393            upload: None,
394            soft_delete: false,
395        }
396    }
397
398    #[test]
399    fn generate_proto_basic_resource() {
400        let mut schema = IndexMap::new();
401        schema.insert(
402            "id".to_string(),
403            FieldSchema {
404                primary: true,
405                generated: true,
406                ..field(FieldType::Uuid)
407            },
408        );
409        schema.insert(
410            "name".to_string(),
411            FieldSchema {
412                required: true,
413                ..field(FieldType::String)
414            },
415        );
416        schema.insert("active".to_string(), field(FieldType::Boolean));
417
418        let mut endpoints = IndexMap::new();
419        endpoints.insert(
420            "list".to_string(),
421            EndpointSpec {
422                filters: Some(vec!["active".to_string()]),
423                search: Some(vec!["name".to_string()]),
424                ..endpoint(HttpMethod::Get, "/items")
425            },
426        );
427        endpoints.insert("get".to_string(), endpoint(HttpMethod::Get, "/items/:id"));
428        endpoints.insert(
429            "create".to_string(),
430            EndpointSpec {
431                input: Some(vec!["name".to_string(), "active".to_string()]),
432                ..endpoint(HttpMethod::Post, "/items")
433            },
434        );
435        endpoints.insert(
436            "delete".to_string(),
437            endpoint(HttpMethod::Delete, "/items/:id"),
438        );
439
440        let resource = ResourceDefinition {
441            resource: "items".to_string(),
442            version: 1,
443            db: None,
444            schema,
445            endpoints: Some(endpoints),
446            relations: None,
447            indexes: None,
448        };
449
450        let proto = generate_proto(&resource);
451
452        assert!(proto.contains("syntax = \"proto3\";"));
453        assert!(proto.contains("package shaperail.v1.items;"));
454        assert!(proto.contains("message Item {"));
455        assert!(proto.contains("string id = 1;"));
456        assert!(proto.contains("string name = 2;"));
457        assert!(proto.contains("bool active = 3;"));
458        assert!(proto.contains("service ItemService {"));
459        assert!(proto.contains("rpc ListItems(ListItemsRequest) returns (ListItemsResponse);"));
460        assert!(proto.contains("rpc StreamItems(ListItemsRequest) returns (stream Item);"));
461        assert!(proto.contains("rpc GetItem(GetItemRequest) returns (GetItemResponse);"));
462        assert!(proto.contains("rpc CreateItem(CreateItemRequest) returns (CreateItemResponse);"));
463        assert!(proto.contains("rpc DeleteItem(DeleteItemRequest) returns (DeleteItemResponse);"));
464        assert!(proto.contains("string active = 1;"));
465        assert!(proto.contains("string search = 2;"));
466        assert!(proto.contains("string cursor = 3;"));
467    }
468
469    #[test]
470    fn generate_proto_with_timestamp() {
471        let mut schema = IndexMap::new();
472        schema.insert(
473            "id".to_string(),
474            FieldSchema {
475                primary: true,
476                generated: true,
477                ..field(FieldType::Uuid)
478            },
479        );
480        schema.insert(
481            "created_at".to_string(),
482            FieldSchema {
483                generated: true,
484                ..field(FieldType::Timestamp)
485            },
486        );
487
488        let resource = ResourceDefinition {
489            resource: "events".to_string(),
490            version: 2,
491            db: None,
492            schema,
493            endpoints: None,
494            relations: None,
495            indexes: None,
496        };
497
498        let proto = generate_proto(&resource);
499        assert!(proto.contains("import \"google/protobuf/timestamp.proto\";"));
500        assert!(proto.contains("google.protobuf.Timestamp created_at = 2;"));
501        assert!(proto.contains("package shaperail.v2.events;"));
502    }
503
504    #[test]
505    fn generate_all_protos_multiple() {
506        let make_schema = || {
507            let mut s = IndexMap::new();
508            s.insert(
509                "id".to_string(),
510                FieldSchema {
511                    primary: true,
512                    ..field(FieldType::Uuid)
513                },
514            );
515            s
516        };
517
518        let resources = vec![
519            ResourceDefinition {
520                resource: "users".to_string(),
521                version: 1,
522                db: None,
523                schema: make_schema(),
524                endpoints: None,
525                relations: None,
526                indexes: None,
527            },
528            ResourceDefinition {
529                resource: "orders".to_string(),
530                version: 1,
531                db: None,
532                schema: make_schema(),
533                endpoints: None,
534                relations: None,
535                indexes: None,
536            },
537        ];
538
539        let protos = generate_all_protos(&resources);
540        assert_eq!(protos.len(), 2);
541        assert_eq!(protos[0].0, "users.proto");
542        assert_eq!(protos[1].0, "orders.proto");
543        assert!(protos[0].1.contains("message User {"));
544        assert!(protos[1].1.contains("message Order {"));
545    }
546}