1use std::marker::PhantomData;
2use crate::apigateway::{ApiGatewayV2Api, ApiGatewayV2ApiProperties, ApiGatewayV2ApiRef, ApiGatewayV2Integration, ApiGatewayV2IntegrationProperties, ApiGatewayV2Route, ApiGatewayV2RouteProperties, ApiGatewayV2Stage, ApiGatewayV2StageProperties, ApiGatewayV2StageRef, CorsConfiguration};
3use crate::intrinsic::{get_arn, get_ref, join, AWS_ACCOUNT_PSEUDO_PARAM, AWS_PARTITION_PSEUDO_PARAM, AWS_REGION_PSEUDO_PARAM};
4use crate::lambda::{FunctionRef, PermissionBuilder};
5use crate::shared::HttpMethod;
6use crate::shared::Id;
7use crate::stack::{Resource, StackBuilder};
8use serde_json::Value;
9use std::time::Duration;
10use crate::type_state;
11use crate::wrappers::LambdaPermissionAction;
12
13struct RouteInfo {
16 lambda_id: Id,
17 path: String,
18 method: Option<HttpMethod>,
19 resource_id: String,
20}
21
22type_state!(
23 ApiGatewayV2APIState,
24 StartState,
25 HttpState,
26 WebsocketState,
27);
28
29pub struct ApiGatewayV2Builder<T: ApiGatewayV2APIState> {
54 phantom_data: PhantomData<T>,
55 id: Id,
56 name: Option<String>,
57 protocol_type: Option<String>,
58 disable_execute_api_endpoint: Option<bool>,
59 disable_schema_validation: Option<bool>,
60 cors_configuration: Option<CorsConfiguration>,
61 route_info: Vec<RouteInfo>,
62 route_selection_expression: Option<String>,
63}
64
65impl ApiGatewayV2Builder<StartState> {
66 pub fn new<T: Into<String>>(id: &str, name: T) -> ApiGatewayV2Builder<StartState> {
72 Self {
73 phantom_data: Default::default(),
74 id: Id(id.to_string()),
75 name: Some(name.into()), protocol_type: None,
77 disable_execute_api_endpoint: None,
78 disable_schema_validation: None,
79 cors_configuration: None,
80 route_selection_expression: None,
81 route_info: vec![],
82 }
83 }
84
85 pub fn http(self) -> ApiGatewayV2Builder<HttpState> {
86 ApiGatewayV2Builder {
87 phantom_data: Default::default(),
88 id: self.id,
89 name: self.name,
90 protocol_type: Some("HTTP".to_string()),
91 disable_execute_api_endpoint: self.disable_execute_api_endpoint,
92 cors_configuration: self.cors_configuration,
93 route_info: self.route_info,
94 disable_schema_validation: None,
95 route_selection_expression: None,
96 }
97 }
98
99 pub fn websocket<T: Into<String>>(self, route_selection_expression: T) -> ApiGatewayV2Builder<WebsocketState> {
100 ApiGatewayV2Builder {
101 phantom_data: Default::default(),
102 id: self.id,
103 name: self.name,
104 protocol_type: Some("WEBSOCKET".to_string()),
105 route_selection_expression: Some(route_selection_expression.into()),
106 disable_execute_api_endpoint: self.disable_execute_api_endpoint,
107 route_info: self.route_info,
108 disable_schema_validation: self.disable_schema_validation,
109 cors_configuration: None,
110 }
111 }
112}
113
114impl ApiGatewayV2Builder<HttpState> {
115 pub fn cors_configuration(self, config: CorsConfiguration) -> Self {
116 Self {
117 cors_configuration: Some(config),
118 ..self
119 }
120 }
121
122 pub fn add_route_lambda<T: Into<String>>(mut self, path: T, method: HttpMethod, lambda: &FunctionRef) -> Self {
126 let path = path.into();
127 let path = if path.starts_with("/") { path } else { format!("/{}", path) };
128
129 self.route_info.push(RouteInfo {
130 lambda_id: lambda.get_id().clone(),
131 path,
132 method: Some(method),
133 resource_id: lambda.get_resource_id().to_string(),
134 });
135 Self { ..self }
136 }
137
138 pub fn build(self, stack_builder: &mut StackBuilder) -> (
139 ApiGatewayV2ApiRef,
140 ApiGatewayV2StageRef,
141 ) {
142 self.build_internal(stack_builder)
143 }
144}
145
146impl ApiGatewayV2Builder<WebsocketState> {
147 pub fn disable_schema_validation(self, disable: bool) -> Self {
148 Self {
149 disable_schema_validation: Some(disable),
150 ..self
151 }
152 }
153
154 pub fn add_route_lambda<T: Into<String>>(mut self, route_key: T, lambda: &FunctionRef) -> Self {
158 self.route_info.push(RouteInfo {
159 lambda_id: lambda.get_id().clone(),
160 path: route_key.into(),
161 method: None,
162 resource_id: lambda.get_resource_id().to_string(),
163 });
164 Self { ..self }
165 }
166
167 pub fn build(self, stack_builder: &mut StackBuilder) -> (
168 ApiGatewayV2ApiRef,
169 ApiGatewayV2StageRef,
170 ) {
171 self.build_internal(stack_builder)
172 }
173}
174
175impl<T: ApiGatewayV2APIState> ApiGatewayV2Builder<T> {
176 pub fn disable_execute_api_endpoint(self, disable_api_endpoint: bool) -> Self {
177 Self {
178 disable_execute_api_endpoint: Some(disable_api_endpoint),
179 ..self
180 }
181 }
182
183 pub fn add_default_route_lambda(mut self, lambda: &FunctionRef) -> Self {
187 self.route_info.push(RouteInfo {
188 lambda_id: lambda.get_id().clone(),
189 path: "$default".to_string(),
190 method: None,
191 resource_id: lambda.get_resource_id().to_string(),
192 });
193 Self { ..self }
194 }
195
196 fn build_internal(
197 self, stack_builder: &mut StackBuilder
198 ) -> (
199 ApiGatewayV2ApiRef,
200 ApiGatewayV2StageRef,
201 ) {
202 let api_resource_id = Resource::generate_id("HttpApiGateway");
203 let stage_resource_id = Resource::generate_id("HttpApiStage");
204 let stage_id = Id::generate_id(&self.id, "Stage");
205
206 let protocol_type = self.protocol_type.expect("protocol type should be present, enforced by builder");
207
208 self
209 .route_info
210 .into_iter()
211 .for_each(|info| {
212 let route_id = Id::combine_with_resource_id(&self.id, &info.lambda_id);
213 let route_permission_id = Id::generate_id(&self.id, "Permission");
214 let route_integration_id = Id::generate_id(&self.id, "Integration");
215
216 let integration_resource_id = Resource::generate_id("HttpApiIntegration");
217 let route_resource_id = Resource::generate_id("HttpApiRoute");
218
219 PermissionBuilder::new(
220 &route_permission_id,
221 LambdaPermissionAction("lambda:InvokeFunction".to_string()),
222 get_arn(&info.resource_id),
223 "apigateway.amazonaws.com".to_string(),
224 )
225 .source_arn(join(
226 "",
227 vec![
228 Value::String("arn:".to_string()),
229 get_ref(AWS_PARTITION_PSEUDO_PARAM),
230 Value::String(":execute-api:".to_string()),
231 get_ref(AWS_REGION_PSEUDO_PARAM),
232 Value::String(":".to_string()),
233 get_ref(AWS_ACCOUNT_PSEUDO_PARAM),
234 Value::String(":".to_string()),
235 get_ref(&api_resource_id),
236 Value::String(format!("*/*{}", info.path)),
237 ],
238 ))
239 .build(stack_builder);
240
241 let integration = ApiGatewayV2Integration {
242 id: route_integration_id,
243 resource_id: integration_resource_id.clone(),
244 r#type: "AWS::ApiGatewayV2::Integration".to_string(),
245 properties: ApiGatewayV2IntegrationProperties {
246 api_id: get_ref(&api_resource_id),
247 integration_type: "AWS_PROXY".to_string(),
248 payload_format_version: if &protocol_type == "HTTP" { Some("2.0".to_string()) } else { Some("1.0".to_string()) },
249 integration_uri: Some(get_arn(&info.resource_id)),
250 content_handling_strategy: None, integration_method: None, passthrough_behavior: None,
254 request_parameters: None,
255 request_templates: None,
256 response_parameters: None,
257 timeout_in_millis: None,
258 },
259 };
260 stack_builder.add_resource(integration);
261
262 let route_key = if let Some(method) = info.method {
263 let method: String = method.into();
264 format!("{} {}", method, info.path)
265 } else {
266 info.path
267 };
268
269 let route = ApiGatewayV2Route {
270 id: route_id,
271 resource_id: route_resource_id.clone(),
272 r#type: "AWS::ApiGatewayV2::Route".to_string(),
273 properties: ApiGatewayV2RouteProperties {
274 api_id: get_ref(&api_resource_id),
275 route_key,
276 target: Some(join(
277 "",
278 vec![Value::String("integrations/".to_string()), get_ref(&integration_resource_id)],
279 )),
280 },
281 };
282 stack_builder.add_resource(route);
283 });
284
285 stack_builder.add_resource(ApiGatewayV2Stage {
286 id: stage_id,
287 resource_id: stage_resource_id.clone(),
288 r#type: "AWS::ApiGatewayV2::Stage".to_string(),
289 properties: ApiGatewayV2StageProperties {
290 api_id: get_ref(&api_resource_id),
291 stage_name: if &protocol_type == "HTTP" { "$default".to_string() } else { "prod".to_string() }, auto_deploy: true,
293 default_route_settings: None,
294 route_settings: None,
295 },
296 });
297
298 stack_builder.add_resource(ApiGatewayV2Api {
299 id: self.id,
300 resource_id: api_resource_id.clone(),
301 r#type: "AWS::ApiGatewayV2::Api".to_string(),
302 properties: ApiGatewayV2ApiProperties {
303 name: self.name,
304 protocol_type,
305 disable_execute_api_endpoint: self.disable_execute_api_endpoint,
306 disable_schema_validation: self.disable_schema_validation,
307 cors_configuration: self.cors_configuration,
308 route_selection_expression: self.route_selection_expression,
309 },
310 });
311
312 let stage = ApiGatewayV2StageRef::internal_new(stage_resource_id);
313 let api = ApiGatewayV2ApiRef::internal_new(api_resource_id);
314
315 (api, stage)
316 }
317}
318
319pub struct CorsConfigurationBuilder {
320 allow_credentials: Option<bool>,
321 allow_headers: Option<Vec<String>>,
322 allow_methods: Option<Vec<String>>,
323 allow_origins: Option<Vec<String>>,
324 expose_headers: Option<Vec<String>>,
325 max_age: Option<u64>,
326}
327
328impl Default for CorsConfigurationBuilder {
329 fn default() -> Self {
330 Self::new()
331 }
332}
333
334impl CorsConfigurationBuilder {
335 pub fn new() -> Self {
336 Self {
337 allow_credentials: None,
338 allow_headers: None,
339 allow_methods: None,
340 allow_origins: None,
341 expose_headers: None,
342 max_age: None,
343 }
344 }
345
346 pub fn allow_credentials(self, allow: bool) -> Self {
347 Self {
348 allow_credentials: Some(allow),
349 ..self
350 }
351 }
352
353 pub fn allow_headers(self, headers: Vec<String>) -> Self {
354 Self {
355 allow_headers: Some(headers),
356 ..self
357 }
358 }
359
360 pub fn allow_methods(self, methods: Vec<HttpMethod>) -> Self {
361 Self {
362 allow_methods: Some(methods.into_iter().map(Into::into).collect()),
363 ..self
364 }
365 }
366
367 pub fn allow_origins(self, origins: Vec<String>) -> Self {
368 Self {
369 allow_origins: Some(origins),
370 ..self
371 }
372 }
373
374 pub fn expose_headers(self, headers: Vec<String>) -> Self {
375 Self {
376 expose_headers: Some(headers),
377 ..self
378 }
379 }
380
381 pub fn max_age(self, age: Duration) -> Self {
382 Self {
383 max_age: Some(age.as_secs()),
384 ..self
385 }
386 }
387
388 #[must_use]
389 pub fn build(self) -> CorsConfiguration {
390 CorsConfiguration {
391 allow_credentials: self.allow_credentials,
392 allow_headers: self.allow_headers,
393 allow_methods: self.allow_methods,
394 allow_origins: self.allow_origins,
395 expose_headers: self.expose_headers,
396 max_age: self.max_age,
397 }
398 }
399}