wireman_core/descriptor/
request.rs

1use super::{metadata::Metadata, DynamicMessage};
2use crate::client::codec::DynamicCodec;
3use crate::{
4    error::{Error, FROM_UTF8},
5    Result,
6};
7use http::{uri::PathAndQuery, Uri};
8use prost_reflect::{MessageDescriptor, MethodDescriptor};
9use serde::{ser::SerializeStruct, Serialize, Serializer};
10use std::str::FromStr;
11use tonic::{
12    metadata::{Ascii, MetadataKey, MetadataValue},
13    Request,
14};
15
16/// Holds all the necessary data for a `gRPC` request, including
17/// the message, method descriptor, and optional metadata.
18#[derive(Debug, Clone)]
19pub struct RequestMessage {
20    /// The `gRPC` message.
21    message: DynamicMessage,
22    /// The `gRPC` method
23    method_desc: MethodDescriptor,
24    /// The requests metadata.
25    metadata: Option<Metadata>,
26    /// The host address.
27    address: String,
28}
29
30impl RequestMessage {
31    /// Create a new `RequestMessage` with the provided message
32    /// descriptor and method descriptor.
33    #[must_use]
34    pub fn new(message_desc: MessageDescriptor, method_desc: MethodDescriptor) -> Self {
35        let message = DynamicMessage::new(message_desc);
36        Self {
37            message,
38            method_desc,
39            metadata: None,
40            address: String::new(),
41        }
42    }
43
44    /// Get the name of the message.
45    #[must_use]
46    pub fn message_name(&self) -> String {
47        self.message_descriptor().name().to_string()
48    }
49
50    /// Get the message descriptor associated with the `RequestMessage`.
51    #[must_use]
52    pub fn message_descriptor(&self) -> MessageDescriptor {
53        self.message.descriptor()
54    }
55
56    /// Get the method descriptor associated with the `RequestMessage`.
57    #[must_use]
58    pub fn method_descriptor(&self) -> MethodDescriptor {
59        self.method_desc.clone()
60    }
61
62    /// Gets a reference to the message.
63    #[must_use]
64    pub fn message(&self) -> &DynamicMessage {
65        &self.message
66    }
67
68    /// Gets a mutable reference to the message.
69    #[must_use]
70    pub fn message_mut(&mut self) -> &mut DynamicMessage {
71        &mut self.message
72    }
73
74    /// Set a new message for the request.
75    pub fn set_message(&mut self, message: DynamicMessage) {
76        self.message = message;
77    }
78
79    /// Get the host address.
80    #[must_use]
81    pub fn address(&self) -> &str {
82        &self.address
83    }
84
85    /// Get the host address as uri.
86    ///
87    /// # Errors
88    /// - Failed to parse address to uri.
89    pub fn uri(&self) -> Result<Uri> {
90        Uri::try_from(self.address())
91            .map_err(|_| Error::Internal(String::from("Failed to parse address")))
92    }
93
94    /// Sets the host address.
95    pub fn set_address(&mut self, address: &str) {
96        self.address = address.to_string();
97    }
98
99    /// Get the metadata associated with the request.
100    #[must_use]
101    pub fn metadata(&self) -> &Option<Metadata> {
102        &self.metadata
103    }
104
105    /// Insert metadata into the request.
106    ///
107    /// # Errors
108    ///
109    /// - Failed to parse metadata value/key to ascii
110    pub fn insert_metadata(&mut self, key: &str, val: &str) -> Result<()> {
111        let key: MetadataKey<Ascii> = key.parse().map_err(|_| Error::ParseToAsciiError)?;
112        let val: MetadataValue<Ascii> = val.parse().map_err(|_| Error::ParseToAsciiError)?;
113        let map = self.metadata.get_or_insert(Metadata::new());
114        map.insert(key, val);
115        Ok(())
116    }
117
118    /// Get the URI path for `gRPC` calls based on the method descriptor.
119    ///
120    /// # Panics
121    ///
122    /// Panics if constructing the path and query from a string fails.
123    #[must_use]
124    pub fn path(&self) -> PathAndQuery {
125        let path = format!(
126            "/{}/{}",
127            self.method_desc.parent_service().full_name(),
128            self.method_desc.name()
129        );
130        PathAndQuery::from_str(&path).unwrap()
131    }
132
133    /// Return the dynamic codec based on the method descriptor.
134    #[must_use]
135    pub fn codec(&self) -> DynamicCodec {
136        DynamicCodec::new(self.method_descriptor())
137    }
138
139    /// Serialize the `RequestMessage` to a JSON string.
140    ///
141    /// # Errors
142    ///
143    /// Returns an `Error` if serialization to a JSON string fails.
144    pub fn to_json(&self) -> Result<String> {
145        let mut s = serde_json::Serializer::new(Vec::new());
146        self.serialize(&mut s)
147            .map_err(|_| Error::Internal(String::from("failed to serialize message")))?;
148        String::from_utf8(s.into_inner()).map_err(|_| Error::Internal(FROM_UTF8.to_string()))
149    }
150}
151
152impl From<RequestMessage> for Request<RequestMessage> {
153    fn from(value: RequestMessage) -> Self {
154        let metadata = value.metadata().clone();
155        let mut req = Request::new(value);
156        if let Some(meta) = metadata {
157            *req.metadata_mut() = meta.inner;
158        }
159        req
160    }
161}
162
163impl Serialize for RequestMessage {
164    fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error>
165    where
166        S: Serializer,
167    {
168        let mut state = serializer.serialize_struct("RequestMessage", 3)?;
169        state.serialize_field("message", &self.message)?;
170        if let Some(metadata) = &self.metadata {
171            state.serialize_field("metadata", &metadata)?;
172        }
173        state.serialize_field("address", &self.address)?;
174        state.end()
175    }
176}
177
178#[cfg(test)]
179mod test {
180    use crate::ProtoDescriptor;
181
182    use super::*;
183
184    fn load_test_message(method: &str) -> RequestMessage {
185        // The test files
186        let files = vec!["test_files/test.proto"];
187        let includes = vec!["."];
188
189        // Generate the descriptor
190        let desc = ProtoDescriptor::new(includes, files).unwrap();
191
192        // Get the method and message
193        let method = desc
194            .get_method_by_name("proto.TestService", method)
195            .unwrap();
196        let request = method.input();
197        RequestMessage::new(request, method)
198    }
199
200    #[test]
201    fn test_into_request() {
202        // given
203        let mut given_message = load_test_message("Simple");
204        given_message
205            .insert_metadata("metadata-key", "metadata-value")
206            .unwrap();
207        let method_descriptor = given_message.method_descriptor().clone();
208        let message_descriptor = given_message.message_descriptor().clone();
209
210        // when
211        let given_req: Request<RequestMessage> = given_message.into();
212
213        // then
214        let metadata = given_req.metadata();
215        assert!(metadata.contains_key("metadata-key"));
216        assert_eq!(metadata.get("metadata-key").unwrap(), "metadata-value");
217        assert_eq!(given_req.get_ref().method_descriptor(), method_descriptor);
218        assert_eq!(given_req.get_ref().message_descriptor(), message_descriptor);
219    }
220
221    #[test]
222    fn test_to_json() {
223        // given
224        let mut given_message = load_test_message("Simple");
225        given_message.insert_metadata("key", "value").unwrap();
226        given_message.set_address("localhost:50051");
227
228        // when
229        let json = given_message.to_json().unwrap();
230
231        // then
232        let expected_json = "{\"message\":{\"number\":0},\"metadata\":{\"key\":\"value\"},\"address\":\"localhost:50051\"}";
233        assert_eq!(json, expected_json);
234    }
235}