Skip to main content

rusty_cdk_core/apigateway/
builder.rs

1use std::marker::PhantomData;
2use crate::apigateway::{ApiGatewayV2Api, ApiGatewayV2ApiProperties, ApiGatewayV2ApiRef, ApiGatewayV2ApiType, ApiGatewayV2Integration, ApiGatewayV2IntegrationProperties, ApiGatewayV2IntegrationType, ApiGatewayV2Route, ApiGatewayV2RouteProperties, ApiGatewayV2RouteType, ApiGatewayV2Stage, ApiGatewayV2StageProperties, ApiGatewayV2StageRef, ApiGatewayV2StageType, CorsConfiguration};
3use crate::intrinsic::{get_arn, get_ref, join, AWS_ACCOUNT_PSEUDO_PARAM, AWS_PARTITION_PSEUDO_PARAM, AWS_REGION_PSEUDO_PARAM};
4use crate::lambda::{FunctionRef, PermissionBuilder};
5use crate::shared::HttpMethod;
6use crate::shared::Id;
7use crate::stack::{Resource, StackBuilder};
8use serde_json::Value;
9use std::time::Duration;
10use crate::type_state;
11use crate::wrappers::LambdaPermissionAction;
12
13// TODO auth, api keys also still to do
14
15struct RouteInfo {
16    lambda_id: Id,
17    path: String,
18    method: Option<HttpMethod>,
19    resource_id: String,
20}
21
22type_state!(
23    ApiGatewayV2APIState,
24    StartState,
25    HttpState,
26    WebsocketState,
27);
28
29/// Builder for API Gateway V2 HTTP APIs.
30///
31/// See https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-apigatewayv2-api.html
32/// Creates an HTTP API with routes to Lambda functions. Automatically creates integrations and permissions for each route.
33///
34/// # Example
35///
36/// ```rust,no_run
37/// use rusty_cdk_core::stack::StackBuilder;
38/// use rusty_cdk_core::apigateway::ApiGatewayV2Builder;
39/// use rusty_cdk_core::shared::HttpMethod;
40/// use rusty_cdk_core::lambda::{FunctionBuilder, Architecture, Runtime, Zip};
41/// use rusty_cdk_core::wrappers::*;
42/// use rusty_cdk_macros::{memory, timeout, zip_file};
43///
44/// let mut stack_builder = StackBuilder::new();
45///
46/// let function = unimplemented!("create a function");
47///
48/// let (api, stage) = ApiGatewayV2Builder::new("my-api", "MyHttpApi")
49///     .http()
50///     .add_route_lambda("/hello", HttpMethod::Get, &function)
51///     .add_route_lambda("/world", HttpMethod::Post, &function)
52///     .build(&mut stack_builder);
53/// ```
54pub struct ApiGatewayV2Builder<T: ApiGatewayV2APIState> {
55    phantom_data: PhantomData<T>,
56    id: Id,
57    name: Option<String>,
58    protocol_type: Option<String>,
59    disable_execute_api_endpoint: Option<bool>,
60    disable_schema_validation: Option<bool>,
61    cors_configuration: Option<CorsConfiguration>,
62    route_info: Vec<RouteInfo>,
63    route_selection_expression: Option<String>,
64}
65
66impl ApiGatewayV2Builder<StartState> {
67    /// Creates a new API Gateway V2 HTTP API builder.
68    ///
69    /// # Arguments
70    /// * `id` - Unique identifier for the API Gateway
71    /// * `name` - Name of the API Gateway
72    pub fn new<T: Into<String>>(id: &str, name: T) -> ApiGatewayV2Builder<StartState> {
73        Self {
74            phantom_data: Default::default(),
75            id: Id(id.to_string()),
76            name: Some(name.into()), // name is required when not OpenAPI (so currently always)
77            protocol_type: None,
78            disable_execute_api_endpoint: None,
79            disable_schema_validation: None,
80            cors_configuration: None,
81            route_selection_expression: None,
82            route_info: vec![],
83        }
84    }
85
86    pub fn http(self) -> ApiGatewayV2Builder<HttpState> {
87        ApiGatewayV2Builder {
88            phantom_data: Default::default(),
89            id: self.id,
90            name: self.name,
91            protocol_type: Some("HTTP".to_string()),
92            disable_execute_api_endpoint: self.disable_execute_api_endpoint,
93            cors_configuration: self.cors_configuration,
94            route_info: self.route_info,
95            disable_schema_validation: None,
96            route_selection_expression: None,
97        }
98    }
99
100    pub fn websocket<T: Into<String>>(self, route_selection_expression: T) -> ApiGatewayV2Builder<WebsocketState> {
101        ApiGatewayV2Builder {
102            phantom_data: Default::default(),
103            id: self.id,
104            name: self.name,
105            protocol_type: Some("WEBSOCKET".to_string()),
106            route_selection_expression: Some(route_selection_expression.into()),
107            disable_execute_api_endpoint: self.disable_execute_api_endpoint,
108            route_info: self.route_info,
109            disable_schema_validation: self.disable_schema_validation,
110            cors_configuration: None,
111        }
112    }
113}
114
115impl ApiGatewayV2Builder<HttpState> {
116    pub fn cors_configuration(self, config: CorsConfiguration) -> Self {
117        Self {
118            cors_configuration: Some(config),
119            ..self
120        }
121    }
122
123    /// Adds a route for a specific HTTP method and path.
124    ///
125    /// Automatically creates the integration and Lambda permission.
126    pub fn add_route_lambda<T: Into<String>>(mut self, path: T, method: HttpMethod, lambda: &FunctionRef) -> Self {
127        let path = path.into();
128        let path = if path.starts_with("/") { path } else { format!("/{}", path) };
129
130        self.route_info.push(RouteInfo {
131            lambda_id: lambda.get_id().clone(),
132            path,
133            method: Some(method),
134            resource_id: lambda.get_resource_id().to_string(),
135        });
136        Self { ..self }
137    }
138
139    pub fn build(self, stack_builder: &mut StackBuilder) -> (
140        ApiGatewayV2ApiRef,
141        ApiGatewayV2StageRef,
142    ) {
143        self.build_internal(stack_builder)
144    }
145}
146
147impl ApiGatewayV2Builder<WebsocketState> {
148    pub fn disable_schema_validation(self, disable: bool) -> Self {
149        Self {
150            disable_schema_validation: Some(disable),
151            ..self
152        }
153    }
154
155    /// Adds a route for a specific route key.
156    ///
157    /// Automatically creates the integration and Lambda permission.
158    pub fn add_route_lambda<T: Into<String>>(mut self, route_key: T, lambda: &FunctionRef) -> Self {
159        self.route_info.push(RouteInfo {
160            lambda_id: lambda.get_id().clone(),
161            path: route_key.into(),
162            method: None,
163            resource_id: lambda.get_resource_id().to_string(),
164        });
165        Self { ..self }
166    }
167    
168    pub fn build(self, stack_builder: &mut StackBuilder) -> (
169        ApiGatewayV2ApiRef,
170        ApiGatewayV2StageRef,
171    ) {
172        self.build_internal(stack_builder)
173    }
174}
175
176impl<T: ApiGatewayV2APIState> ApiGatewayV2Builder<T> {
177    pub fn disable_execute_api_endpoint(self, disable_api_endpoint: bool) -> Self {
178        Self {
179            disable_execute_api_endpoint: Some(disable_api_endpoint),
180            ..self
181        }
182    }
183
184    /// Adds a default route that catches all requests not matching other routes.
185    ///
186    /// Automatically creates the integration and Lambda permission.
187    pub fn add_default_route_lambda(mut self, lambda: &FunctionRef) -> Self {
188        self.route_info.push(RouteInfo {
189            lambda_id: lambda.get_id().clone(),
190            path: "$default".to_string(),
191            method: None,
192            resource_id: lambda.get_resource_id().to_string(),
193        });
194        Self { ..self }
195    }
196
197    fn build_internal(
198        self, stack_builder: &mut StackBuilder
199    ) -> (
200        ApiGatewayV2ApiRef,
201        ApiGatewayV2StageRef,
202    ) {
203        let api_resource_id = Resource::generate_id("HttpApiGateway");
204        let stage_resource_id = Resource::generate_id("HttpApiStage");
205        let stage_id = Id::generate_id(&self.id, "Stage");
206
207        let protocol_type = self.protocol_type.expect("protocol type should be present, enforced by builder");
208
209        self
210            .route_info
211            .into_iter()
212            .for_each(|info| {
213                let route_id = Id::combine_with_resource_id(&self.id, &info.lambda_id);
214                let route_permission_id = Id::generate_id(&self.id, "Permission");
215                let route_integration_id = Id::generate_id(&self.id, "Integration");
216
217                let integration_resource_id = Resource::generate_id("HttpApiIntegration");
218                let route_resource_id = Resource::generate_id("HttpApiRoute");
219
220                PermissionBuilder::new(
221                    &route_permission_id,
222                    LambdaPermissionAction("lambda:InvokeFunction".to_string()),
223                    get_arn(&info.resource_id),
224                    "apigateway.amazonaws.com".to_string(),
225                )
226                    .source_arn(join(
227                        "",
228                        vec![
229                            Value::String("arn:".to_string()),
230                            get_ref(AWS_PARTITION_PSEUDO_PARAM),
231                            Value::String(":execute-api:".to_string()),
232                            get_ref(AWS_REGION_PSEUDO_PARAM),
233                            Value::String(":".to_string()),
234                            get_ref(AWS_ACCOUNT_PSEUDO_PARAM),
235                            Value::String(":".to_string()),
236                            get_ref(&api_resource_id),
237                            Value::String(format!("*/*{}", info.path)),
238                        ],
239                    ))
240                    .build(stack_builder);
241
242                let integration = ApiGatewayV2Integration {
243                    id: route_integration_id,
244                    resource_id: integration_resource_id.clone(),
245                    r#type: ApiGatewayV2IntegrationType::ApiGatewayV2IntegrationType,
246                    properties: ApiGatewayV2IntegrationProperties {
247                        api_id: get_ref(&api_resource_id),
248                        integration_type: "AWS_PROXY".to_string(),
249                        payload_format_version: if &protocol_type == "HTTP" { Some("2.0".to_string()) } else { Some("1.0".to_string()) },
250                        integration_uri: Some(get_arn(&info.resource_id)),
251                        // TODO allow passing these
252                        content_handling_strategy: None, // only for websocket
253                        integration_method: None, // only for websocket - set to post for lambda integration
254                        passthrough_behavior: None,
255                        request_parameters: None,
256                        request_templates: None,
257                        response_parameters: None,
258                        timeout_in_millis: None,
259                    },
260                };
261                stack_builder.add_resource(integration);
262
263                let route_key = if let Some(method) = info.method {
264                    let method: String = method.into();
265                    format!("{} {}", method, info.path)
266                } else {
267                    info.path
268                };
269
270                let route = ApiGatewayV2Route {
271                    id: route_id,
272                    resource_id: route_resource_id.clone(),
273                    r#type: ApiGatewayV2RouteType::ApiGatewayV2RouteType,
274                    properties: ApiGatewayV2RouteProperties {
275                        api_id: get_ref(&api_resource_id),
276                        route_key,
277                        target: Some(join(
278                            "",
279                            vec![Value::String("integrations/".to_string()), get_ref(&integration_resource_id)],
280                        )),
281                    },
282                };
283                stack_builder.add_resource(route);
284            });
285
286        stack_builder.add_resource(ApiGatewayV2Stage {
287            id: stage_id,
288            resource_id: stage_resource_id.clone(),
289            r#type: ApiGatewayV2StageType::ApiGatewayV2StageType,
290            properties: ApiGatewayV2StageProperties {
291                api_id: get_ref(&api_resource_id),
292                stage_name: if &protocol_type == "HTTP" { "$default".to_string() } else { "prod".to_string() }, // in the future, expose this
293                auto_deploy: true,
294                default_route_settings: None,
295                route_settings: None,
296            },
297        });
298
299        stack_builder.add_resource(ApiGatewayV2Api {
300            id: self.id,
301            resource_id: api_resource_id.clone(),
302            r#type: ApiGatewayV2ApiType::ApiGatewayV2ApiType,
303            properties: ApiGatewayV2ApiProperties {
304                name: self.name,
305                protocol_type,
306                disable_execute_api_endpoint: self.disable_execute_api_endpoint,
307                disable_schema_validation: self.disable_schema_validation,
308                cors_configuration: self.cors_configuration,
309                route_selection_expression: self.route_selection_expression,
310            },
311        });
312
313        let stage = ApiGatewayV2StageRef::internal_new(stage_resource_id);
314        let api = ApiGatewayV2ApiRef::internal_new(api_resource_id);
315
316        (api, stage)
317    }
318}
319
320pub struct CorsConfigurationBuilder {
321    allow_credentials: Option<bool>,
322    allow_headers: Option<Vec<String>>,
323    allow_methods: Option<Vec<String>>,
324    allow_origins: Option<Vec<String>>,
325    expose_headers: Option<Vec<String>>,
326    max_age: Option<u64>,
327}
328
329impl Default for CorsConfigurationBuilder {
330    fn default() -> Self {
331        Self::new()
332    }
333}
334
335impl CorsConfigurationBuilder {
336    pub fn new() -> Self {
337        Self {
338            allow_credentials: None,
339            allow_headers: None,
340            allow_methods: None,
341            allow_origins: None,
342            expose_headers: None,
343            max_age: None,
344        }
345    }
346
347    pub fn allow_credentials(self, allow: bool) -> Self {
348        Self {
349            allow_credentials: Some(allow),
350            ..self
351        }
352    }
353
354    pub fn allow_headers(self, headers: Vec<String>) -> Self {
355        Self {
356            allow_headers: Some(headers),
357            ..self
358        }
359    }
360
361    pub fn allow_methods(self, methods: Vec<HttpMethod>) -> Self {
362        Self {
363            allow_methods: Some(methods.into_iter().map(Into::into).collect()),
364            ..self
365        }
366    }
367
368    pub fn allow_origins(self, origins: Vec<String>) -> Self {
369        Self {
370            allow_origins: Some(origins),
371            ..self
372        }
373    }
374
375    pub fn expose_headers(self, headers: Vec<String>) -> Self {
376        Self {
377            expose_headers: Some(headers),
378            ..self
379        }
380    }
381
382    pub fn max_age(self, age: Duration) -> Self {
383        Self {
384            max_age: Some(age.as_secs()),
385            ..self
386        }
387    }
388
389    #[must_use]
390    pub fn build(self) -> CorsConfiguration {
391        CorsConfiguration {
392            allow_credentials: self.allow_credentials,
393            allow_headers: self.allow_headers,
394            allow_methods: self.allow_methods,
395            allow_origins: self.allow_origins,
396            expose_headers: self.expose_headers,
397            max_age: self.max_age,
398        }
399    }
400}