1use std::marker::PhantomData;
2use crate::apigateway::{ApiGatewayV2Api, ApiGatewayV2ApiProperties, ApiGatewayV2ApiRef, ApiGatewayV2ApiType, ApiGatewayV2Integration, ApiGatewayV2IntegrationProperties, ApiGatewayV2IntegrationType, ApiGatewayV2Route, ApiGatewayV2RouteProperties, ApiGatewayV2RouteType, ApiGatewayV2Stage, ApiGatewayV2StageProperties, ApiGatewayV2StageRef, ApiGatewayV2StageType, CorsConfiguration};
3use crate::intrinsic::{get_arn, get_ref, join, AWS_ACCOUNT_PSEUDO_PARAM, AWS_PARTITION_PSEUDO_PARAM, AWS_REGION_PSEUDO_PARAM};
4use crate::lambda::{FunctionRef, PermissionBuilder};
5use crate::shared::HttpMethod;
6use crate::shared::Id;
7use crate::stack::{Resource, StackBuilder};
8use serde_json::Value;
9use std::time::Duration;
10use crate::type_state;
11use crate::wrappers::LambdaPermissionAction;
12
13struct RouteInfo {
16 lambda_id: Id,
17 path: String,
18 method: Option<HttpMethod>,
19 resource_id: String,
20}
21
22type_state!(
23 ApiGatewayV2APIState,
24 StartState,
25 HttpState,
26 WebsocketState,
27);
28
29pub struct ApiGatewayV2Builder<T: ApiGatewayV2APIState> {
55 phantom_data: PhantomData<T>,
56 id: Id,
57 name: Option<String>,
58 protocol_type: Option<String>,
59 disable_execute_api_endpoint: Option<bool>,
60 disable_schema_validation: Option<bool>,
61 cors_configuration: Option<CorsConfiguration>,
62 route_info: Vec<RouteInfo>,
63 route_selection_expression: Option<String>,
64}
65
66impl ApiGatewayV2Builder<StartState> {
67 pub fn new<T: Into<String>>(id: &str, name: T) -> ApiGatewayV2Builder<StartState> {
73 Self {
74 phantom_data: Default::default(),
75 id: Id(id.to_string()),
76 name: Some(name.into()), protocol_type: None,
78 disable_execute_api_endpoint: None,
79 disable_schema_validation: None,
80 cors_configuration: None,
81 route_selection_expression: None,
82 route_info: vec![],
83 }
84 }
85
86 pub fn http(self) -> ApiGatewayV2Builder<HttpState> {
87 ApiGatewayV2Builder {
88 phantom_data: Default::default(),
89 id: self.id,
90 name: self.name,
91 protocol_type: Some("HTTP".to_string()),
92 disable_execute_api_endpoint: self.disable_execute_api_endpoint,
93 cors_configuration: self.cors_configuration,
94 route_info: self.route_info,
95 disable_schema_validation: None,
96 route_selection_expression: None,
97 }
98 }
99
100 pub fn websocket<T: Into<String>>(self, route_selection_expression: T) -> ApiGatewayV2Builder<WebsocketState> {
101 ApiGatewayV2Builder {
102 phantom_data: Default::default(),
103 id: self.id,
104 name: self.name,
105 protocol_type: Some("WEBSOCKET".to_string()),
106 route_selection_expression: Some(route_selection_expression.into()),
107 disable_execute_api_endpoint: self.disable_execute_api_endpoint,
108 route_info: self.route_info,
109 disable_schema_validation: self.disable_schema_validation,
110 cors_configuration: None,
111 }
112 }
113}
114
115impl ApiGatewayV2Builder<HttpState> {
116 pub fn cors_configuration(self, config: CorsConfiguration) -> Self {
117 Self {
118 cors_configuration: Some(config),
119 ..self
120 }
121 }
122
123 pub fn add_route_lambda<T: Into<String>>(mut self, path: T, method: HttpMethod, lambda: &FunctionRef) -> Self {
127 let path = path.into();
128 let path = if path.starts_with("/") { path } else { format!("/{}", path) };
129
130 self.route_info.push(RouteInfo {
131 lambda_id: lambda.get_id().clone(),
132 path,
133 method: Some(method),
134 resource_id: lambda.get_resource_id().to_string(),
135 });
136 Self { ..self }
137 }
138
139 pub fn build(self, stack_builder: &mut StackBuilder) -> (
140 ApiGatewayV2ApiRef,
141 ApiGatewayV2StageRef,
142 ) {
143 self.build_internal(stack_builder)
144 }
145}
146
147impl ApiGatewayV2Builder<WebsocketState> {
148 pub fn disable_schema_validation(self, disable: bool) -> Self {
149 Self {
150 disable_schema_validation: Some(disable),
151 ..self
152 }
153 }
154
155 pub fn add_route_lambda<T: Into<String>>(mut self, route_key: T, lambda: &FunctionRef) -> Self {
159 self.route_info.push(RouteInfo {
160 lambda_id: lambda.get_id().clone(),
161 path: route_key.into(),
162 method: None,
163 resource_id: lambda.get_resource_id().to_string(),
164 });
165 Self { ..self }
166 }
167
168 pub fn build(self, stack_builder: &mut StackBuilder) -> (
169 ApiGatewayV2ApiRef,
170 ApiGatewayV2StageRef,
171 ) {
172 self.build_internal(stack_builder)
173 }
174}
175
176impl<T: ApiGatewayV2APIState> ApiGatewayV2Builder<T> {
177 pub fn disable_execute_api_endpoint(self, disable_api_endpoint: bool) -> Self {
178 Self {
179 disable_execute_api_endpoint: Some(disable_api_endpoint),
180 ..self
181 }
182 }
183
184 pub fn add_default_route_lambda(mut self, lambda: &FunctionRef) -> Self {
188 self.route_info.push(RouteInfo {
189 lambda_id: lambda.get_id().clone(),
190 path: "$default".to_string(),
191 method: None,
192 resource_id: lambda.get_resource_id().to_string(),
193 });
194 Self { ..self }
195 }
196
197 fn build_internal(
198 self, stack_builder: &mut StackBuilder
199 ) -> (
200 ApiGatewayV2ApiRef,
201 ApiGatewayV2StageRef,
202 ) {
203 let api_resource_id = Resource::generate_id("HttpApiGateway");
204 let stage_resource_id = Resource::generate_id("HttpApiStage");
205 let stage_id = Id::generate_id(&self.id, "Stage");
206
207 let protocol_type = self.protocol_type.expect("protocol type should be present, enforced by builder");
208
209 self
210 .route_info
211 .into_iter()
212 .for_each(|info| {
213 let route_id = Id::combine_with_resource_id(&self.id, &info.lambda_id);
214 let route_permission_id = Id::generate_id(&self.id, "Permission");
215 let route_integration_id = Id::generate_id(&self.id, "Integration");
216
217 let integration_resource_id = Resource::generate_id("HttpApiIntegration");
218 let route_resource_id = Resource::generate_id("HttpApiRoute");
219
220 PermissionBuilder::new(
221 &route_permission_id,
222 LambdaPermissionAction("lambda:InvokeFunction".to_string()),
223 get_arn(&info.resource_id),
224 "apigateway.amazonaws.com".to_string(),
225 )
226 .source_arn(join(
227 "",
228 vec![
229 Value::String("arn:".to_string()),
230 get_ref(AWS_PARTITION_PSEUDO_PARAM),
231 Value::String(":execute-api:".to_string()),
232 get_ref(AWS_REGION_PSEUDO_PARAM),
233 Value::String(":".to_string()),
234 get_ref(AWS_ACCOUNT_PSEUDO_PARAM),
235 Value::String(":".to_string()),
236 get_ref(&api_resource_id),
237 Value::String(format!("*/*{}", info.path)),
238 ],
239 ))
240 .build(stack_builder);
241
242 let integration = ApiGatewayV2Integration {
243 id: route_integration_id,
244 resource_id: integration_resource_id.clone(),
245 r#type: ApiGatewayV2IntegrationType::ApiGatewayV2IntegrationType,
246 properties: ApiGatewayV2IntegrationProperties {
247 api_id: get_ref(&api_resource_id),
248 integration_type: "AWS_PROXY".to_string(),
249 payload_format_version: if &protocol_type == "HTTP" { Some("2.0".to_string()) } else { Some("1.0".to_string()) },
250 integration_uri: Some(get_arn(&info.resource_id)),
251 content_handling_strategy: None, integration_method: None, passthrough_behavior: None,
255 request_parameters: None,
256 request_templates: None,
257 response_parameters: None,
258 timeout_in_millis: None,
259 },
260 };
261 stack_builder.add_resource(integration);
262
263 let route_key = if let Some(method) = info.method {
264 let method: String = method.into();
265 format!("{} {}", method, info.path)
266 } else {
267 info.path
268 };
269
270 let route = ApiGatewayV2Route {
271 id: route_id,
272 resource_id: route_resource_id.clone(),
273 r#type: ApiGatewayV2RouteType::ApiGatewayV2RouteType,
274 properties: ApiGatewayV2RouteProperties {
275 api_id: get_ref(&api_resource_id),
276 route_key,
277 target: Some(join(
278 "",
279 vec![Value::String("integrations/".to_string()), get_ref(&integration_resource_id)],
280 )),
281 },
282 };
283 stack_builder.add_resource(route);
284 });
285
286 stack_builder.add_resource(ApiGatewayV2Stage {
287 id: stage_id,
288 resource_id: stage_resource_id.clone(),
289 r#type: ApiGatewayV2StageType::ApiGatewayV2StageType,
290 properties: ApiGatewayV2StageProperties {
291 api_id: get_ref(&api_resource_id),
292 stage_name: if &protocol_type == "HTTP" { "$default".to_string() } else { "prod".to_string() }, auto_deploy: true,
294 default_route_settings: None,
295 route_settings: None,
296 },
297 });
298
299 stack_builder.add_resource(ApiGatewayV2Api {
300 id: self.id,
301 resource_id: api_resource_id.clone(),
302 r#type: ApiGatewayV2ApiType::ApiGatewayV2ApiType,
303 properties: ApiGatewayV2ApiProperties {
304 name: self.name,
305 protocol_type,
306 disable_execute_api_endpoint: self.disable_execute_api_endpoint,
307 disable_schema_validation: self.disable_schema_validation,
308 cors_configuration: self.cors_configuration,
309 route_selection_expression: self.route_selection_expression,
310 },
311 });
312
313 let stage = ApiGatewayV2StageRef::internal_new(stage_resource_id);
314 let api = ApiGatewayV2ApiRef::internal_new(api_resource_id);
315
316 (api, stage)
317 }
318}
319
320pub struct CorsConfigurationBuilder {
321 allow_credentials: Option<bool>,
322 allow_headers: Option<Vec<String>>,
323 allow_methods: Option<Vec<String>>,
324 allow_origins: Option<Vec<String>>,
325 expose_headers: Option<Vec<String>>,
326 max_age: Option<u64>,
327}
328
329impl Default for CorsConfigurationBuilder {
330 fn default() -> Self {
331 Self::new()
332 }
333}
334
335impl CorsConfigurationBuilder {
336 pub fn new() -> Self {
337 Self {
338 allow_credentials: None,
339 allow_headers: None,
340 allow_methods: None,
341 allow_origins: None,
342 expose_headers: None,
343 max_age: None,
344 }
345 }
346
347 pub fn allow_credentials(self, allow: bool) -> Self {
348 Self {
349 allow_credentials: Some(allow),
350 ..self
351 }
352 }
353
354 pub fn allow_headers(self, headers: Vec<String>) -> Self {
355 Self {
356 allow_headers: Some(headers),
357 ..self
358 }
359 }
360
361 pub fn allow_methods(self, methods: Vec<HttpMethod>) -> Self {
362 Self {
363 allow_methods: Some(methods.into_iter().map(Into::into).collect()),
364 ..self
365 }
366 }
367
368 pub fn allow_origins(self, origins: Vec<String>) -> Self {
369 Self {
370 allow_origins: Some(origins),
371 ..self
372 }
373 }
374
375 pub fn expose_headers(self, headers: Vec<String>) -> Self {
376 Self {
377 expose_headers: Some(headers),
378 ..self
379 }
380 }
381
382 pub fn max_age(self, age: Duration) -> Self {
383 Self {
384 max_age: Some(age.as_secs()),
385 ..self
386 }
387 }
388
389 #[must_use]
390 pub fn build(self) -> CorsConfiguration {
391 CorsConfiguration {
392 allow_credentials: self.allow_credentials,
393 allow_headers: self.allow_headers,
394 allow_methods: self.allow_methods,
395 allow_origins: self.allow_origins,
396 expose_headers: self.expose_headers,
397 max_age: self.max_age,
398 }
399 }
400}