1use crate::{
17 Codec, Encoding, Status,
18 frame::writer::encode_frame,
19 server::{
20 bidi::{BidiResponder, BidiUpgrade},
21 body::{CancelSignal, OneShotBody, StreamBody},
22 content_type::{has_te_trailers, parse_grpc_content_type},
23 grpc_conn::GrpcServerConn,
24 },
25 timeout::parse_grpc_timeout,
26};
27use futures_lite::Stream;
28use std::{future::Future, time::Instant};
29use trillium::{Body, Conn, Headers, KnownHeaderName, Status as HttpStatus, Swansong, Upgrade};
30use trillium_http::BodySource;
31use trillium_server_common::Runtime;
32
33#[allow(async_fn_in_trait)]
42pub trait Server: Sized + 'static {
43 async fn unary<Req, Resp>(
46 conn: Conn,
47 f: impl AsyncFnOnce(&mut GrpcServerConn<Self>, Req) -> Result<Resp, Status>,
48 ) -> Conn
49 where
50 Self: Codec<Req> + Codec<Resp>,
51 Req: Send + 'static,
52 Resp: Send + 'static,
53 {
54 unary_impl::<Self, Req, Resp>(conn, f).await
55 }
56
57 async fn client_streaming<Resp>(
61 conn: Conn,
62 f: impl AsyncFnOnce(&mut GrpcServerConn<Self>) -> Result<Resp, Status>,
63 ) -> Conn
64 where
65 Self: Codec<Resp>,
66 Resp: Send + 'static,
67 {
68 client_streaming_impl::<Self, Resp>(conn, f).await
69 }
70
71 async fn server_streaming<Req, Resp, S>(
75 conn: Conn,
76 f: impl AsyncFnOnce(&mut GrpcServerConn<Self>, Req) -> Result<S, Status>,
77 ) -> Conn
78 where
79 Self: Codec<Req> + Codec<Resp>,
80 Req: Send + 'static,
81 Resp: Send + 'static,
82 S: Stream<Item = Result<Resp, Status>> + Send + 'static,
83 {
84 server_streaming_impl::<Self, Req, Resp, S>(conn, f).await
85 }
86
87 async fn bidi<Req, Resp, R>(
94 conn: Conn,
95 prologue: impl AsyncFnOnce(&mut GrpcServerConn<Self>) -> Result<R, Status>,
96 ) -> Conn
97 where
98 Self: Codec<Req> + Codec<Resp>,
99 Req: Send + 'static,
100 Resp: Send + 'static,
101 R: BidiResponder<Req, Resp>,
102 {
103 bidi_prologue_impl::<Self, Req, Resp, R>(conn, prologue).await
104 }
105}
106
107impl<T: Sized + 'static> Server for T {}
108
109async fn unary_impl<C, Req, Resp>(
112 conn: Conn,
113 f: impl AsyncFnOnce(&mut GrpcServerConn<C>, Req) -> Result<Resp, Status>,
114) -> Conn
115where
116 C: Codec<Req> + Codec<Resp>,
117 Req: Send + 'static,
118 Resp: Send + 'static,
119{
120 let request_encoding = match extract_request_encoding(conn.request_headers()) {
121 Ok(e) => e,
122 Err(status) => return error_response(conn, status),
123 };
124 let cancellation = match Cancellation::from_conn(&conn) {
125 Ok(c) => c,
126 Err(status) => return error_response(conn, status),
127 };
128 let response_encoding = negotiate_response_encoding(conn.request_headers());
129 let mut grpc = GrpcServerConn::<C>::new(conn, request_encoding);
130
131 let result = cancellation
132 .race(async {
133 let req = read_one::<C, Req>(&mut grpc).await?;
134 f(&mut grpc, req).await
135 })
136 .await;
137 let (conn, trailers) = grpc.into_parts();
138 finish_unary::<C, Resp>(conn, result, response_encoding, trailers)
139}
140
141async fn client_streaming_impl<C, Resp>(
142 conn: Conn,
143 f: impl AsyncFnOnce(&mut GrpcServerConn<C>) -> Result<Resp, Status>,
144) -> Conn
145where
146 C: Codec<Resp>,
147 Resp: Send + 'static,
148{
149 let request_encoding = match extract_request_encoding(conn.request_headers()) {
150 Ok(e) => e,
151 Err(status) => return error_response(conn, status),
152 };
153 let cancellation = match Cancellation::from_conn(&conn) {
154 Ok(c) => c,
155 Err(status) => return error_response(conn, status),
156 };
157 let response_encoding = negotiate_response_encoding(conn.request_headers());
158 let mut grpc = GrpcServerConn::<C>::new(conn, request_encoding);
159
160 let result = cancellation.race(f(&mut grpc)).await;
161 let (conn, trailers) = grpc.into_parts();
162 finish_unary::<C, Resp>(conn, result, response_encoding, trailers)
163}
164
165async fn server_streaming_impl<C, Req, Resp, S>(
166 conn: Conn,
167 f: impl AsyncFnOnce(&mut GrpcServerConn<C>, Req) -> Result<S, Status>,
168) -> Conn
169where
170 C: Codec<Req> + Codec<Resp>,
171 Req: Send + 'static,
172 Resp: Send + 'static,
173 S: Stream<Item = Result<Resp, Status>> + Send + 'static,
174{
175 let request_encoding = match extract_request_encoding(conn.request_headers()) {
176 Ok(e) => e,
177 Err(status) => return error_response(conn, status),
178 };
179 let cancellation = match Cancellation::from_conn(&conn) {
180 Ok(c) => c,
181 Err(status) => return error_response(conn, status),
182 };
183 let response_encoding = negotiate_response_encoding(conn.request_headers());
184 let mut grpc = GrpcServerConn::<C>::new(conn, request_encoding);
185
186 let result = cancellation
190 .race(async {
191 let req = read_one::<C, Req>(&mut grpc).await?;
192 f(&mut grpc, req).await
193 })
194 .await;
195 let (conn, trailers) = grpc.into_parts();
196 match result {
197 Ok(stream) => respond(
198 conn,
199 StreamBody::new(
200 stream,
201 <C as Codec<Resp>>::encode,
202 response_encoding,
203 trailers,
204 Some(cancellation.signal()),
205 ),
206 ),
207 Err(status) => error_response_with_trailers(conn, status, trailers),
208 }
209}
210
211async fn read_one<C, Req>(grpc: &mut GrpcServerConn<C>) -> Result<Req, Status>
216where
217 C: Codec<Req>,
218 Req: 'static,
219{
220 let mut requests = grpc.requests::<Req>();
221 let Some(req) = requests.recv().await? else {
222 return Err(Status::unimplemented(
223 "expected exactly one request message, but the request stream was empty",
224 ));
225 };
226 if requests.recv().await?.is_some() {
227 return Err(Status::unimplemented(
228 "expected exactly one request message, but the request stream had more than one",
229 ));
230 }
231 Ok(req)
232}
233
234fn finish_unary<C, Resp>(
237 conn: Conn,
238 result: Result<Resp, Status>,
239 response_encoding: Encoding,
240 trailers: Headers,
241) -> Conn
242where
243 C: Codec<Resp>,
244{
245 match result {
246 Ok(resp) => match encode_frame::<C, Resp>(&resp, response_encoding) {
247 Ok(frame) => {
248 let mut trailers = trailers;
249 Status::ok().write_into(&mut trailers);
250 respond(conn, OneShotBody::new(frame, trailers))
251 }
252 Err(status) => error_response_with_trailers(conn, status, trailers),
253 },
254 Err(status) => error_response_with_trailers(conn, status, trailers),
255 }
256}
257
258fn respond(conn: Conn, body: impl BodySource) -> Conn {
261 conn.with_body(Body::new_with_trailers(body, None)).halt()
262}
263
264fn error_response_with_trailers(conn: Conn, status: Status, mut trailers: Headers) -> Conn {
267 status.write_into(&mut trailers);
268 respond(conn, OneShotBody::new(Vec::new(), trailers))
269}
270
271fn error_response(conn: Conn, status: Status) -> Conn {
273 error_response_with_trailers(conn, status, Headers::new())
274}
275
276async fn bidi_prologue_impl<C, Req, Resp, R>(
284 conn: Conn,
285 prologue: impl AsyncFnOnce(&mut GrpcServerConn<C>) -> Result<R, Status>,
286) -> Conn
287where
288 C: Codec<Req> + Codec<Resp>,
289 Req: Send + 'static,
290 Resp: Send + 'static,
291 R: BidiResponder<Req, Resp>,
292{
293 let request_encoding = match extract_request_encoding(conn.request_headers()) {
294 Ok(e) => e,
295 Err(status) => return error_response(conn, status),
296 };
297 let cancellation = match Cancellation::from_conn(&conn) {
298 Ok(c) => c,
299 Err(status) => return error_response(conn, status),
300 };
301 let response_encoding = negotiate_response_encoding(conn.request_headers());
302 let mut grpc = GrpcServerConn::<C>::new(conn, request_encoding);
303
304 let result = cancellation.race(prologue(&mut grpc)).await;
305 let deadline = grpc.deadline();
306 let (conn, trailers) = grpc.into_parts();
307
308 match result {
309 Ok(responder) => {
310 let bidi = BidiUpgrade::new(
311 responder,
312 trailers,
313 <C as Codec<Req>>::decode,
314 <C as Codec<Resp>>::encode,
315 request_encoding,
316 response_encoding,
317 deadline,
318 );
319 conn.with_state(bidi).upgrade().halt()
320 }
321 Err(status) => error_response_with_trailers(conn, status, trailers),
322 }
323}
324
325#[allow(clippy::result_large_err)]
337pub fn prepare_grpc_conn(conn: Conn, codec_suffix: &str) -> Result<Conn, Conn> {
338 let codec_matches = conn
342 .request_headers()
343 .get_str(KnownHeaderName::ContentType)
344 .and_then(parse_grpc_content_type)
345 .is_some_and(|suffix| suffix == codec_suffix);
346 if !codec_matches {
347 return Err(conn.with_status(HttpStatus::UnsupportedMediaType).halt());
348 }
349 if !has_te_trailers(conn.request_headers()) {
350 return Err(conn.with_status(HttpStatus::BadRequest).halt());
351 }
352 let content_type = format!("application/grpc+{codec_suffix}");
353 let response_encoding = negotiate_response_encoding(conn.request_headers());
354 let conn = conn
355 .with_response_header(KnownHeaderName::ContentType, content_type)
356 .with_response_header("grpc-accept-encoding", Encoding::accepted_encodings())
357 .with_status(HttpStatus::Ok);
358 Ok(if matches!(response_encoding, Encoding::Identity) {
359 conn
360 } else {
361 conn.with_response_header("grpc-encoding", response_encoding.as_grpc_encoding())
362 })
363}
364
365fn extract_request_encoding(request_headers: &Headers) -> Result<Encoding, Status> {
368 match request_headers.get_str("grpc-encoding") {
369 None => Ok(Encoding::Identity),
370 Some(s) => Encoding::from_grpc_encoding(s).ok_or_else(|| {
371 Status::unimplemented(format!(
372 "unsupported grpc-encoding {s:?}; accepted: {}",
373 Encoding::accepted_encodings()
374 ))
375 }),
376 }
377}
378
379pub(crate) struct Cancellation {
382 swansong: Swansong,
383 deadline: Option<Deadline>,
384}
385
386#[derive(Clone)]
387struct Deadline {
388 runtime: Runtime,
389 instant: Instant,
390}
391
392impl Cancellation {
393 fn from_conn(conn: &Conn) -> Result<Self, Status> {
397 let swansong = conn.swansong();
398 let deadline = match conn.request_headers().get_str("grpc-timeout") {
399 None => None,
400 Some(header) => {
401 let duration = parse_grpc_timeout(header).ok_or_else(|| {
402 Status::invalid_argument(format!("malformed grpc-timeout {header:?}"))
403 })?;
404 let runtime = conn
405 .shared_state::<Runtime>()
406 .expect("trillium-grpc requires a Runtime in shared state")
407 .clone();
408 Some(Deadline {
409 runtime,
410 instant: Instant::now() + duration,
411 })
412 }
413 };
414 Ok(Self { swansong, deadline })
415 }
416
417 fn signal(&self) -> CancelSignal {
422 let swansong = self.swansong.clone();
423 let deadline = self.deadline.clone();
424 Box::pin(async move {
425 let shutdown = async {
426 swansong.interrupt(std::future::pending::<()>()).await;
427 Status::cancelled("connection shutting down")
428 };
429 match deadline {
430 None => shutdown.await,
431 Some(d) => {
432 let timer = async move {
433 if let Some(remaining) = d.instant.checked_duration_since(Instant::now()) {
434 d.runtime.delay(remaining).await;
435 }
436 Status::deadline_exceeded("deadline elapsed")
437 };
438 futures_lite::future::or(shutdown, timer).await
439 }
440 }
441 })
442 }
443
444 pub(crate) fn for_upgrade(upgrade: &Upgrade, deadline: Option<Instant>) -> Self {
449 let swansong = upgrade.swansong();
450 let deadline = deadline.map(|instant| {
451 let runtime = upgrade
452 .shared_state()
453 .get::<Runtime>()
454 .expect("trillium-grpc requires a Runtime in shared state")
455 .clone();
456 Deadline { runtime, instant }
457 });
458 Self { swansong, deadline }
459 }
460
461 pub(crate) async fn race<T, F>(&self, fut: F) -> Result<T, Status>
462 where
463 F: Future<Output = Result<T, Status>>,
464 {
465 let interruptible = async {
466 match self.swansong.interrupt(fut).await {
467 Some(result) => result,
468 None => Err(Status::cancelled("connection shutting down")),
469 }
470 };
471 let Some(deadline) = self.deadline.as_ref() else {
472 return interruptible.await;
473 };
474 let Some(remaining) = deadline.instant.checked_duration_since(Instant::now()) else {
475 return Err(Status::deadline_exceeded("deadline elapsed"));
476 };
477 let runtime = deadline.runtime.clone();
478 let timer = async move {
479 runtime.delay(remaining).await;
480 Err(Status::deadline_exceeded("deadline elapsed"))
481 };
482 futures_lite::future::or(interruptible, timer).await
483 }
484}
485
486fn negotiate_response_encoding(request_headers: &Headers) -> Encoding {
490 let Some(accepted) = request_headers.get_str("grpc-accept-encoding") else {
491 return Encoding::Identity;
492 };
493 let accepted: Vec<&str> = accepted.split(',').map(str::trim).collect();
494 Encoding::ALL
495 .iter()
496 .copied()
497 .filter(|e| !matches!(e, Encoding::Identity))
498 .find(|e| accepted.contains(&e.as_grpc_encoding()))
499 .unwrap_or(Encoding::Identity)
500}
501
502#[cfg(test)]
503mod tests {
504 use super::*;
505
506 fn headers_with(accept: &str) -> Headers {
507 let mut h = Headers::new();
508 h.insert("grpc-accept-encoding", accept.to_owned());
509 h
510 }
511
512 #[test]
513 fn no_accept_header_falls_back_to_identity() {
514 assert_eq!(
515 negotiate_response_encoding(&Headers::new()),
516 Encoding::Identity
517 );
518 }
519
520 #[test]
521 fn identity_only_means_identity() {
522 assert_eq!(
523 negotiate_response_encoding(&headers_with("identity")),
524 Encoding::Identity
525 );
526 }
527
528 #[cfg(feature = "gzip")]
529 #[test]
530 fn picks_gzip_when_offered() {
531 assert_eq!(
532 negotiate_response_encoding(&headers_with("identity, gzip")),
533 Encoding::Gzip
534 );
535 }
536
537 #[cfg(all(feature = "gzip", feature = "zstd"))]
538 #[test]
539 fn prefers_build_order_over_client_order() {
540 assert_eq!(
541 negotiate_response_encoding(&headers_with("zstd, gzip")),
542 Encoding::Gzip
543 );
544 }
545
546 #[cfg(feature = "gzip")]
547 #[test]
548 fn ignores_unknown_codecs() {
549 assert_eq!(
550 negotiate_response_encoding(&headers_with("snappy, gzip")),
551 Encoding::Gzip
552 );
553 assert_eq!(
554 negotiate_response_encoding(&headers_with("snappy")),
555 Encoding::Identity
556 );
557 }
558}