volo_thrift/
message_wrapper.rs1use pilota::thrift::{ApplicationException, Message, TAsyncInputProtocol, ThriftException};
2use volo::FastStr;
3
4use crate::{
5 EntryMessage,
6 context::{ClientContext, ServerContext, ThriftContext},
7 protocol::{
8 TInputProtocol, TLengthProtocol, TMessageIdentifier, TMessageType, TOutputProtocol,
9 },
10};
11
12#[derive(Debug)]
13pub struct MessageMeta {
14 pub msg_type: TMessageType,
15 pub(crate) method: FastStr,
16 pub(crate) seq_id: i32,
17}
18
19#[derive(Debug)]
20pub struct ThriftMessage<M> {
21 pub data: Result<M, ApplicationException>,
22 pub meta: MessageMeta,
23}
24
25pub(crate) struct DummyMessage;
26
27impl EntryMessage for DummyMessage {
28 #[inline]
29 fn encode<T: TOutputProtocol>(&self, _protocol: &mut T) -> Result<(), ThriftException> {
30 unreachable!()
31 }
32
33 #[inline]
34 fn decode<T: TInputProtocol>(
35 _protocol: &mut T,
36 _msg_ident: &TMessageIdentifier,
37 ) -> Result<Self, ThriftException> {
38 unreachable!()
39 }
40
41 #[inline]
42 async fn decode_async<T: TAsyncInputProtocol>(
43 _protocol: &mut T,
44 _msg_ident: &TMessageIdentifier,
45 ) -> Result<Self, ThriftException> {
46 unreachable!()
47 }
48
49 fn size<T: TLengthProtocol>(&self, _protocol: &mut T) -> usize {
50 unreachable!()
51 }
52}
53
54impl<M> ThriftMessage<M> {
55 #[inline]
56 pub fn mk_client_msg(cx: &ClientContext, msg: M) -> Self {
57 let meta = MessageMeta {
58 msg_type: cx.message_type,
59 method: cx.rpc_info.method().clone(),
60 seq_id: cx.seq_id,
61 };
62 Self {
63 data: Ok(msg),
64 meta,
65 }
66 }
67
68 #[inline]
70 pub fn mk_server_resp(cx: &ServerContext, msg: Result<M, ApplicationException>) -> Self {
71 let meta = MessageMeta {
72 msg_type: match msg {
73 Ok(_) => TMessageType::Reply,
74 Err(_) => TMessageType::Exception,
75 },
76 method: cx.rpc_info.method().clone(),
77 seq_id: cx.seq_id.unwrap_or(0),
78 };
79 Self { data: msg, meta }
80 }
81}
82
83impl<U> ThriftMessage<U>
84where
85 U: EntryMessage,
86{
87 #[inline]
88 pub(crate) fn size<T: TLengthProtocol>(&self, protocol: &mut T) -> usize {
89 let ident = TMessageIdentifier::new(
90 self.meta.method.clone(),
91 self.meta.msg_type,
92 self.meta.seq_id,
93 );
94
95 match &self.data {
96 Ok(inner) => {
97 protocol.message_begin_len(&ident)
98 + inner.size(protocol)
99 + protocol.message_end_len()
100 }
101 Err(inner) => {
102 protocol.message_begin_len(&ident)
103 + inner.size(protocol)
104 + protocol.message_end_len()
105 }
106 }
107 }
108}
109
110impl<U> ThriftMessage<U>
111where
112 U: EntryMessage + Send,
113{
114 #[inline]
115 pub(crate) fn encode<T: TOutputProtocol>(
116 &self,
117 protocol: &mut T,
118 ) -> Result<(), ThriftException> {
119 let ident = TMessageIdentifier::new(
120 self.meta.method.clone(),
121 self.meta.msg_type,
122 self.meta.seq_id,
123 );
124 match &self.data {
125 Ok(v) => {
126 protocol.write_message_begin(&ident)?;
127 v.encode(protocol)?;
128 }
129 Err(e) => {
130 protocol.write_message_begin(&ident)?;
131 e.encode(protocol)?;
132 }
133 }
134 protocol.write_message_end()?;
135 Ok(())
136 }
137
138 #[inline]
139 pub(crate) fn decode<Cx: ThriftContext, T: TInputProtocol>(
140 protocol: &mut T,
141 cx: &mut Cx,
142 ) -> Result<Self, ThriftException> {
143 let msg_ident = protocol.read_message_begin()?;
144
145 cx.handle_decoded_msg_ident(&msg_ident);
146
147 let res = match msg_ident.message_type {
148 TMessageType::Exception => Err(ApplicationException::decode(protocol)?),
149 _ => Ok(U::decode(protocol, &msg_ident)?),
150 };
151 protocol.read_message_end()?;
152 Ok(ThriftMessage {
153 data: res,
154 meta: MessageMeta {
155 msg_type: msg_ident.message_type,
156 method: msg_ident.name,
157 seq_id: msg_ident.sequence_number,
158 },
159 })
160 }
161
162 #[inline]
163 pub(crate) async fn decode_async<Cx: ThriftContext + Send, T: TAsyncInputProtocol>(
164 protocol: &mut T,
165 cx: &mut Cx,
166 ) -> Result<Self, ThriftException> {
167 let msg_ident = protocol.read_message_begin().await?;
168
169 cx.handle_decoded_msg_ident(&msg_ident);
170
171 let res = match msg_ident.message_type {
172 TMessageType::Exception => Err(ApplicationException::decode_async(protocol).await?),
173 _ => Ok(U::decode_async(protocol, &msg_ident).await?),
174 };
175 protocol.read_message_end().await?;
176 Ok(ThriftMessage {
177 data: res,
178 meta: MessageMeta {
179 msg_type: msg_ident.message_type,
180 method: msg_ident.name,
181 seq_id: msg_ident.sequence_number,
182 },
183 })
184 }
185}