Skip to main content

trillium_grpc/server/
dispatch.rs

1//! Server-side request dispatch.
2//!
3//! The [`Server`] trait's methods are what generated code calls per RPC,
4//! resolved through a blanket impl so `Prost::unary(conn, ...)` works without a
5//! turbofish. The three half-duplex shapes (unary, client-streaming,
6//! server-streaming) run entirely in `Handler::run`: they take the `Conn`,
7//! drive the codec/framing around your method, and return a normal trillium
8//! response whose body carries `grpc-status` in HTTP trailers — no `Upgrade`.
9//! Bidi is the one shape that upgrades, because read-while-write requires the
10//! response head already flushed.
11//!
12//! [`prepare_grpc_conn`] is the shared preflight (content-type / `te: trailers`
13//! validation, response content-type / `grpc-accept-encoding`), called from
14//! generated `Handler::run` after path matching.
15
16use crate::{
17    Codec, Encoding, Status,
18    frame::writer::encode_frame,
19    server::{
20        bidi::{BidiResponder, BidiUpgrade},
21        body::{CancelSignal, OneShotBody, StreamBody},
22        content_type::{has_te_trailers, parse_grpc_content_type},
23        grpc_conn::GrpcServerConn,
24    },
25    timeout::parse_grpc_timeout,
26};
27use futures_lite::Stream;
28use std::{future::Future, time::Instant};
29use trillium::{Body, Conn, Headers, KnownHeaderName, Status as HttpStatus, Swansong, Upgrade};
30use trillium_http::BodySource;
31use trillium_server_common::Runtime;
32
33/// Server-side dispatch methods, available on any codec type via a blanket
34/// impl. Generated code calls these as `Prost::unary(conn, ...)` etc.
35///
36/// The three half-duplex methods take the `Conn` by value and return the
37/// finished `Conn`; the user closure is handed a [`GrpcServerConn`] control surface
38/// (and, for unary / server-streaming, the decoded request). [`bidi`](Self::bidi)
39/// still takes a [`trillium::Upgrade`] — it is driven from `Handler::upgrade`
40/// after the head is flushed.
41#[allow(async_fn_in_trait)]
42pub trait Server: Sized + 'static {
43    /// Unary RPC: read exactly one request, await the user function, emit one
44    /// response frame followed by `grpc-status` trailers.
45    async fn unary<Req, Resp>(
46        conn: Conn,
47        f: impl AsyncFnOnce(&mut GrpcServerConn<Self>, Req) -> Result<Resp, Status>,
48    ) -> Conn
49    where
50        Self: Codec<Req> + Codec<Resp>,
51        Req: Send + 'static,
52        Resp: Send + 'static,
53    {
54        unary_impl::<Self, Req, Resp>(conn, f).await
55    }
56
57    /// Client-streaming RPC: hand the user a [`GrpcServerConn`] from which they read
58    /// the request stream (`conn.requests::<Req>()`); emit the single response
59    /// frame and `grpc-status` trailers.
60    async fn client_streaming<Resp>(
61        conn: Conn,
62        f: impl AsyncFnOnce(&mut GrpcServerConn<Self>) -> Result<Resp, Status>,
63    ) -> Conn
64    where
65        Self: Codec<Resp>,
66        Resp: Send + 'static,
67    {
68        client_streaming_impl::<Self, Resp>(conn, f).await
69    }
70
71    /// Server-streaming RPC: read one request, await the user function for a
72    /// response [`Stream`], then frame each item lazily into the response body
73    /// with `grpc-status` trailers derived from how the stream ended.
74    async fn server_streaming<Req, Resp, S>(
75        conn: Conn,
76        f: impl AsyncFnOnce(&mut GrpcServerConn<Self>, Req) -> Result<S, Status>,
77    ) -> Conn
78    where
79        Self: Codec<Req> + Codec<Resp>,
80        Req: Send + 'static,
81        Resp: Send + 'static,
82        S: Stream<Item = Result<Resp, Status>> + Send + 'static,
83    {
84        server_streaming_impl::<Self, Req, Resp, S>(conn, f).await
85    }
86
87    /// Bidirectional-streaming RPC — the run-phase prologue. Hand the user a
88    /// [`GrpcServerConn`] from which they may read early request messages (to decide
89    /// response headers) and set initial metadata, then return a
90    /// [`BidiResponder`] that drives the read-while-write loop after the head is
91    /// flushed. Returning `Err(Status)` rejects before the flush (trailers-only,
92    /// no upgrade). See [`crate::server::bidi`] for the seam mechanics.
93    async fn bidi<Req, Resp, R>(
94        conn: Conn,
95        prologue: impl AsyncFnOnce(&mut GrpcServerConn<Self>) -> Result<R, Status>,
96    ) -> Conn
97    where
98        Self: Codec<Req> + Codec<Resp>,
99        Req: Send + 'static,
100        Resp: Send + 'static,
101        R: BidiResponder<Req, Resp>,
102    {
103        bidi_prologue_impl::<Self, Req, Resp, R>(conn, prologue).await
104    }
105}
106
107impl<T: Sized + 'static> Server for T {}
108
109// ── half-duplex (run-phase) ────────────────────────────────────────────────
110
111async fn unary_impl<C, Req, Resp>(
112    conn: Conn,
113    f: impl AsyncFnOnce(&mut GrpcServerConn<C>, Req) -> Result<Resp, Status>,
114) -> Conn
115where
116    C: Codec<Req> + Codec<Resp>,
117    Req: Send + 'static,
118    Resp: Send + 'static,
119{
120    let request_encoding = match extract_request_encoding(conn.request_headers()) {
121        Ok(e) => e,
122        Err(status) => return error_response(conn, status),
123    };
124    let cancellation = match Cancellation::from_conn(&conn) {
125        Ok(c) => c,
126        Err(status) => return error_response(conn, status),
127    };
128    let response_encoding = negotiate_response_encoding(conn.request_headers());
129    let mut grpc = GrpcServerConn::<C>::new(conn, request_encoding);
130
131    let result = cancellation
132        .race(async {
133            let req = read_one::<C, Req>(&mut grpc).await?;
134            f(&mut grpc, req).await
135        })
136        .await;
137    let (conn, trailers) = grpc.into_parts();
138    finish_unary::<C, Resp>(conn, result, response_encoding, trailers)
139}
140
141async fn client_streaming_impl<C, Resp>(
142    conn: Conn,
143    f: impl AsyncFnOnce(&mut GrpcServerConn<C>) -> Result<Resp, Status>,
144) -> Conn
145where
146    C: Codec<Resp>,
147    Resp: Send + 'static,
148{
149    let request_encoding = match extract_request_encoding(conn.request_headers()) {
150        Ok(e) => e,
151        Err(status) => return error_response(conn, status),
152    };
153    let cancellation = match Cancellation::from_conn(&conn) {
154        Ok(c) => c,
155        Err(status) => return error_response(conn, status),
156    };
157    let response_encoding = negotiate_response_encoding(conn.request_headers());
158    let mut grpc = GrpcServerConn::<C>::new(conn, request_encoding);
159
160    let result = cancellation.race(f(&mut grpc)).await;
161    let (conn, trailers) = grpc.into_parts();
162    finish_unary::<C, Resp>(conn, result, response_encoding, trailers)
163}
164
165async fn server_streaming_impl<C, Req, Resp, S>(
166    conn: Conn,
167    f: impl AsyncFnOnce(&mut GrpcServerConn<C>, Req) -> Result<S, Status>,
168) -> Conn
169where
170    C: Codec<Req> + Codec<Resp>,
171    Req: Send + 'static,
172    Resp: Send + 'static,
173    S: Stream<Item = Result<Resp, Status>> + Send + 'static,
174{
175    let request_encoding = match extract_request_encoding(conn.request_headers()) {
176        Ok(e) => e,
177        Err(status) => return error_response(conn, status),
178    };
179    let cancellation = match Cancellation::from_conn(&conn) {
180        Ok(c) => c,
181        Err(status) => return error_response(conn, status),
182    };
183    let response_encoding = negotiate_response_encoding(conn.request_headers());
184    let mut grpc = GrpcServerConn::<C>::new(conn, request_encoding);
185
186    // The setup (read one request + the user fn that produces the stream) is
187    // raced against cancellation; the resulting stream is then itself made
188    // cancellable so an in-flight deadline / shutdown cuts it between frames.
189    let result = cancellation
190        .race(async {
191            let req = read_one::<C, Req>(&mut grpc).await?;
192            f(&mut grpc, req).await
193        })
194        .await;
195    let (conn, trailers) = grpc.into_parts();
196    match result {
197        Ok(stream) => respond(
198            conn,
199            StreamBody::new(
200                stream,
201                <C as Codec<Resp>>::encode,
202                response_encoding,
203                trailers,
204                Some(cancellation.signal()),
205            ),
206        ),
207        Err(status) => error_response_with_trailers(conn, status, trailers),
208    }
209}
210
211/// Read exactly one request message from the conn's request body, enforcing
212/// the unary / server-streaming cardinality of *exactly one* request. Zero or
213/// more-than-one is a cardinality violation, which the gRPC status-code
214/// guidance maps to `unimplemented`.
215async fn read_one<C, Req>(grpc: &mut GrpcServerConn<C>) -> Result<Req, Status>
216where
217    C: Codec<Req>,
218    Req: 'static,
219{
220    let mut requests = grpc.requests::<Req>();
221    let Some(req) = requests.recv().await? else {
222        return Err(Status::unimplemented(
223            "expected exactly one request message, but the request stream was empty",
224        ));
225    };
226    if requests.recv().await?.is_some() {
227        return Err(Status::unimplemented(
228            "expected exactly one request message, but the request stream had more than one",
229        ));
230    }
231    Ok(req)
232}
233
234/// Encode the single-response result (unary / client-streaming) into a
235/// `OneShotBody` + trailers, or an error response.
236fn finish_unary<C, Resp>(
237    conn: Conn,
238    result: Result<Resp, Status>,
239    response_encoding: Encoding,
240    trailers: Headers,
241) -> Conn
242where
243    C: Codec<Resp>,
244{
245    match result {
246        Ok(resp) => match encode_frame::<C, Resp>(&resp, response_encoding) {
247            Ok(frame) => {
248                let mut trailers = trailers;
249                Status::ok().write_into(&mut trailers);
250                respond(conn, OneShotBody::new(frame, trailers))
251            }
252            Err(status) => error_response_with_trailers(conn, status, trailers),
253        },
254        Err(status) => error_response_with_trailers(conn, status, trailers),
255    }
256}
257
258/// Set a half-duplex response body (carrying `grpc-status` in trailers) and
259/// halt — this handler has fully produced the response.
260fn respond(conn: Conn, body: impl BodySource) -> Conn {
261    conn.with_body(Body::new_with_trailers(body, None)).halt()
262}
263
264/// An error response: response headers (already set by preflight) + empty body
265/// + `grpc-status` trailers merged onto any trailing metadata the handler set.
266fn error_response_with_trailers(conn: Conn, status: Status, mut trailers: Headers) -> Conn {
267    status.write_into(&mut trailers);
268    respond(conn, OneShotBody::new(Vec::new(), trailers))
269}
270
271/// An error response with no handler-set trailing metadata.
272fn error_response(conn: Conn, status: Status) -> Conn {
273    error_response_with_trailers(conn, status, Headers::new())
274}
275
276// ── bidi (run-phase prologue) ────────────────────────────────────────────────
277
278/// Run the bidi prologue in `run()`: build the [`GrpcServerConn`], let the user read
279/// early requests and set initial metadata, then either stash a [`BidiUpgrade`]
280/// and mark the conn for upgrade (the responder drives the loop in `upgrade()`)
281/// or, on `Err`, emit a trailers-only error without upgrading. The
282/// read-while-write loop itself lives in [`crate::server::bidi`].
283async fn bidi_prologue_impl<C, Req, Resp, R>(
284    conn: Conn,
285    prologue: impl AsyncFnOnce(&mut GrpcServerConn<C>) -> Result<R, Status>,
286) -> Conn
287where
288    C: Codec<Req> + Codec<Resp>,
289    Req: Send + 'static,
290    Resp: Send + 'static,
291    R: BidiResponder<Req, Resp>,
292{
293    let request_encoding = match extract_request_encoding(conn.request_headers()) {
294        Ok(e) => e,
295        Err(status) => return error_response(conn, status),
296    };
297    let cancellation = match Cancellation::from_conn(&conn) {
298        Ok(c) => c,
299        Err(status) => return error_response(conn, status),
300    };
301    let response_encoding = negotiate_response_encoding(conn.request_headers());
302    let mut grpc = GrpcServerConn::<C>::new(conn, request_encoding);
303
304    let result = cancellation.race(prologue(&mut grpc)).await;
305    let deadline = grpc.deadline();
306    let (conn, trailers) = grpc.into_parts();
307
308    match result {
309        Ok(responder) => {
310            let bidi = BidiUpgrade::new(
311                responder,
312                trailers,
313                <C as Codec<Req>>::decode,
314                <C as Codec<Resp>>::encode,
315                request_encoding,
316                response_encoding,
317                deadline,
318            );
319            conn.with_state(bidi).upgrade().halt()
320        }
321        Err(status) => error_response_with_trailers(conn, status, trailers),
322    }
323}
324
325// ── preflight + shared helpers ───────────────────────────────────────────────
326
327/// Validate request preflight (content-type, te:trailers) and set the gRPC
328/// response headers (content-type, grpc-accept-encoding). Returns the conn with
329/// those headers set, or an error-shaped conn if preflight failed.
330///
331/// Called from generated `Handler::run` *after* path matching has confirmed
332/// this request belongs to the service.
333// Both arms are the same `trillium::Conn`, so boxing only the `Err` side would be
334// pointless asymmetry — the equally-large `Ok` conn moves by value regardless, and
335// this runs once per request, not on a hot path.
336#[allow(clippy::result_large_err)]
337pub fn prepare_grpc_conn(conn: Conn, codec_suffix: &str) -> Result<Conn, Conn> {
338    // The request's content-type must name a gRPC codec we actually serve;
339    // `application/grpc+foo` for an unknown `foo` is rejected here rather than
340    // mis-decoded as the service's codec.
341    let codec_matches = conn
342        .request_headers()
343        .get_str(KnownHeaderName::ContentType)
344        .and_then(parse_grpc_content_type)
345        .is_some_and(|suffix| suffix == codec_suffix);
346    if !codec_matches {
347        return Err(conn.with_status(HttpStatus::UnsupportedMediaType).halt());
348    }
349    if !has_te_trailers(conn.request_headers()) {
350        return Err(conn.with_status(HttpStatus::BadRequest).halt());
351    }
352    let content_type = format!("application/grpc+{codec_suffix}");
353    let response_encoding = negotiate_response_encoding(conn.request_headers());
354    let conn = conn
355        .with_response_header(KnownHeaderName::ContentType, content_type)
356        .with_response_header("grpc-accept-encoding", Encoding::accepted_encodings())
357        .with_status(HttpStatus::Ok);
358    Ok(if matches!(response_encoding, Encoding::Identity) {
359        conn
360    } else {
361        conn.with_response_header("grpc-encoding", response_encoding.as_grpc_encoding())
362    })
363}
364
365/// Resolve the inbound message encoding from `grpc-encoding`. Missing →
366/// `Identity` (per spec). Unknown → `Unimplemented`.
367fn extract_request_encoding(request_headers: &Headers) -> Result<Encoding, Status> {
368    match request_headers.get_str("grpc-encoding") {
369        None => Ok(Encoding::Identity),
370        Some(s) => Encoding::from_grpc_encoding(s).ok_or_else(|| {
371            Status::unimplemented(format!(
372                "unsupported grpc-encoding {s:?}; accepted: {}",
373                Encoding::accepted_encodings()
374            ))
375        }),
376    }
377}
378
379/// Per-request cancellation handle, combining connection shutdown
380/// (`Conn`/`Upgrade::swansong()`) and an optional `grpc-timeout` deadline.
381pub(crate) struct Cancellation {
382    swansong: Swansong,
383    deadline: Option<Deadline>,
384}
385
386#[derive(Clone)]
387struct Deadline {
388    runtime: Runtime,
389    instant: Instant,
390}
391
392impl Cancellation {
393    /// Build from a `Conn` (the half-duplex run-phase path). Mirrors
394    /// [`from_upgrade`](Self::from_upgrade): connection shutdown via
395    /// `Conn::swansong()` and an optional `grpc-timeout` deadline.
396    fn from_conn(conn: &Conn) -> Result<Self, Status> {
397        let swansong = conn.swansong();
398        let deadline = match conn.request_headers().get_str("grpc-timeout") {
399            None => None,
400            Some(header) => {
401                let duration = parse_grpc_timeout(header).ok_or_else(|| {
402                    Status::invalid_argument(format!("malformed grpc-timeout {header:?}"))
403                })?;
404                let runtime = conn
405                    .shared_state::<Runtime>()
406                    .expect("trillium-grpc requires a Runtime in shared state")
407                    .clone();
408                Some(Deadline {
409                    runtime,
410                    instant: Instant::now() + duration,
411                })
412            }
413        };
414        Ok(Self { swansong, deadline })
415    }
416
417    /// A future that resolves to the terminal [`Status`] when this request is
418    /// cancelled (shutdown → `cancelled`, deadline → `deadline_exceeded`) and
419    /// stays `Pending` otherwise. Handed to [`StreamBody`] so a server-streaming
420    /// response can be cut between frames.
421    fn signal(&self) -> CancelSignal {
422        let swansong = self.swansong.clone();
423        let deadline = self.deadline.clone();
424        Box::pin(async move {
425            let shutdown = async {
426                swansong.interrupt(std::future::pending::<()>()).await;
427                Status::cancelled("connection shutting down")
428            };
429            match deadline {
430                None => shutdown.await,
431                Some(d) => {
432                    let timer = async move {
433                        if let Some(remaining) = d.instant.checked_duration_since(Instant::now()) {
434                            d.runtime.delay(remaining).await;
435                        }
436                        Status::deadline_exceeded("deadline elapsed")
437                    };
438                    futures_lite::future::or(shutdown, timer).await
439                }
440            }
441        })
442    }
443
444    /// Build for the upgrade (bidi loop) phase from a deadline already computed
445    /// in the prologue (so the clock isn't restarted by however long the
446    /// prologue ran). Shutdown comes from the upgrade's swansong; the deadline
447    /// timer reuses the request's `Runtime` from shared state.
448    pub(crate) fn for_upgrade(upgrade: &Upgrade, deadline: Option<Instant>) -> Self {
449        let swansong = upgrade.swansong();
450        let deadline = deadline.map(|instant| {
451            let runtime = upgrade
452                .shared_state()
453                .get::<Runtime>()
454                .expect("trillium-grpc requires a Runtime in shared state")
455                .clone();
456            Deadline { runtime, instant }
457        });
458        Self { swansong, deadline }
459    }
460
461    pub(crate) async fn race<T, F>(&self, fut: F) -> Result<T, Status>
462    where
463        F: Future<Output = Result<T, Status>>,
464    {
465        let interruptible = async {
466            match self.swansong.interrupt(fut).await {
467                Some(result) => result,
468                None => Err(Status::cancelled("connection shutting down")),
469            }
470        };
471        let Some(deadline) = self.deadline.as_ref() else {
472            return interruptible.await;
473        };
474        let Some(remaining) = deadline.instant.checked_duration_since(Instant::now()) else {
475            return Err(Status::deadline_exceeded("deadline elapsed"));
476        };
477        let runtime = deadline.runtime.clone();
478        let timer = async move {
479            runtime.delay(remaining).await;
480            Err(Status::deadline_exceeded("deadline elapsed"))
481        };
482        futures_lite::future::or(interruptible, timer).await
483    }
484}
485
486/// Pick the response encoding by intersecting the client's
487/// `grpc-accept-encoding` with `Encoding::ALL` in build order (prefer gzip,
488/// deflate, zstd, then identity).
489fn negotiate_response_encoding(request_headers: &Headers) -> Encoding {
490    let Some(accepted) = request_headers.get_str("grpc-accept-encoding") else {
491        return Encoding::Identity;
492    };
493    let accepted: Vec<&str> = accepted.split(',').map(str::trim).collect();
494    Encoding::ALL
495        .iter()
496        .copied()
497        .filter(|e| !matches!(e, Encoding::Identity))
498        .find(|e| accepted.contains(&e.as_grpc_encoding()))
499        .unwrap_or(Encoding::Identity)
500}
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505
506    fn headers_with(accept: &str) -> Headers {
507        let mut h = Headers::new();
508        h.insert("grpc-accept-encoding", accept.to_owned());
509        h
510    }
511
512    #[test]
513    fn no_accept_header_falls_back_to_identity() {
514        assert_eq!(
515            negotiate_response_encoding(&Headers::new()),
516            Encoding::Identity
517        );
518    }
519
520    #[test]
521    fn identity_only_means_identity() {
522        assert_eq!(
523            negotiate_response_encoding(&headers_with("identity")),
524            Encoding::Identity
525        );
526    }
527
528    #[cfg(feature = "gzip")]
529    #[test]
530    fn picks_gzip_when_offered() {
531        assert_eq!(
532            negotiate_response_encoding(&headers_with("identity, gzip")),
533            Encoding::Gzip
534        );
535    }
536
537    #[cfg(all(feature = "gzip", feature = "zstd"))]
538    #[test]
539    fn prefers_build_order_over_client_order() {
540        assert_eq!(
541            negotiate_response_encoding(&headers_with("zstd, gzip")),
542            Encoding::Gzip
543        );
544    }
545
546    #[cfg(feature = "gzip")]
547    #[test]
548    fn ignores_unknown_codecs() {
549        assert_eq!(
550            negotiate_response_encoding(&headers_with("snappy, gzip")),
551            Encoding::Gzip
552        );
553        assert_eq!(
554            negotiate_response_encoding(&headers_with("snappy")),
555            Encoding::Identity
556        );
557    }
558}