Skip to main content

rusty_cdk_core/apigateway/
builder.rs

1use std::marker::PhantomData;
2use crate::apigateway::{ApiGatewayV2Api, ApiGatewayV2ApiProperties, ApiGatewayV2ApiRef, ApiGatewayV2Integration, ApiGatewayV2IntegrationProperties, ApiGatewayV2Route, ApiGatewayV2RouteProperties, ApiGatewayV2Stage, ApiGatewayV2StageProperties, ApiGatewayV2StageRef, CorsConfiguration};
3use crate::intrinsic::{get_arn, get_ref, join, AWS_ACCOUNT_PSEUDO_PARAM, AWS_PARTITION_PSEUDO_PARAM, AWS_REGION_PSEUDO_PARAM};
4use crate::lambda::{FunctionRef, PermissionBuilder};
5use crate::shared::HttpMethod;
6use crate::shared::Id;
7use crate::stack::{Resource, StackBuilder};
8use serde_json::Value;
9use std::time::Duration;
10use crate::type_state;
11use crate::wrappers::LambdaPermissionAction;
12
13// TODO auth, api keys also still to do
14
15struct RouteInfo {
16    lambda_id: Id,
17    path: String,
18    method: Option<HttpMethod>,
19    resource_id: String,
20}
21
22type_state!(
23    ApiGatewayV2APIState,
24    StartState,
25    HttpState,
26    WebsocketState,
27);
28
29/// Builder for API Gateway V2 HTTP APIs.
30///
31/// Creates an HTTP API with routes to Lambda functions. Automatically creates integrations and permissions for each route.
32///
33/// # Example
34///
35/// ```rust,no_run
36/// use rusty_cdk_core::stack::StackBuilder;
37/// use rusty_cdk_core::apigateway::ApiGatewayV2Builder;
38/// use rusty_cdk_core::shared::HttpMethod;
39/// use rusty_cdk_core::lambda::{FunctionBuilder, Architecture, Runtime, Zip};
40/// use rusty_cdk_core::wrappers::*;
41/// use rusty_cdk_macros::{memory, timeout, zip_file};
42///
43/// let mut stack_builder = StackBuilder::new();
44///
45/// let function = unimplemented!("create a function");
46///
47/// let (api, stage) = ApiGatewayV2Builder::new("my-api", "MyHttpApi")
48///     .http()
49///     .add_route_lambda("/hello", HttpMethod::Get, &function)
50///     .add_route_lambda("/world", HttpMethod::Post, &function)
51///     .build(&mut stack_builder);
52/// ```
53pub struct ApiGatewayV2Builder<T: ApiGatewayV2APIState> {
54    phantom_data: PhantomData<T>,
55    id: Id,
56    name: Option<String>,
57    protocol_type: Option<String>,
58    disable_execute_api_endpoint: Option<bool>,
59    disable_schema_validation: Option<bool>,
60    cors_configuration: Option<CorsConfiguration>,
61    route_info: Vec<RouteInfo>,
62    route_selection_expression: Option<String>,
63}
64
65impl ApiGatewayV2Builder<StartState> {
66    /// Creates a new API Gateway V2 HTTP API builder.
67    ///
68    /// # Arguments
69    /// * `id` - Unique identifier for the API Gateway
70    /// * `name` - Name of the API Gateway
71    pub fn new<T: Into<String>>(id: &str, name: T) -> ApiGatewayV2Builder<StartState> {
72        Self {
73            phantom_data: Default::default(),
74            id: Id(id.to_string()),
75            name: Some(name.into()), // name is required when not OpenAPI (so currently always)
76            protocol_type: None,
77            disable_execute_api_endpoint: None,
78            disable_schema_validation: None,
79            cors_configuration: None,
80            route_selection_expression: None,
81            route_info: vec![],
82        }
83    }
84
85    pub fn http(self) -> ApiGatewayV2Builder<HttpState> {
86        ApiGatewayV2Builder {
87            phantom_data: Default::default(),
88            id: self.id,
89            name: self.name,
90            protocol_type: Some("HTTP".to_string()),
91            disable_execute_api_endpoint: self.disable_execute_api_endpoint,
92            cors_configuration: self.cors_configuration,
93            route_info: self.route_info,
94            disable_schema_validation: None,
95            route_selection_expression: None,
96        }
97    }
98
99    pub fn websocket<T: Into<String>>(self, route_selection_expression: T) -> ApiGatewayV2Builder<WebsocketState> {
100        ApiGatewayV2Builder {
101            phantom_data: Default::default(),
102            id: self.id,
103            name: self.name,
104            protocol_type: Some("WEBSOCKET".to_string()),
105            route_selection_expression: Some(route_selection_expression.into()),
106            disable_execute_api_endpoint: self.disable_execute_api_endpoint,
107            route_info: self.route_info,
108            disable_schema_validation: self.disable_schema_validation,
109            cors_configuration: None,
110        }
111    }
112}
113
114impl ApiGatewayV2Builder<HttpState> {
115    pub fn cors_configuration(self, config: CorsConfiguration) -> Self {
116        Self {
117            cors_configuration: Some(config),
118            ..self
119        }
120    }
121
122    /// Adds a route for a specific HTTP method and path.
123    ///
124    /// Automatically creates the integration and Lambda permission.
125    pub fn add_route_lambda<T: Into<String>>(mut self, path: T, method: HttpMethod, lambda: &FunctionRef) -> Self {
126        let path = path.into();
127        let path = if path.starts_with("/") { path } else { format!("/{}", path) };
128
129        self.route_info.push(RouteInfo {
130            lambda_id: lambda.get_id().clone(),
131            path,
132            method: Some(method),
133            resource_id: lambda.get_resource_id().to_string(),
134        });
135        Self { ..self }
136    }
137
138    pub fn build(self, stack_builder: &mut StackBuilder) -> (
139        ApiGatewayV2ApiRef,
140        ApiGatewayV2StageRef,
141    ) {
142        self.build_internal(stack_builder)
143    }
144}
145
146impl ApiGatewayV2Builder<WebsocketState> {
147    pub fn disable_schema_validation(self, disable: bool) -> Self {
148        Self {
149            disable_schema_validation: Some(disable),
150            ..self
151        }
152    }
153
154    /// Adds a route for a specific route key.
155    ///
156    /// Automatically creates the integration and Lambda permission.
157    pub fn add_route_lambda<T: Into<String>>(mut self, route_key: T, lambda: &FunctionRef) -> Self {
158        self.route_info.push(RouteInfo {
159            lambda_id: lambda.get_id().clone(),
160            path: route_key.into(),
161            method: None,
162            resource_id: lambda.get_resource_id().to_string(),
163        });
164        Self { ..self }
165    }
166    
167    pub fn build(self, stack_builder: &mut StackBuilder) -> (
168        ApiGatewayV2ApiRef,
169        ApiGatewayV2StageRef,
170    ) {
171        self.build_internal(stack_builder)
172    }
173}
174
175impl<T: ApiGatewayV2APIState> ApiGatewayV2Builder<T> {
176    pub fn disable_execute_api_endpoint(self, disable_api_endpoint: bool) -> Self {
177        Self {
178            disable_execute_api_endpoint: Some(disable_api_endpoint),
179            ..self
180        }
181    }
182
183    /// Adds a default route that catches all requests not matching other routes.
184    ///
185    /// Automatically creates the integration and Lambda permission.
186    pub fn add_default_route_lambda(mut self, lambda: &FunctionRef) -> Self {
187        self.route_info.push(RouteInfo {
188            lambda_id: lambda.get_id().clone(),
189            path: "$default".to_string(),
190            method: None,
191            resource_id: lambda.get_resource_id().to_string(),
192        });
193        Self { ..self }
194    }
195
196    fn build_internal(
197        self, stack_builder: &mut StackBuilder
198    ) -> (
199        ApiGatewayV2ApiRef,
200        ApiGatewayV2StageRef,
201    ) {
202        let api_resource_id = Resource::generate_id("HttpApiGateway");
203        let stage_resource_id = Resource::generate_id("HttpApiStage");
204        let stage_id = Id::generate_id(&self.id, "Stage");
205
206        let protocol_type = self.protocol_type.expect("protocol type should be present, enforced by builder");
207
208        self
209            .route_info
210            .into_iter()
211            .for_each(|info| {
212                let route_id = Id::combine_with_resource_id(&self.id, &info.lambda_id);
213                let route_permission_id = Id::generate_id(&self.id, "Permission");
214                let route_integration_id = Id::generate_id(&self.id, "Integration");
215
216                let integration_resource_id = Resource::generate_id("HttpApiIntegration");
217                let route_resource_id = Resource::generate_id("HttpApiRoute");
218
219                PermissionBuilder::new(
220                    &route_permission_id,
221                    LambdaPermissionAction("lambda:InvokeFunction".to_string()),
222                    get_arn(&info.resource_id),
223                    "apigateway.amazonaws.com".to_string(),
224                )
225                    .source_arn(join(
226                        "",
227                        vec![
228                            Value::String("arn:".to_string()),
229                            get_ref(AWS_PARTITION_PSEUDO_PARAM),
230                            Value::String(":execute-api:".to_string()),
231                            get_ref(AWS_REGION_PSEUDO_PARAM),
232                            Value::String(":".to_string()),
233                            get_ref(AWS_ACCOUNT_PSEUDO_PARAM),
234                            Value::String(":".to_string()),
235                            get_ref(&api_resource_id),
236                            Value::String(format!("*/*{}", info.path)),
237                        ],
238                    ))
239                    .build(stack_builder);
240
241                let integration = ApiGatewayV2Integration {
242                    id: route_integration_id,
243                    resource_id: integration_resource_id.clone(),
244                    r#type: "AWS::ApiGatewayV2::Integration".to_string(),
245                    properties: ApiGatewayV2IntegrationProperties {
246                        api_id: get_ref(&api_resource_id),
247                        integration_type: "AWS_PROXY".to_string(),
248                        payload_format_version: if &protocol_type == "HTTP" { Some("2.0".to_string()) } else { Some("1.0".to_string()) },
249                        integration_uri: Some(get_arn(&info.resource_id)),
250                        // TODO allow passing these
251                        content_handling_strategy: None, // only for websocket
252                        integration_method: None, // only for websocket - set to post for lambda integration
253                        passthrough_behavior: None,
254                        request_parameters: None,
255                        request_templates: None,
256                        response_parameters: None,
257                        timeout_in_millis: None,
258                    },
259                };
260                stack_builder.add_resource(integration);
261
262                let route_key = if let Some(method) = info.method {
263                    let method: String = method.into();
264                    format!("{} {}", method, info.path)
265                } else {
266                    info.path
267                };
268
269                let route = ApiGatewayV2Route {
270                    id: route_id,
271                    resource_id: route_resource_id.clone(),
272                    r#type: "AWS::ApiGatewayV2::Route".to_string(),
273                    properties: ApiGatewayV2RouteProperties {
274                        api_id: get_ref(&api_resource_id),
275                        route_key,
276                        target: Some(join(
277                            "",
278                            vec![Value::String("integrations/".to_string()), get_ref(&integration_resource_id)],
279                        )),
280                    },
281                };
282                stack_builder.add_resource(route);
283            });
284
285        stack_builder.add_resource(ApiGatewayV2Stage {
286            id: stage_id,
287            resource_id: stage_resource_id.clone(),
288            r#type: "AWS::ApiGatewayV2::Stage".to_string(),
289            properties: ApiGatewayV2StageProperties {
290                api_id: get_ref(&api_resource_id),
291                stage_name: if &protocol_type == "HTTP" { "$default".to_string() } else { "prod".to_string() }, // in the future, expose this
292                auto_deploy: true,
293                default_route_settings: None,
294                route_settings: None,
295            },
296        });
297
298        stack_builder.add_resource(ApiGatewayV2Api {
299            id: self.id,
300            resource_id: api_resource_id.clone(),
301            r#type: "AWS::ApiGatewayV2::Api".to_string(),
302            properties: ApiGatewayV2ApiProperties {
303                name: self.name,
304                protocol_type,
305                disable_execute_api_endpoint: self.disable_execute_api_endpoint,
306                disable_schema_validation: self.disable_schema_validation,
307                cors_configuration: self.cors_configuration,
308                route_selection_expression: self.route_selection_expression,
309            },
310        });
311
312        let stage = ApiGatewayV2StageRef::internal_new(stage_resource_id);
313        let api = ApiGatewayV2ApiRef::internal_new(api_resource_id);
314
315        (api, stage)
316    }
317}
318
319pub struct CorsConfigurationBuilder {
320    allow_credentials: Option<bool>,
321    allow_headers: Option<Vec<String>>,
322    allow_methods: Option<Vec<String>>,
323    allow_origins: Option<Vec<String>>,
324    expose_headers: Option<Vec<String>>,
325    max_age: Option<u64>,
326}
327
328impl Default for CorsConfigurationBuilder {
329    fn default() -> Self {
330        Self::new()
331    }
332}
333
334impl CorsConfigurationBuilder {
335    pub fn new() -> Self {
336        Self {
337            allow_credentials: None,
338            allow_headers: None,
339            allow_methods: None,
340            allow_origins: None,
341            expose_headers: None,
342            max_age: None,
343        }
344    }
345
346    pub fn allow_credentials(self, allow: bool) -> Self {
347        Self {
348            allow_credentials: Some(allow),
349            ..self
350        }
351    }
352
353    pub fn allow_headers(self, headers: Vec<String>) -> Self {
354        Self {
355            allow_headers: Some(headers),
356            ..self
357        }
358    }
359
360    pub fn allow_methods(self, methods: Vec<HttpMethod>) -> Self {
361        Self {
362            allow_methods: Some(methods.into_iter().map(Into::into).collect()),
363            ..self
364        }
365    }
366
367    pub fn allow_origins(self, origins: Vec<String>) -> Self {
368        Self {
369            allow_origins: Some(origins),
370            ..self
371        }
372    }
373
374    pub fn expose_headers(self, headers: Vec<String>) -> Self {
375        Self {
376            expose_headers: Some(headers),
377            ..self
378        }
379    }
380
381    pub fn max_age(self, age: Duration) -> Self {
382        Self {
383            max_age: Some(age.as_secs()),
384            ..self
385        }
386    }
387
388    #[must_use]
389    pub fn build(self) -> CorsConfiguration {
390        CorsConfiguration {
391            allow_credentials: self.allow_credentials,
392            allow_headers: self.allow_headers,
393            allow_methods: self.allow_methods,
394            allow_origins: self.allow_origins,
395            expose_headers: self.expose_headers,
396            max_age: self.max_age,
397        }
398    }
399}