ttrpc_compiler/
codegen.rs

1// Copyright (c) 2019 Ant Financial
2//
3// Copyright 2017 PingCAP, Inc.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9//     http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16// Copyright (c) 2016, Stepan Koltsov
17//
18// Permission is hereby granted, free of charge, to any person obtaining
19// a copy of this software and associated documentation files (the
20// "Software"), to deal in the Software without restriction, including
21// without limitation the rights to use, copy, modify, merge, publish,
22// distribute, sublicense, and/or sell copies of the Software, and to
23// permit persons to whom the Software is furnished to do so, subject to
24// the following conditions:
25//
26// The above copyright notice and this permission notice shall be
27// included in all copies or substantial portions of the Software.
28//
29// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
30// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
31// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
32// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
33// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
34// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
35// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
36
37#![allow(dead_code)]
38
39use std::{
40    collections::{HashMap, HashSet},
41    fs,
42    io::BufRead,
43};
44
45use crate::Customize;
46use protobuf::{
47    compiler_plugin::{GenRequest, GenResult},
48    descriptor::*,
49    descriptorx::*,
50    plugin::{
51        CodeGeneratorRequest, CodeGeneratorResponse, CodeGeneratorResponse_Feature,
52        CodeGeneratorResponse_File,
53    },
54    Message,
55};
56use protobuf_codegen::code_writer::CodeWriter;
57use std::fs::File;
58use std::io::{self, stdin, stdout, Write};
59use std::path::Path;
60
61use super::util::{
62    self, async_on, def_async_fn, fq_grpc, pub_async_fn, to_camel_case, to_snake_case, MethodType,
63};
64
65struct MethodGen<'a> {
66    proto: &'a MethodDescriptorProto,
67    package_name: String,
68    service_name: String,
69    root_scope: &'a RootScope<'a>,
70    customize: &'a Customize,
71}
72
73impl<'a> MethodGen<'a> {
74    fn new(
75        proto: &'a MethodDescriptorProto,
76        package_name: String,
77        service_name: String,
78        root_scope: &'a RootScope<'a>,
79        customize: &'a Customize,
80    ) -> MethodGen<'a> {
81        MethodGen {
82            proto,
83            package_name,
84            service_name,
85            root_scope,
86            customize,
87        }
88    }
89
90    fn input(&self) -> String {
91        format!(
92            "super::{}",
93            self.root_scope
94                .find_message(self.proto.get_input_type())
95                .rust_fq_name()
96        )
97    }
98
99    fn output(&self) -> String {
100        format!(
101            "super::{}",
102            self.root_scope
103                .find_message(self.proto.get_output_type())
104                .rust_fq_name()
105        )
106    }
107
108    fn method_type(&self) -> (MethodType, String) {
109        match (
110            self.proto.get_client_streaming(),
111            self.proto.get_server_streaming(),
112        ) {
113            (false, false) => (MethodType::Unary, fq_grpc("MethodType::Unary")),
114            (true, false) => (
115                MethodType::ClientStreaming,
116                fq_grpc("MethodType::ClientStreaming"),
117            ),
118            (false, true) => (
119                MethodType::ServerStreaming,
120                fq_grpc("MethodType::ServerStreaming"),
121            ),
122            (true, true) => (MethodType::Duplex, fq_grpc("MethodType::Duplex")),
123        }
124    }
125
126    fn service_name(&self) -> String {
127        to_snake_case(&self.service_name)
128    }
129
130    fn name(&self) -> String {
131        to_snake_case(self.proto.get_name())
132    }
133
134    fn struct_name(&self) -> String {
135        to_camel_case(self.proto.get_name())
136    }
137
138    fn const_method_name(&self) -> String {
139        format!(
140            "METHOD_{}_{}",
141            self.service_name().to_uppercase(),
142            self.name().to_uppercase()
143        )
144    }
145
146    fn write_handler(&self, w: &mut CodeWriter) {
147        w.block(
148            &format!("struct {}Method {{", self.struct_name()),
149            "}",
150            |w| {
151                w.write_line(format!(
152                    "service: Arc<dyn {} + Send + Sync>,",
153                    self.service_name
154                ));
155            },
156        );
157        w.write_line("");
158        if async_on(self.customize, "server") {
159            self.write_handler_impl_async(w)
160        } else {
161            self.write_handler_impl(w)
162        }
163    }
164
165    fn write_handler_impl(&self, w: &mut CodeWriter) {
166        w.block(&format!("impl ::ttrpc::MethodHandler for {}Method {{", self.struct_name()), "}",
167        |w| {
168            w.block("fn handler(&self, ctx: ::ttrpc::TtrpcContext, req: ::ttrpc::Request) -> ::ttrpc::Result<()> {", "}",
169            |w| {
170                w.write_line(format!("::ttrpc::request_handler!(self, ctx, req, {}, {}, {});",
171                                        proto_path_to_rust_mod(self.root_scope.find_message(self.proto.get_input_type()).get_scope().get_file_descriptor().get_name()),
172                                        self.root_scope.find_message(self.proto.get_input_type()).rust_name(),
173                                        self.name()));
174                w.write_line("Ok(())");
175            });
176        });
177    }
178
179    fn write_handler_impl_async(&self, w: &mut CodeWriter) {
180        w.write_line("#[async_trait]");
181        match self.method_type().0 {
182            MethodType::Unary => {
183                w.block(&format!("impl ::ttrpc::r#async::MethodHandler for {}Method {{", self.struct_name()), "}",
184                |w| {
185                    w.block("async fn handler(&self, ctx: ::ttrpc::r#async::TtrpcContext, req: ::ttrpc::Request) -> ::ttrpc::Result<::ttrpc::Response> {", "}",
186                        |w| {
187                            w.write_line(format!("::ttrpc::async_request_handler!(self, ctx, req, {}, {}, {});",
188                                        proto_path_to_rust_mod(self.root_scope.find_message(self.proto.get_input_type()).get_scope().get_file_descriptor().get_name()),
189                                        self.root_scope.find_message(self.proto.get_input_type()).rust_name(),
190                                        self.name()));
191                    });
192            });
193            }
194            // only receive
195            MethodType::ClientStreaming => {
196                w.block(&format!("impl ::ttrpc::r#async::StreamHandler for {}Method {{", self.struct_name()), "}",
197                |w| {
198                    w.block("async fn handler(&self, ctx: ::ttrpc::r#async::TtrpcContext, inner: ::ttrpc::r#async::StreamInner) -> ::ttrpc::Result<Option<::ttrpc::Response>> {", "}",
199                        |w| {
200                            w.write_line(format!("::ttrpc::async_client_streamimg_handler!(self, ctx, inner, {});",
201                                        self.name()));
202                    });
203            });
204            }
205            // only send
206            MethodType::ServerStreaming => {
207                w.block(&format!("impl ::ttrpc::r#async::StreamHandler for {}Method {{", self.struct_name()), "}",
208                |w| {
209                    w.block("async fn handler(&self, ctx: ::ttrpc::r#async::TtrpcContext, mut inner: ::ttrpc::r#async::StreamInner) -> ::ttrpc::Result<Option<::ttrpc::Response>> {", "}",
210                        |w| {
211                            w.write_line(format!("::ttrpc::async_server_streamimg_handler!(self, ctx, inner, {}, {}, {});",
212                                        proto_path_to_rust_mod(self.root_scope.find_message(self.proto.get_input_type()).get_scope().get_file_descriptor().get_name()),
213                                        self.root_scope.find_message(self.proto.get_input_type()).rust_name(),
214                                        self.name()));
215                    });
216            });
217            }
218            // receive and send
219            MethodType::Duplex => {
220                w.block(&format!("impl ::ttrpc::r#async::StreamHandler for {}Method {{", self.struct_name()), "}",
221                |w| {
222                    w.block("async fn handler(&self, ctx: ::ttrpc::r#async::TtrpcContext, inner: ::ttrpc::r#async::StreamInner) -> ::ttrpc::Result<Option<::ttrpc::Response>> {", "}",
223                        |w| {
224                            w.write_line(format!("::ttrpc::async_duplex_streamimg_handler!(self, ctx, inner, {});",
225                                        self.name()));
226                    });
227            });
228            }
229        }
230    }
231
232    // Method signatures
233    fn unary(&self, method_name: &str) -> String {
234        format!(
235            "{}(&self, ctx: ttrpc::context::Context, req: &{}) -> {}<{}>",
236            method_name,
237            self.input(),
238            fq_grpc("Result"),
239            self.output()
240        )
241    }
242
243    fn client_streaming(&self, method_name: &str) -> String {
244        format!(
245            "{}(&self, ctx: ttrpc::context::Context) -> {}<{}<{}, {}>>",
246            method_name,
247            fq_grpc("Result"),
248            fq_grpc("r#async::ClientStreamSender"),
249            self.input(),
250            self.output()
251        )
252    }
253
254    fn server_streaming(&self, method_name: &str) -> String {
255        format!(
256            "{}(&self, ctx: ttrpc::context::Context, req: &{}) -> {}<{}<{}>>",
257            method_name,
258            self.input(),
259            fq_grpc("Result"),
260            fq_grpc("r#async::ClientStreamReceiver"),
261            self.output()
262        )
263    }
264
265    fn duplex_streaming(&self, method_name: &str) -> String {
266        format!(
267            "{}(&self, ctx: ttrpc::context::Context) -> {}<{}<{}, {}>>",
268            method_name,
269            fq_grpc("Result"),
270            fq_grpc("r#async::ClientStream"),
271            self.input(),
272            self.output()
273        )
274    }
275
276    fn write_client(&self, w: &mut CodeWriter) {
277        let method_name = self.name();
278        if let MethodType::Unary = self.method_type().0 {
279            w.pub_fn(&self.unary(&method_name), |w| {
280                w.write_line(format!("let mut cres = {}::new();", self.output()));
281                w.write_line(format!(
282                    "::ttrpc::client_request!(self, ctx, req, \"{}.{}\", \"{}\", cres);",
283                    self.package_name,
284                    self.service_name,
285                    &self.proto.get_name(),
286                ));
287                w.write_line("Ok(cres)");
288            });
289        }
290    }
291
292    fn write_async_client(&self, w: &mut CodeWriter) {
293        let method_name = self.name();
294        match self.method_type().0 {
295            // Unary RPC
296            MethodType::Unary => {
297                pub_async_fn(w, &self.unary(&method_name), |w| {
298                    w.write_line(format!("let mut cres = {}::new();", self.output()));
299                    w.write_line(format!(
300                        "::ttrpc::async_client_request!(self, ctx, req, \"{}.{}\", \"{}\", cres);",
301                        self.package_name,
302                        self.service_name,
303                        &self.proto.get_name(),
304                    ));
305                });
306            }
307            // Client Streaming RPC
308            MethodType::ClientStreaming => {
309                pub_async_fn(w, &self.client_streaming(&method_name), |w| {
310                    w.write_line(format!(
311                        "::ttrpc::async_client_stream_send!(self, ctx, \"{}.{}\", \"{}\");",
312                        self.package_name,
313                        self.service_name,
314                        &self.proto.get_name(),
315                    ));
316                });
317            }
318            // Server Streaming RPC
319            MethodType::ServerStreaming => {
320                pub_async_fn(w, &self.server_streaming(&method_name), |w| {
321                    w.write_line(format!(
322                        "::ttrpc::async_client_stream_receive!(self, ctx, req, \"{}.{}\", \"{}\");",
323                        self.package_name,
324                        self.service_name,
325                        &self.proto.get_name(),
326                    ));
327                });
328            }
329            // Bidirectional streaming RPC
330            MethodType::Duplex => {
331                pub_async_fn(w, &self.duplex_streaming(&method_name), |w| {
332                    w.write_line(format!(
333                        "::ttrpc::async_client_stream!(self, ctx, \"{}.{}\", \"{}\");",
334                        self.package_name,
335                        self.service_name,
336                        &self.proto.get_name(),
337                    ));
338                });
339            }
340        };
341    }
342
343    fn write_service(&self, w: &mut CodeWriter) {
344        let (_req, req_type, resp_type) = match self.method_type().0 {
345            MethodType::Unary => ("req", self.input(), self.output()),
346            MethodType::ClientStreaming => (
347                "stream",
348                format!("::ttrpc::r#async::ServerStreamReceiver<{}>", self.input()),
349                self.output(),
350            ),
351            MethodType::ServerStreaming => (
352                "req",
353                format!(
354                    "{}, _: {}<{}>",
355                    self.input(),
356                    "::ttrpc::r#async::ServerStreamSender",
357                    self.output()
358                ),
359                "()".to_string(),
360            ),
361            MethodType::Duplex => (
362                "stream",
363                format!(
364                    "{}<{}, {}>",
365                    "::ttrpc::r#async::ServerStream",
366                    self.output(),
367                    self.input(),
368                ),
369                "()".to_string(),
370            ),
371        };
372
373        let get_sig = |context_name| {
374            format!(
375                "{}(&self, _ctx: &{}, _: {}) -> ::ttrpc::Result<{}>",
376                self.name(),
377                fq_grpc(context_name),
378                req_type,
379                resp_type,
380            )
381        };
382
383        let cb = |w: &mut CodeWriter| {
384            w.write_line(format!("Err(::ttrpc::Error::RpcStatus(::ttrpc::get_status(::ttrpc::Code::NOT_FOUND, \"/{}.{}/{} is not supported\".to_string())))",
385            self.package_name,
386            self.service_name, self.proto.get_name(),));
387        };
388
389        if async_on(self.customize, "server") {
390            let sig = get_sig("r#async::TtrpcContext");
391            def_async_fn(w, &sig, cb);
392        } else {
393            let sig = get_sig("TtrpcContext");
394            w.def_fn(&sig, cb);
395        }
396    }
397
398    fn write_bind(&self, w: &mut CodeWriter) {
399        let method_handler_name = "::ttrpc::MethodHandler";
400
401        let s = format!(
402            "methods.insert(\"/{}.{}/{}\".to_string(),
403                    Box::new({}Method{{service: service.clone()}}) as Box<dyn {} + Send + Sync>);",
404            self.package_name,
405            self.service_name,
406            self.proto.get_name(),
407            self.struct_name(),
408            method_handler_name,
409        );
410        w.write_line(&s);
411    }
412
413    fn write_async_bind(&self, w: &mut CodeWriter) {
414        let s = if matches!(self.method_type().0, MethodType::Unary) {
415            format!(
416                "methods.insert(\"{}\".to_string(),
417                    Box::new({}Method{{service: service.clone()}}) as {});",
418                self.proto.get_name(),
419                self.struct_name(),
420                "Box<dyn ::ttrpc::r#async::MethodHandler + Send + Sync>"
421            )
422        } else {
423            format!(
424                "streams.insert(\"{}\".to_string(),
425                    Arc::new({}Method{{service: service.clone()}}) as {});",
426                self.proto.get_name(),
427                self.struct_name(),
428                "Arc<dyn ::ttrpc::r#async::StreamHandler + Send + Sync>"
429            )
430        };
431        w.write_line(&s);
432    }
433}
434
435struct ServiceGen<'a> {
436    proto: &'a ServiceDescriptorProto,
437    methods: Vec<MethodGen<'a>>,
438    customize: &'a Customize,
439    package_name: String,
440}
441
442impl<'a> ServiceGen<'a> {
443    fn new(
444        proto: &'a ServiceDescriptorProto,
445        file: &FileDescriptorProto,
446        root_scope: &'a RootScope,
447        customize: &'a Customize,
448    ) -> ServiceGen<'a> {
449        let methods = proto
450            .get_method()
451            .iter()
452            .map(|m| {
453                MethodGen::new(
454                    m,
455                    file.get_package().to_string(),
456                    util::to_camel_case(proto.get_name()),
457                    root_scope,
458                    customize,
459                )
460            })
461            .collect();
462
463        ServiceGen {
464            proto,
465            methods,
466            customize,
467            package_name: file.get_package().to_string(),
468        }
469    }
470
471    fn service_name(&self) -> String {
472        util::to_camel_case(self.proto.get_name())
473    }
474
475    fn service_path(&self) -> String {
476        format!("{}.{}", self.package_name, self.service_name())
477    }
478
479    fn client_name(&self) -> String {
480        format!("{}Client", self.service_name())
481    }
482
483    fn has_stream_method(&self) -> bool {
484        self.methods
485            .iter()
486            .any(|method| !matches!(method.method_type().0, MethodType::Unary))
487    }
488
489    fn write_client(&self, w: &mut CodeWriter) {
490        if async_on(self.customize, "client") {
491            self.write_async_client(w)
492        } else {
493            self.write_sync_client(w)
494        }
495    }
496
497    fn write_sync_client(&self, w: &mut CodeWriter) {
498        w.write_line("#[derive(Clone)]");
499        w.pub_struct(self.client_name(), |w| {
500            w.field_decl("client", "::ttrpc::Client");
501        });
502
503        w.write_line("");
504
505        w.impl_self_block(self.client_name(), |w| {
506            w.pub_fn("new(client: ::ttrpc::Client) -> Self", |w| {
507                w.expr_block(&self.client_name(), |w| {
508                    w.write_line("client,");
509                });
510            });
511
512            for method in &self.methods {
513                w.write_line("");
514                method.write_client(w);
515            }
516        });
517    }
518
519    fn write_async_client(&self, w: &mut CodeWriter) {
520        w.write_line("#[derive(Clone)]");
521        w.pub_struct(self.client_name(), |w| {
522            w.field_decl("client", "::ttrpc::r#async::Client");
523        });
524
525        w.write_line("");
526
527        w.impl_self_block(self.client_name(), |w| {
528            w.pub_fn("new(client: ::ttrpc::r#async::Client) -> Self", |w| {
529                w.expr_block(&self.client_name(), |w| {
530                    w.write_line("client,");
531                });
532            });
533
534            for method in &self.methods {
535                w.write_line("");
536                method.write_async_client(w);
537            }
538        });
539    }
540
541    fn write_server(&self, w: &mut CodeWriter) {
542        let mut trait_name = self.service_name();
543        if async_on(self.customize, "server") {
544            w.write_line("#[async_trait]");
545            trait_name = format!("{}: Sync", &self.service_name());
546        }
547
548        w.pub_trait(&trait_name, |w| {
549            for method in &self.methods {
550                method.write_service(w);
551            }
552        });
553
554        w.write_line("");
555        if async_on(self.customize, "server") {
556            self.write_async_server_create(w);
557        } else {
558            self.write_sync_server_create(w);
559        }
560    }
561
562    fn write_sync_server_create(&self, w: &mut CodeWriter) {
563        let method_handler_name = "::ttrpc::MethodHandler";
564        let s = format!(
565            "create_{}(service: Arc<dyn {} + Send + Sync>) -> HashMap<String, Box<dyn {} + Send + Sync>>",
566            to_snake_case(&self.service_name()),
567            self.service_name(),
568            method_handler_name,
569        );
570
571        w.pub_fn(&s, |w| {
572            w.write_line("let mut methods = HashMap::new();");
573            for method in &self.methods[0..self.methods.len()] {
574                w.write_line("");
575                method.write_bind(w);
576            }
577            w.write_line("");
578            w.write_line("methods");
579        });
580    }
581
582    fn write_async_server_create(&self, w: &mut CodeWriter) {
583        let s = format!(
584            "create_{}(service: Arc<dyn {} + Send + Sync>) -> HashMap<String, {}>",
585            to_snake_case(&self.service_name()),
586            self.service_name(),
587            "::ttrpc::r#async::Service"
588        );
589
590        let has_stream_method = self.has_stream_method();
591        w.pub_fn(&s, |w| {
592            w.write_line("let mut ret = HashMap::new();");
593            w.write_line("let mut methods = HashMap::new();");
594            if has_stream_method {
595                w.write_line("let mut streams = HashMap::new();");
596            } else {
597                w.write_line("let streams = HashMap::new();");
598            }
599            for method in &self.methods[0..self.methods.len()] {
600                w.write_line("");
601                method.write_async_bind(w);
602            }
603            w.write_line("");
604            w.write_line(format!(
605                "ret.insert(\"{}\".to_string(), {});",
606                self.service_path(),
607                "::ttrpc::r#async::Service{ methods, streams }"
608            ));
609            w.write_line("ret");
610        });
611    }
612
613    fn write_method_handlers(&self, w: &mut CodeWriter) {
614        for (i, method) in self.methods.iter().enumerate() {
615            if i != 0 {
616                w.write_line("");
617            }
618
619            method.write_handler(w);
620        }
621    }
622
623    fn write(&self, w: &mut CodeWriter) {
624        self.write_client(w);
625        w.write_line("");
626        self.write_method_handlers(w);
627        w.write_line("");
628        self.write_server(w);
629    }
630}
631
632pub fn write_generated_by(w: &mut CodeWriter, pkg: &str, version: &str) {
633    w.write_line(format!(
634        "// This file is generated by {pkg} {version}. Do not edit",
635        pkg = pkg,
636        version = version
637    ));
638    write_generated_common(w);
639}
640
641fn write_generated_common(w: &mut CodeWriter) {
642    // https://secure.phabricator.com/T784
643    w.write_line("// @generated");
644
645    w.write_line("");
646    w.write_line("#![cfg_attr(rustfmt, rustfmt_skip)]");
647    w.write_line("#![allow(unknown_lints)]");
648    w.write_line("#![allow(clipto_camel_casepy)]");
649    w.write_line("#![allow(dead_code)]");
650    w.write_line("#![allow(missing_docs)]");
651    w.write_line("#![allow(non_camel_case_types)]");
652    w.write_line("#![allow(non_snake_case)]");
653    w.write_line("#![allow(non_upper_case_globals)]");
654    w.write_line("#![allow(trivial_casts)]");
655    w.write_line("#![allow(unsafe_code)]");
656    w.write_line("#![allow(unused_imports)]");
657    w.write_line("#![allow(unused_results)]");
658    w.write_line("#![allow(clippy::all)]");
659}
660
661fn gen_file(
662    file: &FileDescriptorProto,
663    root_scope: &RootScope,
664    customize: &Customize,
665) -> Option<GenResult> {
666    if file.get_service().is_empty() {
667        return None;
668    }
669
670    let base = protobuf::descriptorx::proto_path_to_rust_mod(file.get_name());
671
672    let mut v = Vec::new();
673    {
674        let mut w = CodeWriter::new(&mut v);
675
676        write_generated_by(&mut w, "ttrpc-compiler", env!("CARGO_PKG_VERSION"));
677
678        w.write_line("use protobuf::{CodedInputStream, CodedOutputStream, Message};");
679        w.write_line("use std::collections::HashMap;");
680        w.write_line("use std::sync::Arc;");
681        if customize.async_all || customize.async_client || customize.async_server {
682            w.write_line("use async_trait::async_trait;");
683        }
684
685        for service in file.get_service() {
686            w.write_line("");
687            ServiceGen::new(service, file, root_scope, customize).write(&mut w);
688        }
689    }
690
691    Some(GenResult {
692        name: base + "_ttrpc.rs",
693        content: v,
694    })
695}
696
697pub fn gen(
698    file_descriptors: &[FileDescriptorProto],
699    files_to_generate: &[String],
700    customize: &Customize,
701) -> Vec<GenResult> {
702    let files_map: HashMap<&str, &FileDescriptorProto> =
703        file_descriptors.iter().map(|f| (f.get_name(), f)).collect();
704
705    let root_scope = RootScope { file_descriptors };
706
707    let mut results = Vec::new();
708
709    for file_name in files_to_generate {
710        let file = files_map[&file_name[..]];
711
712        if file.get_service().is_empty() {
713            continue;
714        }
715
716        results.extend(gen_file(file, &root_scope, customize).into_iter());
717    }
718
719    results
720}
721
722pub fn gen_and_write(
723    file_descriptors: &[FileDescriptorProto],
724    files_to_generate: &[String],
725    out_dir: &Path,
726    customize: &Customize,
727) -> io::Result<()> {
728    let results = gen(file_descriptors, files_to_generate, customize);
729
730    if customize.gen_mod {
731        let file_path = out_dir.join("mod.rs");
732        let mut set = HashSet::new();
733        //if mod file exists
734        if let Ok(file) = File::open(&file_path) {
735            let reader = io::BufReader::new(file);
736            reader.lines().for_each(|line| {
737                let _ = line.map(|r| set.insert(r));
738            });
739        }
740        let mut file_write = fs::OpenOptions::new()
741            .create(true)
742            .write(true)
743            .truncate(true)
744            .open(&file_path)?;
745        for r in &results {
746            let prefix_name: Vec<&str> = r.name.split('.').collect();
747            set.insert(format!("pub mod {};", prefix_name[0]));
748        }
749        for item in &set {
750            writeln!(file_write, "{}", item)?;
751        }
752        file_write.flush()?;
753    }
754
755    for r in &results {
756        let mut file_path = out_dir.to_owned();
757        file_path.push(&r.name);
758        let mut file_writer = File::create(&file_path)?;
759        file_writer.write_all(&r.content)?;
760        file_writer.flush()?;
761    }
762
763    Ok(())
764}
765
766pub fn protoc_gen_grpc_rust_main() {
767    plugin_main(|file_descriptors, files_to_generate| {
768        gen(
769            file_descriptors,
770            files_to_generate,
771            &Customize {
772                ..Default::default()
773            },
774        )
775    });
776}
777
778fn plugin_main<F>(gen: F)
779where
780    F: Fn(&[FileDescriptorProto], &[String]) -> Vec<GenResult>,
781{
782    plugin_main_2(|r| gen(r.file_descriptors, r.files_to_generate))
783}
784
785fn plugin_main_2<F>(gen: F)
786where
787    F: Fn(&GenRequest) -> Vec<GenResult>,
788{
789    let req = CodeGeneratorRequest::parse_from_reader(&mut stdin()).unwrap();
790    let result = gen(&GenRequest {
791        file_descriptors: req.get_proto_file(),
792        files_to_generate: req.get_file_to_generate(),
793        parameter: req.get_parameter(),
794    });
795    let mut resp = CodeGeneratorResponse::new();
796    resp.set_supported_features(CodeGeneratorResponse_Feature::FEATURE_PROTO3_OPTIONAL as u64);
797    resp.set_file(
798        result
799            .iter()
800            .map(|file| {
801                let mut r = CodeGeneratorResponse_File::new();
802                r.set_name(file.name.to_string());
803                r.set_content(
804                    std::str::from_utf8(file.content.as_ref())
805                        .unwrap()
806                        .to_string(),
807                );
808                r
809            })
810            .collect(),
811    );
812    resp.write_to_writer(&mut stdout()).unwrap();
813}