Skip to main content

shaperail_runtime/grpc/
server.rs

1//! gRPC server builder (M16).
2//!
3//! Builds a Tonic gRPC server with dynamic resource services, JWT auth
4//! interceptor, server reflection, and health check.
5
6use std::net::SocketAddr;
7use std::sync::Arc;
8
9use http_body_util::BodyExt;
10use prost::bytes::Bytes;
11use shaperail_core::{GrpcConfig, ResourceDefinition};
12use tokio::task::JoinHandle;
13use tonic::server::NamedService;
14use tonic::transport::Server;
15use tonic::Status;
16
17use super::service;
18use crate::auth::extractor::AuthenticatedUser;
19use crate::auth::jwt::JwtConfig;
20use crate::handlers::crud::AppState;
21
22/// Handle to a running gRPC server — can be used to await or abort.
23pub struct GrpcServerHandle {
24    pub handle: JoinHandle<Result<(), tonic::transport::Error>>,
25    pub addr: SocketAddr,
26}
27
28/// Dynamic gRPC service that routes to resource handlers based on path.
29#[derive(Clone)]
30pub struct ShaperailGrpcService {
31    state: Arc<AppState>,
32    resources: Vec<ResourceDefinition>,
33    jwt_config: Option<Arc<JwtConfig>>,
34}
35
36impl ShaperailGrpcService {
37    pub fn new(
38        state: Arc<AppState>,
39        resources: Vec<ResourceDefinition>,
40        jwt_config: Option<Arc<JwtConfig>>,
41    ) -> Self {
42        Self {
43            state,
44            resources,
45            jwt_config,
46        }
47    }
48
49    /// Parse a gRPC path like `/shaperail.v1.users.UserService/GetUser`
50    /// into (resource_name, method_name).
51    pub fn parse_grpc_path(path: &str) -> Option<(String, String)> {
52        let path = path.strip_prefix('/')?;
53        let (service_part, method) = path.split_once('/')?;
54        let parts: Vec<&str> = service_part.split('.').collect();
55        if parts.len() >= 4 && parts[0] == "shaperail" {
56            let resource_name = parts[2].to_string();
57            Some((resource_name, method.to_string()))
58        } else {
59            None
60        }
61    }
62
63    /// Handle a unary or server-streaming gRPC call.
64    async fn handle_request(
65        &self,
66        resource_name: &str,
67        method_name: &str,
68        user: Option<&AuthenticatedUser>,
69        body: &[u8],
70    ) -> Result<GrpcResponse, Status> {
71        let resource = self
72            .resources
73            .iter()
74            .find(|r| r.resource == resource_name)
75            .ok_or_else(|| Status::not_found(format!("Unknown resource: {resource_name}")))?;
76
77        if method_name.starts_with("Get") {
78            let data = service::handle_get(self.state.clone(), resource, user, body).await?;
79            Ok(GrpcResponse::Unary(data))
80        } else if method_name.starts_with("Stream") {
81            let items =
82                service::handle_stream_list(self.state.clone(), resource, user, body).await?;
83            Ok(GrpcResponse::Stream(items))
84        } else if method_name.starts_with("List") {
85            let data = service::handle_list(self.state.clone(), resource, user, body).await?;
86            Ok(GrpcResponse::Unary(data))
87        } else if method_name.starts_with("Create") {
88            let data = service::handle_create(self.state.clone(), resource, user, body).await?;
89            Ok(GrpcResponse::Unary(data))
90        } else if method_name.starts_with("Update") {
91            Err(Status::unimplemented("Update not yet implemented"))
92        } else if method_name.starts_with("Delete") {
93            let data = service::handle_delete(self.state.clone(), resource, user, body).await?;
94            Ok(GrpcResponse::Unary(data))
95        } else {
96            Err(Status::unimplemented(format!(
97                "Unknown method: {method_name}"
98            )))
99        }
100    }
101}
102
103enum GrpcResponse {
104    Unary(Bytes),
105    Stream(Vec<Bytes>),
106}
107
108/// The tonic body type used in 0.12.
109type TonicBody = tonic::body::BoxBody;
110
111/// Wrapper implementing tonic's Service trait for dynamic dispatch.
112#[derive(Clone)]
113struct ShaperailGrpcServiceServer {
114    inner: ShaperailGrpcService,
115}
116
117impl NamedService for ShaperailGrpcServiceServer {
118    const NAME: &'static str = "shaperail";
119}
120
121impl tower::Service<http::Request<TonicBody>> for ShaperailGrpcServiceServer {
122    type Response = http::Response<TonicBody>;
123    type Error = std::convert::Infallible;
124    type Future = std::pin::Pin<
125        Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
126    >;
127
128    fn poll_ready(
129        &mut self,
130        _cx: &mut std::task::Context<'_>,
131    ) -> std::task::Poll<Result<(), Self::Error>> {
132        std::task::Poll::Ready(Ok(()))
133    }
134
135    fn call(&mut self, req: http::Request<TonicBody>) -> Self::Future {
136        let inner = self.inner.clone();
137
138        Box::pin(async move {
139            let path = req.uri().path().to_string();
140
141            // Extract auth from headers
142            let user = extract_user_from_headers(req.headers(), inner.jwt_config.as_deref());
143
144            // Collect body bytes
145            let body_bytes = collect_body(req.into_body()).await;
146
147            // Strip gRPC framing: 1 byte compression + 4 bytes length
148            let message_data = if body_bytes.len() >= 5 {
149                &body_bytes[5..]
150            } else {
151                &body_bytes[..]
152            };
153
154            // Parse path and dispatch
155            let (resource_name, method_name) = match ShaperailGrpcService::parse_grpc_path(&path) {
156                Some(v) => v,
157                None => {
158                    return Ok(grpc_error_response(
159                        tonic::Code::Unimplemented,
160                        &format!("Unknown path: {path}"),
161                    ));
162                }
163            };
164
165            match inner
166                .handle_request(&resource_name, &method_name, user.as_ref(), message_data)
167                .await
168            {
169                Ok(GrpcResponse::Unary(data)) => Ok(grpc_data_response(&data)),
170                Ok(GrpcResponse::Stream(items)) => {
171                    let mut combined = Vec::new();
172                    for item in &items {
173                        let len = item.len() as u32;
174                        combined.push(0u8);
175                        combined.extend_from_slice(&len.to_be_bytes());
176                        combined.extend_from_slice(item);
177                    }
178                    Ok(grpc_data_response(&combined))
179                }
180                Err(status) => Ok(grpc_error_response(status.code(), status.message())),
181            }
182        })
183    }
184}
185
186/// Extract a user from HTTP headers (for JWT auth via gRPC metadata).
187fn extract_user_from_headers(
188    headers: &http::HeaderMap,
189    jwt_config: Option<&JwtConfig>,
190) -> Option<AuthenticatedUser> {
191    let auth_str = headers.get("authorization")?.to_str().ok()?;
192    let token = auth_str.strip_prefix("Bearer ")?;
193    let jwt = jwt_config?;
194    let claims = jwt.decode(token).ok()?;
195    if claims.token_type != "access" {
196        return None;
197    }
198    Some(AuthenticatedUser {
199        id: claims.sub,
200        role: claims.role,
201    })
202}
203
204/// Collect body bytes from a tonic BoxBody.
205async fn collect_body(body: TonicBody) -> Bytes {
206    use http_body_util::BodyExt;
207    match body.collect().await {
208        Ok(collected) => collected.to_bytes(),
209        Err(_) => Bytes::new(),
210    }
211}
212
213/// Build a successful gRPC response with data.
214fn grpc_data_response(data: &[u8]) -> http::Response<TonicBody> {
215    // gRPC frame: 0 (no compression) + 4 byte big-endian length + data
216    let mut frame = Vec::with_capacity(5 + data.len());
217    frame.push(0u8);
218    let len = data.len() as u32;
219    frame.extend_from_slice(&len.to_be_bytes());
220    frame.extend_from_slice(data);
221
222    let body = http_body_util::Full::new(Bytes::from(frame))
223        .map_err(|never: std::convert::Infallible| match never {});
224    let boxed = TonicBody::new(body);
225
226    http::Response::builder()
227        .status(200)
228        .header("content-type", "application/grpc")
229        .header("grpc-status", "0")
230        .body(boxed)
231        .unwrap_or_else(|_| empty_grpc_response(13, "Internal error"))
232}
233
234/// Build a gRPC error response.
235fn grpc_error_response(code: tonic::Code, message: &str) -> http::Response<TonicBody> {
236    empty_grpc_response(code as i32, message)
237}
238
239/// Build an empty gRPC response with status and message headers.
240fn empty_grpc_response(code: i32, message: &str) -> http::Response<TonicBody> {
241    let body = http_body_util::Full::new(Bytes::new())
242        .map_err(|never: std::convert::Infallible| match never {});
243    let boxed = TonicBody::new(body);
244
245    http::Response::builder()
246        .status(200)
247        .header("content-type", "application/grpc")
248        .header("grpc-status", code.to_string())
249        .header("grpc-message", message)
250        .body(boxed)
251        .unwrap_or_else(|_| {
252            // Last resort fallback
253            let fb = http_body_util::Full::new(Bytes::new())
254                .map_err(|never: std::convert::Infallible| match never {});
255            http::Response::new(TonicBody::new(fb))
256        })
257}
258
259/// Build and start the gRPC server.
260///
261/// Returns a `GrpcServerHandle` that can be awaited or aborted.
262/// The server runs on a separate port from the HTTP REST/GraphQL server.
263pub async fn build_grpc_server(
264    state: Arc<AppState>,
265    resources: Vec<ResourceDefinition>,
266    jwt_config: Option<Arc<JwtConfig>>,
267    grpc_config: Option<&GrpcConfig>,
268) -> Result<GrpcServerHandle, Box<dyn std::error::Error + Send + Sync>> {
269    let port = grpc_config.map(|c| c.port).unwrap_or(50051);
270    let reflection_enabled = grpc_config.map(|c| c.reflection).unwrap_or(true);
271
272    let addr: SocketAddr = format!("0.0.0.0:{port}").parse()?;
273
274    let svc = ShaperailGrpcService::new(state, resources.clone(), jwt_config);
275    let grpc_service = ShaperailGrpcServiceServer { inner: svc };
276
277    // Health service
278    let (mut health_reporter, health_service) = tonic_health::server::health_reporter();
279    health_reporter
280        .set_serving::<ShaperailGrpcServiceServer>()
281        .await;
282
283    for resource in &resources {
284        let pascal = to_pascal_case(&to_singular(&resource.resource));
285        let service_name = format!(
286            "shaperail.v{}.{}.{}Service",
287            resource.version, resource.resource, pascal
288        );
289        health_reporter
290            .set_service_status(&service_name, tonic_health::ServingStatus::Serving)
291            .await;
292    }
293
294    let mut builder = Server::builder();
295
296    let handle = if reflection_enabled {
297        let reflection_service = tonic_reflection::server::Builder::configure()
298            .build_v1()
299            .map_err(|e| format!("Failed to build reflection service: {e}"))?;
300
301        let router = builder
302            .add_service(health_service)
303            .add_service(reflection_service)
304            .add_service(grpc_service);
305
306        tokio::spawn(async move { router.serve(addr).await })
307    } else {
308        let router = builder
309            .add_service(health_service)
310            .add_service(grpc_service);
311
312        tokio::spawn(async move { router.serve(addr).await })
313    };
314
315    tracing::info!("gRPC server listening on {addr}");
316
317    Ok(GrpcServerHandle { handle, addr })
318}
319
320fn to_pascal_case(s: &str) -> String {
321    s.split('_')
322        .map(|part| {
323            let mut chars = part.chars();
324            match chars.next() {
325                Some(c) => {
326                    let upper: String = c.to_uppercase().collect();
327                    upper + chars.as_str()
328                }
329                None => String::new(),
330            }
331        })
332        .collect()
333}
334
335fn to_singular(s: &str) -> String {
336    const EXCEPTIONS: &[&str] = &["status", "bus", "alias", "canvas"];
337    if EXCEPTIONS.iter().any(|e| s.ends_with(e)) {
338        return s.to_string();
339    }
340    if let Some(stripped) = s.strip_suffix("ies") {
341        format!("{stripped}y")
342    } else if s.ends_with("ses") || s.ends_with("xes") || s.ends_with("zes") {
343        s[..s.len() - 2].to_string()
344    } else if let Some(stripped) = s.strip_suffix('s') {
345        if stripped.ends_with('s') {
346            s.to_string()
347        } else {
348            stripped.to_string()
349        }
350    } else {
351        s.to_string()
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358
359    #[test]
360    fn parse_grpc_path_valid() {
361        let result =
362            ShaperailGrpcService::parse_grpc_path("/shaperail.v1.users.UserService/GetUser");
363        assert_eq!(result, Some(("users".to_string(), "GetUser".to_string())));
364    }
365
366    #[test]
367    fn parse_grpc_path_list() {
368        let result =
369            ShaperailGrpcService::parse_grpc_path("/shaperail.v1.orders.OrderService/ListOrders");
370        assert_eq!(
371            result,
372            Some(("orders".to_string(), "ListOrders".to_string()))
373        );
374    }
375
376    #[test]
377    fn parse_grpc_path_invalid() {
378        assert!(ShaperailGrpcService::parse_grpc_path("/invalid").is_none());
379        assert!(ShaperailGrpcService::parse_grpc_path("").is_none());
380    }
381
382    #[test]
383    fn parse_grpc_path_stream() {
384        let result =
385            ShaperailGrpcService::parse_grpc_path("/shaperail.v1.users.UserService/StreamUsers");
386        assert_eq!(
387            result,
388            Some(("users".to_string(), "StreamUsers".to_string()))
389        );
390    }
391
392    #[test]
393    fn pascal_and_singular() {
394        assert_eq!(to_pascal_case("user"), "User");
395        assert_eq!(to_pascal_case("blog_post"), "BlogPost");
396        assert_eq!(to_singular("users"), "user");
397        assert_eq!(to_singular("categories"), "category");
398    }
399
400    #[test]
401    fn extract_user_no_header() {
402        let headers = http::HeaderMap::new();
403        assert!(extract_user_from_headers(&headers, None).is_none());
404    }
405
406    #[test]
407    fn extract_user_valid_token() {
408        let jwt = JwtConfig::new("test-secret-key-at-least-32-bytes-long!", 3600, 86400);
409        let token = jwt.encode_access("user-1", "admin").unwrap();
410
411        let mut headers = http::HeaderMap::new();
412        headers.insert(
413            "authorization",
414            http::HeaderValue::from_str(&format!("Bearer {token}")).unwrap(),
415        );
416
417        let user = extract_user_from_headers(&headers, Some(&jwt));
418        assert!(user.is_some());
419        let user = user.unwrap();
420        assert_eq!(user.id, "user-1");
421        assert_eq!(user.role, "admin");
422    }
423
424    #[test]
425    fn extract_user_invalid_token() {
426        let jwt = JwtConfig::new("test-secret-key-at-least-32-bytes-long!", 3600, 86400);
427
428        let mut headers = http::HeaderMap::new();
429        headers.insert(
430            "authorization",
431            http::HeaderValue::from_str("Bearer invalid.token.here").unwrap(),
432        );
433
434        assert!(extract_user_from_headers(&headers, Some(&jwt)).is_none());
435    }
436}