use faststr::FastStr;
use itertools::Itertools;
use pilota_build::{
codegen::thrift::DecodeHelper, db::RirDatabase, rir, rir::Method, tags::RustWrapperArc,
CodegenBackend, Context, DefId, IdentName, Symbol, ThriftBackend,
};
use quote::format_ident;
#[derive(Clone)]
pub struct VoloThriftBackend {
inner: ThriftBackend,
}
impl VoloThriftBackend {
fn codegen_service_anonymous_type(&self, stream: &mut String, def_id: DefId) {
let service_name = self.cx().rust_name(def_id);
let methods = self.cx().service_methods(def_id);
let methods_names = methods.iter().map(|m| &**m.name).collect::<Vec<_>>();
let variant_names = methods
.iter()
.map(|m| self.cx().rust_name(m.def_id).0.upper_camel_ident())
.collect::<Vec<_>>();
let args_recv_names = methods
.iter()
.map(|m| self.method_args_path(&service_name, m, false))
.collect_vec();
let args_send_names = methods
.iter()
.map(|m| self.method_args_path(&service_name, m, true))
.collect_vec();
let result_recv_names = methods
.iter()
.map(|m| self.method_result_path(&service_name, m, true))
.collect_vec();
let result_send_names = methods
.iter()
.map(|m| self.method_result_path(&service_name, m, false))
.collect_vec();
let req_recv_name = format!("{service_name}RequestRecv");
let req_send_name = format!("{service_name}RequestSend");
let res_recv_name = format!("{service_name}ResponseRecv");
let res_send_name = format!("{service_name}ResponseSend");
let req_impl = {
let mk_decode = |is_async: bool| {
let helper = DecodeHelper::new(is_async);
let decode_variants = helper.codegen_item_decode();
let match_methods = crate::join_multi_strs!("", |methods_names, variant_names| -> "\"{methods_names}\" => {{ Self::{variant_names}({decode_variants}) }},");
format! {
r#"Ok(match &*msg_ident.name {{
{match_methods}
_ => {{
return Err(::pilota::thrift::DecodeError::new(::pilota::thrift::DecodeErrorKind::UnknownMethod, format!("unknown method {{}}", msg_ident.name)));
}},
}})"#
}
};
let decode = mk_decode(false);
let decode_async = mk_decode(true);
let match_encode = crate::join_multi_strs!(",", |variant_names| -> "Self::{variant_names}(value) => {{::pilota::thrift::Message::encode(value, protocol).map_err(|err| err.into())}}");
let match_size = crate::join_multi_strs!(",", |variant_names| -> "Self::{variant_names}(value) => {{::volo_thrift::Message::size(value, protocol)}}");
format! {
r#"#[::async_trait::async_trait]
impl ::volo_thrift::EntryMessage for {req_recv_name} {{
fn encode<T: ::pilota::thrift::TOutputProtocol>(&self, protocol: &mut T) -> ::core::result::Result<(), ::pilota::thrift::EncodeError> {{
match self {{
{match_encode}
}}
}}
fn decode<T: ::pilota::thrift::TInputProtocol>(protocol: &mut T, msg_ident: &::pilota::thrift::TMessageIdentifier) -> ::core::result::Result<Self, ::pilota::thrift::DecodeError> {{
{decode}
}}
async fn decode_async<T: ::pilota::thrift::TAsyncInputProtocol>(
protocol: &mut T,
msg_ident: &::pilota::thrift::TMessageIdentifier
) -> ::core::result::Result<Self, ::pilota::thrift::DecodeError>
{{
{decode_async}
}}
fn size<T: ::pilota::thrift::TLengthProtocol>(&self, protocol: &mut T) -> usize {{
match self {{
{match_size}
}}
}}
}}
#[::async_trait::async_trait]
impl ::volo_thrift::EntryMessage for {req_send_name} {{
fn encode<T: ::pilota::thrift::TOutputProtocol>(&self, protocol: &mut T) -> ::core::result::Result<(), ::pilota::thrift::EncodeError> {{
match self {{
{match_encode}
}}
}}
fn decode<T: ::pilota::thrift::TInputProtocol>(protocol: &mut T, msg_ident: &::pilota::thrift::TMessageIdentifier) -> ::core::result::Result<Self, ::pilota::thrift::DecodeError> {{
{decode}
}}
async fn decode_async<T: ::pilota::thrift::TAsyncInputProtocol>(
protocol: &mut T,
msg_ident: &::pilota::thrift::TMessageIdentifier
) -> ::core::result::Result<Self, ::pilota::thrift::DecodeError>
{{
{decode_async}
}}
fn size<T: ::pilota::thrift::TLengthProtocol>(&self, protocol: &mut T) -> usize {{
match self {{
{match_size}
}}
}}
}}"#
}
};
let res_impl = {
let mk_decode = |is_async: bool| {
let helper = DecodeHelper::new(is_async);
let decode_item = helper.codegen_item_decode();
let match_methods = crate::join_multi_strs!("", |methods_names, variant_names| -> "\"{methods_names}\" => {{ Self::{variant_names}({decode_item}) }},");
format!(
r#"Ok(match &*msg_ident.name {{
{match_methods}
_ => {{
return Err(::pilota::thrift::DecodeError::new(::pilota::thrift::DecodeErrorKind::UnknownMethod, format!("unknown method {{}}", msg_ident.name)));
}},
}})"#
)
};
let match_encode = crate::join_multi_strs!(",", |variant_names| -> "Self::{variant_names}(value) => {{::pilota::thrift::Message::encode(value, protocol).map_err(|err| err.into())}}");
let match_size = crate::join_multi_strs!(",", |variant_names| -> "Self::{variant_names}(value) => {{::volo_thrift::Message::size(value, protocol)}}");
let decode = mk_decode(false);
let decode_async = mk_decode(true);
format! {
r#"#[::async_trait::async_trait]
impl ::volo_thrift::EntryMessage for {res_recv_name} {{
fn encode<T: ::pilota::thrift::TOutputProtocol>(&self, protocol: &mut T) -> ::core::result::Result<(), ::pilota::thrift::EncodeError> {{
match self {{
{match_encode}
}}
}}
fn decode<T: ::pilota::thrift::TInputProtocol>(protocol: &mut T, msg_ident: &::pilota::thrift::TMessageIdentifier) -> ::core::result::Result<Self, ::pilota::thrift::DecodeError> {{
{decode}
}}
async fn decode_async<T: ::pilota::thrift::TAsyncInputProtocol>(
protocol: &mut T,
msg_ident: &::pilota::thrift::TMessageIdentifier,
) -> ::core::result::Result<Self, ::pilota::thrift::DecodeError>
{{
{decode_async}
}}
fn size<T: ::pilota::thrift::TLengthProtocol>(&self, protocol: &mut T) -> usize {{
match self {{
{match_size}
}}
}}
}}
#[::async_trait::async_trait]
impl ::volo_thrift::EntryMessage for {res_send_name} {{
fn encode<T: ::pilota::thrift::TOutputProtocol>(&self, protocol: &mut T) -> ::core::result::Result<(), ::pilota::thrift::EncodeError> {{
match self {{
{match_encode}
}}
}}
fn decode<T: ::pilota::thrift::TInputProtocol>(protocol: &mut T, msg_ident: &::pilota::thrift::TMessageIdentifier) -> ::core::result::Result<Self, ::pilota::thrift::DecodeError> {{
{decode}
}}
async fn decode_async<T: ::pilota::thrift::TAsyncInputProtocol>(
protocol: &mut T,
msg_ident: &::pilota::thrift::TMessageIdentifier,
) -> ::core::result::Result<Self, ::pilota::thrift::DecodeError>
{{
{decode_async}
}}
fn size<T: ::pilota::thrift::TLengthProtocol>(&self, protocol: &mut T) -> usize {{
match self {{
{match_size}
}}
}}
}}"#
}
};
let req_recv_variants = crate::join_multi_strs!(
",",
|variant_names, args_recv_names| -> "{variant_names}({args_recv_names})"
);
let req_send_variants = crate::join_multi_strs!(
",",
|variant_names, args_send_names| -> "{variant_names}({args_send_names})"
);
let res_recv_variants = crate::join_multi_strs!(
",",
|variant_names, result_recv_names| -> "{variant_names}({result_recv_names})"
);
let res_send_variants = crate::join_multi_strs!(
",",
|variant_names, result_send_names| -> "{variant_names}({result_send_names})"
);
stream.push_str(&format! {
r#"#[derive(Debug, Clone)]
pub enum {req_recv_name} {{
{req_recv_variants}
}}
#[derive(Debug, Clone)]
pub enum {req_send_name} {{
{req_send_variants}
}}
#[derive(Debug, Clone)]
pub enum {res_recv_name} {{
{res_recv_variants}
}}
#[derive(Debug, Clone)]
pub enum {res_send_name} {{
{res_send_variants}
}}
{req_impl}
{res_impl}"#
});
}
fn method_ty_path(&self, service_name: &Symbol, method: &Method, suffix: &str) -> FastStr {
match method.source {
rir::MethodSource::Extend(def_id) => {
let item = self.cx().expect_item(def_id);
let target_service = match &*item {
rir::Item::Service(s) => s,
_ => panic!("expected service"),
};
let ident = &*format!(
"{}{}{}",
target_service.name,
self.cx().rust_name(method.def_id).0.upper_camel_ident(),
suffix,
);
let path = self.cx().cur_related_item_path(def_id);
let mut path = path.split("::").collect_vec();
path.pop();
path.push(ident);
let path = path.join("::");
path.into()
}
rir::MethodSource::Own => format!(
"{}{}{}",
service_name,
self.cx().rust_name(method.def_id).0.upper_camel_ident(),
suffix
)
.into(),
}
}
fn method_args_path(&self, service_name: &Symbol, method: &Method, is_client: bool) -> FastStr {
if is_client {
self.method_ty_path(service_name, method, "ArgsSend")
} else {
self.method_ty_path(service_name, method, "ArgsRecv")
}
}
fn method_result_path(
&self,
service_name: &Symbol,
method: &Method,
is_client: bool,
) -> FastStr {
if is_client {
self.method_ty_path(service_name, method, "ResultRecv")
} else {
self.method_ty_path(service_name, method, "ResultSend")
}
}
}
impl pilota_build::CodegenBackend for VoloThriftBackend {
fn codegen_struct_impl(&self, def_id: DefId, stream: &mut String, s: &rir::Message) {
self.inner.codegen_struct_impl(def_id, stream, s)
}
fn codegen_service_impl(&self, def_id: DefId, stream: &mut String, _s: &rir::Service) {
let service_name = self.cx().rust_name(def_id);
let server_name = format!("{service_name}Server");
let generic_client_name = format!("{service_name}GenericClient");
let client_name = format!("{service_name}Client");
let oneshot_client_name = format!("{service_name}OneShotClient");
let client_builder_name = format!("{client_name}Builder");
let req_send_name = format!("{service_name}RequestSend");
let req_recv_name = format!("{service_name}RequestRecv");
let res_send_name = format!("{service_name}ResponseSend");
let res_recv_name = format!("{service_name}ResponseRecv");
let all_methods = self.cx().service_methods(def_id);
let mut client_methods = Vec::new();
let mut oneshot_client_methods = Vec::new();
all_methods.iter().for_each(|m| {
let name = self.cx().rust_name(m.def_id);
let resp_type = self.cx().codegen_item_ty(m.ret.kind.clone());
let req_fields = m.args.iter().map(|a| {
let name = self.cx().rust_name(a.def_id);
let ty = self.cx().codegen_item_ty(a.ty.kind.clone());
let mut ty = format!("{ty}");
if let Some(RustWrapperArc(true)) = self.cx().tags(a.tags_id).as_ref().and_then(|tags| tags.get::<RustWrapperArc>()) {
ty = format!("::std::sync::Arc<{ty}>");
}
format!(", {name}: {ty}")
}).join("");
let method_name_str = &**m.name;
let enum_variant = self.cx().rust_name(m.def_id).0.upper_camel_ident();
let result_path = self.method_result_path(&service_name, m, true);
let oneway = m.oneway;
let none = if m.oneway {
"None => { Ok(()) }"
} else {
"None => unreachable!()"
};
let req_field_names = m.args.iter().map(|a| self.cx().rust_name(a.def_id)).join(",");
let anonymous_args_send_name = self.method_args_path(&service_name, m, true);
let exception = if let Some(p) = &m.exceptions {
self.cx().cur_related_item_path(p.did)
} else {
"std::convert::Infallible".into()
};
let convert_exceptions = m.exceptions.iter().map(|p| {
self.cx().expect_item(p.did)
}).flat_map(|e| {
match &*e {
rir::Item::Enum(e) => e.variants.iter().map(|v| {
let name = self.cx().rust_name(v.did);
format!("Some({res_recv_name}::{enum_variant}({result_path}::{name}(err))) => Err(::volo_thrift::error::ResponseError::UserException({exception}::{name}(err))),")
}).collect::<Vec<_>>(),
_ => panic!()
}
}).join("");
client_methods.push(format! {
r#"pub async fn {name}(&self {req_fields}) -> ::std::result::Result<{resp_type}, ::volo_thrift::error::ResponseError<{exception}>> {{
let req = {req_send_name}::{enum_variant}({anonymous_args_send_name} {{
{req_field_names}
}});
let mut cx = self.0.make_cx("{method_name_str}", {oneway});
#[allow(unreachable_patterns)]
let resp = match ::volo::service::Service::call(&self.0, &mut cx, req).await? {{
Some({res_recv_name}::{enum_variant}({result_path}::Ok(resp))) => Ok(resp),
{convert_exceptions}
{none},
_ => unreachable!()
}};
::volo_thrift::context::CLIENT_CONTEXT_CACHE.with(|cache| {{
let mut cache = cache.borrow_mut();
if cache.len() < cache.capacity() {{
cache.push(cx);
}}
}});
resp
}}"#
});
oneshot_client_methods.push(format! {
r#"pub async fn {name}(self {req_fields}) -> ::std::result::Result<{resp_type}, ::volo_thrift::error::ResponseError<{exception}>> {{
let req = {req_send_name}::{enum_variant}({anonymous_args_send_name} {{
{req_field_names}
}});
let mut cx = self.0.make_cx("{method_name_str}", {oneway});
#[allow(unreachable_patterns)]
let resp = match ::volo::client::OneShotService::call(self.0, &mut cx, req).await? {{
Some({res_recv_name}::{enum_variant}({result_path}::Ok(resp))) => Ok(resp),
{convert_exceptions}
{none},
_ => unreachable!()
}};
::volo_thrift::context::CLIENT_CONTEXT_CACHE.with(|cache| {{
let mut cache = cache.borrow_mut();
if cache.len() < cache.capacity() {{
cache.push(cx);
}}
}});
resp
}}"#
});
});
let variants = all_methods
.iter()
.map(|m| self.cx().rust_name(m.def_id).0.upper_camel_ident())
.collect_vec();
let user_handler = all_methods
.iter()
.map(|m| {
let name = self.cx().rust_name(m.def_id);
let args = m
.args
.iter()
.map(|a| format!("args.{}", self.cx().rust_name(a.def_id)))
.join(",");
let has_exception = m.exceptions.is_some();
let method_result_path = self.method_result_path(&service_name, m, false);
let exception: FastStr = if let Some(p) = &m.exceptions {
self.cx().cur_related_item_path(p.did)
} else {
"::volo_thrift::error::DummyError".into()
};
let convert_exceptions = m
.exceptions
.iter()
.map(|p| self.cx().expect_item(p.did))
.flat_map(|e| match &*e {
rir::Item::Enum(e) => e
.variants
.iter()
.map(|v| {
let name = self.cx().rust_name(v.did);
format!(
"Err(::volo_thrift::error::UserError::UserException({exception}::{name}(err))) => {method_result_path}::{name}(err),"
)
})
.collect::<Vec<_>>(),
_ => panic!(),
})
.join("");
if has_exception {
format! {
r#"match self.inner.{name}({args}).await {{
Ok(resp) => {method_result_path}::Ok(resp),
{convert_exceptions}
Err(::volo_thrift::error::UserError::Other(err)) => return Err(err),
}}"#
}
} else {
format! {
r#"match self.inner.{name}({args}).await {{
Ok(resp) => {method_result_path}::Ok(resp),
Err(err) => return Err(err),
}}"#
}
}
})
.collect_vec();
let mk_client_name = format_ident!("Mk{}", generic_client_name);
let client_methods = client_methods.join("\n");
let oneshot_client_methods = oneshot_client_methods.join("\n");
let handler = crate::join_multi_strs!("", |variants, user_handler| -> r#"{req_recv_name}::{variants}(args) => Ok(
{res_send_name}::{variants}(
{user_handler}
)),"#);
stream.push_str(&format! {
r#"pub struct {server_name}<S> {{
inner: S, // handler
}}
pub struct {mk_client_name};
pub type {client_name} = {generic_client_name}<::volo::service::BoxCloneService<::volo_thrift::context::ClientContext, {req_send_name}, ::std::option::Option<{res_recv_name}>, ::volo_thrift::Error>>;
impl<S> ::volo::client::MkClient<::volo_thrift::Client<S>> for {mk_client_name} {{
type Target = {generic_client_name}<S>;
fn mk_client(&self, service: ::volo_thrift::Client<S>) -> Self::Target {{
{generic_client_name}(service)
}}
}}
#[derive(Clone)]
pub struct {generic_client_name}<S>(pub ::volo_thrift::Client<S>);
pub struct {oneshot_client_name}<S>(pub ::volo_thrift::Client<S>);
impl<S: ::volo::service::Service<::volo_thrift::context::ClientContext, {req_send_name}, Response = ::std::option::Option<{res_recv_name}>, Error = ::volo_thrift::Error> + Send + Sync + 'static> {generic_client_name}<S> {{
pub fn with_callopt<Opt: ::volo::client::Apply<::volo_thrift::context::ClientContext>>(self, opt: Opt) -> {oneshot_client_name}<::volo::client::WithOptService<S, Opt>> {{
{oneshot_client_name}(self.0.with_opt(opt))
}}
{client_methods}
}}
impl<S: ::volo::client::OneShotService<::volo_thrift::context::ClientContext, {req_send_name}, Response = ::std::option::Option<{res_recv_name}>, Error = ::volo_thrift::Error> + Send + Sync + 'static> {oneshot_client_name}<S> {{
{oneshot_client_methods}
}}
pub struct {client_builder_name} {{
}}
impl {client_builder_name} {{
pub fn new(service_name: impl AsRef<str>) -> ::volo_thrift::client::ClientBuilder<
::volo::layer::Identity,
::volo::layer::Identity,
{mk_client_name},
{req_send_name},
{res_recv_name},
::volo::net::dial::DefaultMakeTransport,
::volo_thrift::codec::default::DefaultMakeCodec<::volo_thrift::codec::default::ttheader::MakeTTHeaderCodec<::volo_thrift::codec::default::framed::MakeFramedCodec<::volo_thrift::codec::default::thrift::MakeThriftCodec>>>,
::volo::loadbalance::LbConfig<::volo::loadbalance::random::WeightedRandomBalance<()>, ::volo::discovery::DummyDiscover>,
>
{{
::volo_thrift::client::ClientBuilder::new(service_name, {mk_client_name})
}}
}}
impl<S> {server_name}<S> where S: {service_name} + ::core::marker::Send + ::core::marker::Sync + 'static {{
pub fn new(inner: S) -> ::volo_thrift::server::Server<Self, ::volo::layer::Identity, {req_recv_name}, ::volo_thrift::codec::default::DefaultMakeCodec<::volo_thrift::codec::default::ttheader::MakeTTHeaderCodec<::volo_thrift::codec::default::framed::MakeFramedCodec<::volo_thrift::codec::default::thrift::MakeThriftCodec>>>, ::volo_thrift::tracing::DefaultProvider> {{
::volo_thrift::server::Server::new(Self {{
inner,
}})
}}
}}
impl<T> ::volo::service::Service<::volo_thrift::context::ServerContext, {req_recv_name}> for {server_name}<T> where T: {service_name} + Send + Sync + 'static {{
type Response = {res_send_name};
type Error = ::anyhow::Error;
type Future<'cx> = impl ::std::future::Future<Output = ::std::result::Result<Self::Response, Self::Error>> + 'cx;
fn call<'cx, 's>(&'s self, _cx: &'cx mut ::volo_thrift::context::ServerContext, req: {req_recv_name}) -> Self::Future<'cx> where 's:'cx {{
async move {{
match req {{
{handler}
}}
}}
}}
}}"#
});
self.codegen_service_anonymous_type(stream, def_id);
}
fn codegen_service_method(&self, _service_def_id: DefId, method: &Method) -> String {
let name = self.cx().rust_name(method.def_id);
let ret_ty = self.inner.codegen_item_ty(method.ret.kind.clone());
let mut ret_ty = format!("{ret_ty}");
if let Some(RustWrapperArc(true)) = self
.cx()
.tags(method.ret.tags_id)
.as_ref()
.and_then(|tags| tags.get::<RustWrapperArc>())
{
ret_ty = format!("::std::sync::Arc<{ret_ty}>");
}
let args = method
.args
.iter()
.map(|a| {
let ty = self.inner.codegen_item_ty(a.ty.kind.clone());
let ident = self.cx().rust_name(a.def_id);
format!("{ident}: {ty}")
})
.join(",");
let exception: FastStr = if let Some(p) = &method.exceptions {
let exception = self.inner.cur_related_item_path(p.did);
format! {"::volo_thrift::error::UserError<{exception}>" }.into()
} else {
"::volo_thrift::AnyhowError".into()
};
format!("async fn {name}(&self, {args}) -> ::core::result::Result<{ret_ty}, {exception}>;")
}
fn codegen_enum_impl(&self, def_id: DefId, stream: &mut String, e: &rir::Enum) {
self.inner.codegen_enum_impl(def_id, stream, e)
}
fn codegen_newtype_impl(&self, def_id: DefId, stream: &mut String, t: &rir::NewType) {
self.inner.codegen_newtype_impl(def_id, stream, t)
}
fn cx(&self) -> &Context {
self.inner.cx()
}
}
pub struct MkThriftBackend;
impl pilota_build::MakeBackend for MkThriftBackend {
type Target = VoloThriftBackend;
fn make_backend(self, context: Context) -> Self::Target {
VoloThriftBackend {
inner: ThriftBackend::new(context),
}
}
}