volo_thrift/
message_wrapper.rs

1use pilota::thrift::{ApplicationException, Message, TAsyncInputProtocol, ThriftException};
2use volo::FastStr;
3
4use crate::{
5    EntryMessage,
6    context::{ClientContext, ServerContext, ThriftContext},
7    protocol::{
8        TInputProtocol, TLengthProtocol, TMessageIdentifier, TMessageType, TOutputProtocol,
9    },
10};
11
12#[derive(Debug)]
13pub struct MessageMeta {
14    pub msg_type: TMessageType,
15    pub(crate) method: FastStr,
16    pub(crate) seq_id: i32,
17}
18
19#[derive(Debug)]
20pub struct ThriftMessage<M> {
21    pub data: Result<M, ApplicationException>,
22    pub meta: MessageMeta,
23}
24
25pub(crate) struct DummyMessage;
26
27impl EntryMessage for DummyMessage {
28    #[inline]
29    fn encode<T: TOutputProtocol>(&self, _protocol: &mut T) -> Result<(), ThriftException> {
30        unreachable!()
31    }
32
33    #[inline]
34    fn decode<T: TInputProtocol>(
35        _protocol: &mut T,
36        _msg_ident: &TMessageIdentifier,
37    ) -> Result<Self, ThriftException> {
38        unreachable!()
39    }
40
41    #[inline]
42    async fn decode_async<T: TAsyncInputProtocol>(
43        _protocol: &mut T,
44        _msg_ident: &TMessageIdentifier,
45    ) -> Result<Self, ThriftException> {
46        unreachable!()
47    }
48
49    fn size<T: TLengthProtocol>(&self, _protocol: &mut T) -> usize {
50        unreachable!()
51    }
52}
53
54impl<M> ThriftMessage<M> {
55    #[inline]
56    pub fn mk_client_msg(cx: &ClientContext, msg: M) -> Self {
57        let meta = MessageMeta {
58            msg_type: cx.message_type,
59            method: cx.rpc_info.method().clone(),
60            seq_id: cx.seq_id,
61        };
62        Self {
63            data: Ok(msg),
64            meta,
65        }
66    }
67
68    /// Server response message can only be an Ok(msg) or Err(ApplicationException).
69    #[inline]
70    pub fn mk_server_resp(cx: &ServerContext, msg: Result<M, ApplicationException>) -> Self {
71        let meta = MessageMeta {
72            msg_type: match msg {
73                Ok(_) => TMessageType::Reply,
74                Err(_) => TMessageType::Exception,
75            },
76            method: cx.rpc_info.method().clone(),
77            seq_id: cx.seq_id.unwrap_or(0),
78        };
79        Self { data: msg, meta }
80    }
81}
82
83impl<U> ThriftMessage<U>
84where
85    U: EntryMessage,
86{
87    #[inline]
88    pub(crate) fn size<T: TLengthProtocol>(&self, protocol: &mut T) -> usize {
89        let ident = TMessageIdentifier::new(
90            self.meta.method.clone(),
91            self.meta.msg_type,
92            self.meta.seq_id,
93        );
94
95        match &self.data {
96            Ok(inner) => {
97                protocol.message_begin_len(&ident)
98                    + inner.size(protocol)
99                    + protocol.message_end_len()
100            }
101            Err(inner) => {
102                protocol.message_begin_len(&ident)
103                    + inner.size(protocol)
104                    + protocol.message_end_len()
105            }
106        }
107    }
108}
109
110impl<U> ThriftMessage<U>
111where
112    U: EntryMessage + Send,
113{
114    #[inline]
115    pub(crate) fn encode<T: TOutputProtocol>(
116        &self,
117        protocol: &mut T,
118    ) -> Result<(), ThriftException> {
119        let ident = TMessageIdentifier::new(
120            self.meta.method.clone(),
121            self.meta.msg_type,
122            self.meta.seq_id,
123        );
124        match &self.data {
125            Ok(v) => {
126                protocol.write_message_begin(&ident)?;
127                v.encode(protocol)?;
128            }
129            Err(e) => {
130                protocol.write_message_begin(&ident)?;
131                e.encode(protocol)?;
132            }
133        }
134        protocol.write_message_end()?;
135        Ok(())
136    }
137
138    #[inline]
139    pub(crate) fn decode<Cx: ThriftContext, T: TInputProtocol>(
140        protocol: &mut T,
141        cx: &mut Cx,
142    ) -> Result<Self, ThriftException> {
143        let msg_ident = protocol.read_message_begin()?;
144
145        cx.handle_decoded_msg_ident(&msg_ident);
146
147        let res = match msg_ident.message_type {
148            TMessageType::Exception => Err(ApplicationException::decode(protocol)?),
149            _ => Ok(U::decode(protocol, &msg_ident)?),
150        };
151        protocol.read_message_end()?;
152        Ok(ThriftMessage {
153            data: res,
154            meta: MessageMeta {
155                msg_type: msg_ident.message_type,
156                method: msg_ident.name,
157                seq_id: msg_ident.sequence_number,
158            },
159        })
160    }
161
162    #[inline]
163    pub(crate) async fn decode_async<Cx: ThriftContext + Send, T: TAsyncInputProtocol>(
164        protocol: &mut T,
165        cx: &mut Cx,
166    ) -> Result<Self, ThriftException> {
167        let msg_ident = protocol.read_message_begin().await?;
168
169        cx.handle_decoded_msg_ident(&msg_ident);
170
171        let res = match msg_ident.message_type {
172            TMessageType::Exception => Err(ApplicationException::decode_async(protocol).await?),
173            _ => Ok(U::decode_async(protocol, &msg_ident).await?),
174        };
175        protocol.read_message_end().await?;
176        Ok(ThriftMessage {
177            data: res,
178            meta: MessageMeta {
179                msg_type: msg_ident.message_type,
180                method: msg_ident.name,
181                seq_id: msg_ident.sequence_number,
182            },
183        })
184    }
185}