tonic_arcanyx_fork/
request.rs

1use crate::metadata::{MetadataMap, MetadataValue};
2#[cfg(all(feature = "transport", feature = "tls"))]
3use crate::transport::server::TlsConnectInfo;
4#[cfg(feature = "transport")]
5use crate::transport::{server::TcpConnectInfo, Certificate};
6use crate::Extensions;
7use futures_core::Stream;
8#[cfg(feature = "transport")]
9use std::sync::Arc;
10use std::{net::SocketAddr, time::Duration};
11
12/// A gRPC request and metadata from an RPC call.
13#[derive(Debug)]
14pub struct Request<T> {
15    metadata: MetadataMap,
16    message: T,
17    extensions: Extensions,
18}
19
20/// Trait implemented by RPC request types.
21///
22/// Types implementing this trait can be used as arguments to client RPC
23/// methods without explicitly wrapping them into `tonic::Request`s. The purpose
24/// is to make client calls slightly more convenient to write.
25///
26/// Tonic's code generation and blanket implementations handle this for you,
27/// so it is not necessary to implement this trait directly.
28///
29/// # Example
30///
31/// Given the following gRPC method definition:
32/// ```proto
33/// rpc GetFeature(Point) returns (Feature) {}
34/// ```
35///
36/// we can call `get_feature` in two equivalent ways:
37/// ```rust
38/// # pub struct Point {}
39/// # pub struct Client {}
40/// # impl Client {
41/// #   fn get_feature(&self, r: impl tonic::IntoRequest<Point>) {}
42/// # }
43/// # let client = Client {};
44/// use tonic::Request;
45///
46/// client.get_feature(Point {});
47/// client.get_feature(Request::new(Point {}));
48/// ```
49pub trait IntoRequest<T>: sealed::Sealed {
50    /// Wrap the input message `T` in a `tonic::Request`
51    fn into_request(self) -> Request<T>;
52}
53
54/// Trait implemented by RPC streaming request types.
55///
56/// Types implementing this trait can be used as arguments to client streaming
57/// RPC methods without explicitly wrapping them into `tonic::Request`s. The
58/// purpose is to make client calls slightly more convenient to write.
59///
60/// Tonic's code generation and blanket implementations handle this for you,
61/// so it is not necessary to implement this trait directly.
62///
63/// # Example
64///
65/// Given the following gRPC service method definition:
66/// ```proto
67/// rpc RecordRoute(stream Point) returns (RouteSummary) {}
68/// ```
69/// we can call `record_route` in two equivalent ways:
70///
71/// ```rust
72/// # #[derive(Clone)]
73/// # pub struct Point {};
74/// # pub struct Client {};
75/// # impl Client {
76/// #   fn record_route(&self, r: impl tonic::IntoStreamingRequest<Message = Point>) {}
77/// # }
78/// # let client = Client {};
79/// use tonic::Request;
80/// use futures_util::stream;
81///
82/// let messages = vec![Point {}, Point {}];
83///
84/// client.record_route(Request::new(stream::iter(messages.clone())));
85/// client.record_route(stream::iter(messages));
86/// ```
87pub trait IntoStreamingRequest: sealed::Sealed {
88    /// The RPC request stream type
89    type Stream: Stream<Item = Self::Message> + Send + 'static;
90
91    /// The RPC request type
92    type Message;
93
94    /// Wrap the stream of messages in a `tonic::Request`
95    fn into_streaming_request(self) -> Request<Self::Stream>;
96}
97
98impl<T> Request<T> {
99    /// Create a new gRPC request.
100    ///
101    /// ```rust
102    /// # use tonic::Request;
103    /// # pub struct HelloRequest {
104    /// #   pub name: String,
105    /// # }
106    /// Request::new(HelloRequest {
107    ///    name: "Bob".into(),
108    /// });
109    /// ```
110    pub fn new(message: T) -> Self {
111        Request {
112            metadata: MetadataMap::new(),
113            message,
114            extensions: Extensions::new(),
115        }
116    }
117
118    /// Get a reference to the message
119    pub fn get_ref(&self) -> &T {
120        &self.message
121    }
122
123    /// Get a mutable reference to the message
124    pub fn get_mut(&mut self) -> &mut T {
125        &mut self.message
126    }
127
128    /// Get a reference to the custom request metadata.
129    pub fn metadata(&self) -> &MetadataMap {
130        &self.metadata
131    }
132
133    /// Get a mutable reference to the request metadata.
134    pub fn metadata_mut(&mut self) -> &mut MetadataMap {
135        &mut self.metadata
136    }
137
138    /// Consumes `self`, returning the message
139    pub fn into_inner(self) -> T {
140        self.message
141    }
142
143    pub(crate) fn into_parts(self) -> (MetadataMap, Extensions, T) {
144        (self.metadata, self.extensions, self.message)
145    }
146
147    pub(crate) fn from_parts(metadata: MetadataMap, extensions: Extensions, message: T) -> Self {
148        Self {
149            metadata,
150            extensions,
151            message,
152        }
153    }
154
155    pub(crate) fn from_http_parts(parts: http::request::Parts, message: T) -> Self {
156        Request {
157            metadata: MetadataMap::from_headers(parts.headers),
158            message,
159            extensions: Extensions::from_http(parts.extensions),
160        }
161    }
162
163    /// Convert an HTTP request to a gRPC request
164    pub fn from_http(http: http::Request<T>) -> Self {
165        let (parts, message) = http.into_parts();
166        Request::from_http_parts(parts, message)
167    }
168
169    pub(crate) fn into_http(
170        self,
171        uri: http::Uri,
172        method: http::Method,
173        version: http::Version,
174        sanitize_headers: SanitizeHeaders,
175    ) -> http::Request<T> {
176        let mut request = http::Request::new(self.message);
177
178        *request.version_mut() = version;
179        *request.method_mut() = method;
180        *request.uri_mut() = uri;
181        *request.headers_mut() = match sanitize_headers {
182            SanitizeHeaders::Yes => self.metadata.into_sanitized_headers(),
183            SanitizeHeaders::No => self.metadata.into_headers(),
184        };
185        *request.extensions_mut() = self.extensions.into_http();
186
187        request
188    }
189
190    #[doc(hidden)]
191    pub fn map<F, U>(self, f: F) -> Request<U>
192    where
193        F: FnOnce(T) -> U,
194    {
195        let message = f(self.message);
196
197        Request {
198            metadata: self.metadata,
199            message,
200            extensions: self.extensions,
201        }
202    }
203
204    /// Get the remote address of this connection.
205    ///
206    /// This will return `None` if the `IO` type used
207    /// does not implement `Connected` or when using a unix domain socket.
208    /// This currently only works on the server side.
209    pub fn remote_addr(&self) -> Option<SocketAddr> {
210        #[cfg(feature = "transport")]
211        {
212            #[cfg(feature = "tls")]
213            {
214                self.extensions()
215                    .get::<TcpConnectInfo>()
216                    .and_then(|i| i.remote_addr())
217                    .or_else(|| {
218                        self.extensions()
219                            .get::<TlsConnectInfo<TcpConnectInfo>>()
220                            .and_then(|i| i.get_ref().remote_addr())
221                    })
222            }
223
224            #[cfg(not(feature = "tls"))]
225            {
226                self.extensions()
227                    .get::<TcpConnectInfo>()
228                    .and_then(|i| i.remote_addr())
229            }
230        }
231
232        #[cfg(not(feature = "transport"))]
233        {
234            None
235        }
236    }
237
238    /// Get the peer certificates of the connected client.
239    ///
240    /// This is used to fetch the certificates from the TLS session
241    /// and is mostly used for mTLS. This currently only returns
242    /// `Some` on the server side of the `transport` server with
243    /// TLS enabled connections.
244    #[cfg(feature = "transport")]
245    #[cfg_attr(docsrs, doc(cfg(feature = "transport")))]
246    pub fn peer_certs(&self) -> Option<Arc<Vec<Certificate>>> {
247        #[cfg(feature = "tls")]
248        {
249            self.extensions()
250                .get::<TlsConnectInfo<TcpConnectInfo>>()
251                .and_then(|i| i.peer_certs())
252        }
253
254        #[cfg(not(feature = "tls"))]
255        {
256            None
257        }
258    }
259
260    /// Set the max duration the request is allowed to take.
261    ///
262    /// Requires the server to support the `grpc-timeout` metadata, which Tonic does.
263    ///
264    /// The duration will be formatted according to [the spec] and use the most precise unit
265    /// possible.
266    ///
267    /// Example:
268    ///
269    /// ```rust
270    /// use std::time::Duration;
271    /// use tonic::Request;
272    ///
273    /// let mut request = Request::new(());
274    ///
275    /// request.set_timeout(Duration::from_secs(30));
276    ///
277    /// let value = request.metadata().get("grpc-timeout").unwrap();
278    ///
279    /// assert_eq!(
280    ///     value,
281    ///     // equivalent to 30 seconds
282    ///     "30000000u"
283    /// );
284    /// ```
285    ///
286    /// [the spec]: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md
287    pub fn set_timeout(&mut self, deadline: Duration) {
288        let value: MetadataValue<_> = duration_to_grpc_timeout(deadline).parse().unwrap();
289        self.metadata_mut()
290            .insert(crate::metadata::GRPC_TIMEOUT_HEADER, value);
291    }
292
293    /// Returns a reference to the associated extensions.
294    pub fn extensions(&self) -> &Extensions {
295        &self.extensions
296    }
297
298    /// Returns a mutable reference to the associated extensions.
299    ///
300    /// # Example
301    ///
302    /// Extensions can be set in interceptors:
303    ///
304    /// ```no_run
305    /// use tonic::{Request, service::interceptor};
306    ///
307    /// struct MyExtension {
308    ///     some_piece_of_data: String,
309    /// }
310    ///
311    /// interceptor(|mut request: Request<()>| {
312    ///     request.extensions_mut().insert(MyExtension {
313    ///         some_piece_of_data: "foo".to_string(),
314    ///     });
315    ///
316    ///     Ok(request)
317    /// });
318    /// ```
319    ///
320    /// And picked up by RPCs:
321    ///
322    /// ```no_run
323    /// use tonic::{async_trait, Status, Request, Response};
324    /// #
325    /// # struct Output {}
326    /// # struct Input;
327    /// # struct MyService;
328    /// # struct MyExtension;
329    /// # #[async_trait]
330    /// # trait TestService {
331    /// #     async fn handler(&self, req: Request<Input>) -> Result<Response<Output>, Status>;
332    /// # }
333    ///
334    /// #[async_trait]
335    /// impl TestService for MyService {
336    ///     async fn handler(&self, req: Request<Input>) -> Result<Response<Output>, Status> {
337    ///         let value: &MyExtension = req.extensions().get::<MyExtension>().unwrap();
338    ///
339    ///         Ok(Response::new(Output {}))
340    ///     }
341    /// }
342    /// ```
343    pub fn extensions_mut(&mut self) -> &mut Extensions {
344        &mut self.extensions
345    }
346}
347
348impl<T> IntoRequest<T> for T {
349    fn into_request(self) -> Request<Self> {
350        Request::new(self)
351    }
352}
353
354impl<T> IntoRequest<T> for Request<T> {
355    fn into_request(self) -> Request<T> {
356        self
357    }
358}
359
360impl<T> IntoStreamingRequest for T
361where
362    T: Stream + Send + 'static,
363{
364    type Stream = T;
365    type Message = T::Item;
366
367    fn into_streaming_request(self) -> Request<Self> {
368        Request::new(self)
369    }
370}
371
372impl<T> IntoStreamingRequest for Request<T>
373where
374    T: Stream + Send + 'static,
375{
376    type Stream = T;
377    type Message = T::Item;
378
379    fn into_streaming_request(self) -> Self {
380        self
381    }
382}
383
384impl<T> sealed::Sealed for T {}
385
386mod sealed {
387    pub trait Sealed {}
388}
389
390fn duration_to_grpc_timeout(duration: Duration) -> String {
391    fn try_format<T: Into<u128>>(
392        duration: Duration,
393        unit: char,
394        convert: impl FnOnce(Duration) -> T,
395    ) -> Option<String> {
396        // The gRPC spec specifies that the timeout most be at most 8 digits. So this is the largest a
397        // value can be before we need to use a bigger unit.
398        let max_size: u128 = 99_999_999; // exactly 8 digits
399
400        let value = convert(duration).into();
401        if value > max_size {
402            None
403        } else {
404            Some(format!("{}{}", value, unit))
405        }
406    }
407
408    // pick the most precise unit that is less than or equal to 8 digits as per the gRPC spec
409    try_format(duration, 'n', |d| d.as_nanos())
410        .or_else(|| try_format(duration, 'u', |d| d.as_micros()))
411        .or_else(|| try_format(duration, 'm', |d| d.as_millis()))
412        .or_else(|| try_format(duration, 'S', |d| d.as_secs()))
413        .or_else(|| try_format(duration, 'M', |d| d.as_secs() / 60))
414        .or_else(|| {
415            try_format(duration, 'H', |d| {
416                let minutes = d.as_secs() / 60;
417                minutes / 60
418            })
419        })
420        // duration has to be more than 11_415 years for this to happen
421        .expect("duration is unrealistically large")
422}
423
424/// When converting a `tonic::Request` into a `http::Request` should reserved
425/// headers be removed?
426pub(crate) enum SanitizeHeaders {
427    Yes,
428    No,
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434    use crate::metadata::MetadataValue;
435    use http::Uri;
436
437    #[test]
438    fn reserved_headers_are_excluded() {
439        let mut r = Request::new(1);
440
441        for header in &MetadataMap::GRPC_RESERVED_HEADERS {
442            r.metadata_mut()
443                .insert(*header, MetadataValue::from_static("invalid"));
444        }
445
446        let http_request = r.into_http(
447            Uri::default(),
448            http::Method::POST,
449            http::Version::HTTP_2,
450            SanitizeHeaders::Yes,
451        );
452        assert!(http_request.headers().is_empty());
453    }
454
455    #[test]
456    fn duration_to_grpc_timeout_less_than_second() {
457        let timeout = Duration::from_millis(500);
458        let value = duration_to_grpc_timeout(timeout);
459        assert_eq!(value, format!("{}u", timeout.as_micros()));
460    }
461
462    #[test]
463    fn duration_to_grpc_timeout_more_than_second() {
464        let timeout = Duration::from_secs(30);
465        let value = duration_to_grpc_timeout(timeout);
466        assert_eq!(value, format!("{}u", timeout.as_micros()));
467    }
468
469    #[test]
470    fn duration_to_grpc_timeout_a_very_long_time() {
471        let one_hour = Duration::from_secs(60 * 60);
472        let value = duration_to_grpc_timeout(one_hour);
473        assert_eq!(value, format!("{}m", one_hour.as_millis()));
474    }
475}