1mod generated;
2
3pub use prost;
5pub use prost_types;
6pub use tonic;
7
8#[cfg(feature = "provisioner")]
9pub mod provisioner {
10 pub use super::generated::provisioner::*;
11
12 #[cfg(feature = "provisioner-client")]
13 pub use super::_provisioner_client::*;
14
15 use shuttle_common::{
16 database::{self, AwsRdsEngine, SharedEngine},
17 DatabaseInfo,
18 };
19
20 impl From<DatabaseResponse> for DatabaseInfo {
21 fn from(response: DatabaseResponse) -> Self {
22 DatabaseInfo::new(
23 response.engine,
24 response.username,
25 response.password,
26 response.database_name,
27 response.port,
28 response.address_private,
29 response.address_public,
30 )
31 }
32 }
33
34 impl From<database::Type> for database_request::DbType {
35 fn from(db_type: database::Type) -> Self {
36 match db_type {
37 database::Type::Shared(engine) => {
38 let engine = match engine {
39 SharedEngine::Postgres => shared::Engine::Postgres(String::new()),
40 SharedEngine::MongoDb => shared::Engine::Mongodb(String::new()),
41 };
42 database_request::DbType::Shared(Shared {
43 engine: Some(engine),
44 })
45 }
46 database::Type::AwsRds(engine) => {
47 let config = RdsConfig {};
48 let engine = match engine {
49 AwsRdsEngine::Postgres => aws_rds::Engine::Postgres(config),
50 AwsRdsEngine::MariaDB => aws_rds::Engine::Mariadb(config),
51 AwsRdsEngine::MySql => aws_rds::Engine::Mysql(config),
52 };
53 database_request::DbType::AwsRds(AwsRds {
54 engine: Some(engine),
55 })
56 }
57 }
58 }
59 }
60
61 impl From<database_request::DbType> for Option<database::Type> {
62 fn from(db_type: database_request::DbType) -> Self {
63 match db_type {
64 database_request::DbType::Shared(Shared {
65 engine: Some(engine),
66 }) => match engine {
67 shared::Engine::Postgres(_) => {
68 Some(database::Type::Shared(SharedEngine::Postgres))
69 }
70 shared::Engine::Mongodb(_) => {
71 Some(database::Type::Shared(SharedEngine::MongoDb))
72 }
73 },
74 database_request::DbType::AwsRds(AwsRds {
75 engine: Some(engine),
76 }) => match engine {
77 aws_rds::Engine::Postgres(_) => {
78 Some(database::Type::AwsRds(AwsRdsEngine::Postgres))
79 }
80 aws_rds::Engine::Mysql(_) => Some(database::Type::AwsRds(AwsRdsEngine::MySql)),
81 aws_rds::Engine::Mariadb(_) => {
82 Some(database::Type::AwsRds(AwsRdsEngine::MariaDB))
83 }
84 },
85 database_request::DbType::Shared(Shared { engine: None })
86 | database_request::DbType::AwsRds(AwsRds { engine: None }) => None,
87 }
88 }
89 }
90
91 impl std::fmt::Display for aws_rds::Engine {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 match self {
94 Self::Mariadb(_) => write!(f, "mariadb"),
95 Self::Mysql(_) => write!(f, "mysql"),
96 Self::Postgres(_) => write!(f, "postgres"),
97 }
98 }
99 }
100}
101
102#[cfg(feature = "provisioner-client")]
103mod _provisioner_client {
104 use super::provisioner::*;
105
106 use http::Uri;
107
108 pub type Client = provisioner_client::ProvisionerClient<
109 shuttle_common::claims::ClaimService<
110 shuttle_common::claims::InjectPropagation<tonic::transport::Channel>,
111 >,
112 >;
113
114 pub async fn get_client(provisioner_uri: Uri) -> Client {
116 let channel = tonic::transport::Endpoint::from(provisioner_uri)
117 .connect()
118 .await
119 .expect("failed to connect to provisioner");
120
121 let provisioner_service = tower::ServiceBuilder::new()
122 .layer(shuttle_common::claims::ClaimLayer)
123 .layer(shuttle_common::claims::InjectPropagationLayer)
124 .service(channel);
125
126 Client::new(provisioner_service)
127 .max_decoding_message_size(50 * 1024 * 1024)
129 .max_encoding_message_size(50 * 1024 * 1024)
130 }
131}
132
133#[cfg(feature = "runtime")]
134pub mod runtime {
135 pub use super::generated::runtime::*;
136
137 #[cfg(feature = "runtime-client")]
138 pub use super::_runtime_client::*;
139}
140
141#[cfg(feature = "runtime-client")]
142mod _runtime_client {
143 use super::runtime::*;
144
145 use std::time::Duration;
146
147 use anyhow::Context;
148 use tonic::transport::Endpoint;
149 use tracing::{info, trace};
150
151 pub type Client = runtime_client::RuntimeClient<
152 shuttle_common::claims::InjectPropagation<tonic::transport::Channel>,
153 >;
154
155 #[cfg(feature = "client")]
157 pub async fn get_client(address: String) -> anyhow::Result<Client> {
158 info!("connecting runtime client");
159 let conn = Endpoint::new(address)
160 .context("creating runtime client endpoint")?
161 .connect_timeout(Duration::from_secs(5));
162
163 let channel = tokio::time::timeout(Duration::from_millis(7000), async move {
166 let mut ms = 5;
167 loop {
168 if let Ok(channel) = conn.connect().await {
169 break channel;
170 }
171 trace!("waiting for runtime control port to open");
172 tokio::time::sleep(Duration::from_millis(ms)).await;
174 ms *= 2;
175 }
176 })
177 .await
178 .context("runtime control port did not open in time")?;
179
180 let runtime_service = tower::ServiceBuilder::new()
181 .layer(shuttle_common::claims::InjectPropagationLayer)
182 .service(channel);
183
184 Ok(Client::new(runtime_service))
185 }
186}
187
188#[cfg(feature = "resource-recorder")]
189pub mod resource_recorder {
190 pub use super::generated::resource_recorder::*;
191
192 #[cfg(feature = "resource-recorder-client")]
193 pub use super::_resource_recorder_client::*;
194
195 use std::str::FromStr;
196
197 use anyhow::Context;
198
199 impl TryFrom<record_request::Resource> for shuttle_common::resource::Response {
200 type Error = anyhow::Error;
201
202 fn try_from(resource: record_request::Resource) -> Result<Self, Self::Error> {
203 let r#type = shuttle_common::resource::Type::from_str(resource.r#type.as_str())
204 .map_err(anyhow::Error::msg)
205 .context("resource type should have a valid resource string")?;
206 let response = shuttle_common::resource::Response {
207 r#type,
208 config: serde_json::from_slice(&resource.config)
209 .context(format!("{} resource config should be valid JSON", r#type))?,
210 data: serde_json::from_slice(&resource.data)
211 .context(format!("{} resource data should be valid JSON", r#type))?,
212 };
213
214 Ok(response)
215 }
216 }
217
218 impl TryFrom<Resource> for shuttle_common::resource::Response {
219 type Error = anyhow::Error;
220
221 fn try_from(resource: Resource) -> Result<Self, Self::Error> {
222 let r#type = shuttle_common::resource::Type::from_str(resource.r#type.as_str())
223 .map_err(anyhow::Error::msg)
224 .context("resource type should have a valid resource string")?;
225
226 let response = shuttle_common::resource::Response {
227 r#type,
228 config: serde_json::from_slice(&resource.config)
229 .context(format!("{} resource config should be valid JSON", r#type))?,
230 data: serde_json::from_slice(&resource.data)
231 .context(format!("{} resource data should be valid JSON", r#type))?,
232 };
233
234 Ok(response)
235 }
236 }
237}
238
239#[cfg(feature = "resource-recorder-client")]
240mod _resource_recorder_client {
241 use http::Uri;
242
243 pub type Client = super::resource_recorder::resource_recorder_client::ResourceRecorderClient<
244 shuttle_common::claims::ClaimService<
245 shuttle_common::claims::InjectPropagation<tonic::transport::Channel>,
246 >,
247 >;
248
249 pub async fn get_client(resource_recorder_uri: Uri) -> Client {
251 let channel = tonic::transport::Endpoint::from(resource_recorder_uri)
252 .connect()
253 .await
254 .expect("failed to connect to resource recorder");
255
256 let resource_recorder_service = tower::ServiceBuilder::new()
257 .layer(shuttle_common::claims::ClaimLayer)
258 .layer(shuttle_common::claims::InjectPropagationLayer)
259 .service(channel);
260
261 Client::new(resource_recorder_service)
262 }
263}
264
265#[cfg(feature = "logger")]
266pub mod logger {
267 pub use super::generated::logger::*;
268
269 #[cfg(feature = "logger-client")]
270 pub use super::_logger_client::*;
271
272 use std::str::FromStr;
273 use std::time::Duration;
274
275 use chrono::{NaiveDateTime, TimeZone, Utc};
276 use prost::bytes::Bytes;
277 use tokio::{select, sync::mpsc, time::interval};
278 use tonic::{
279 async_trait,
280 codegen::{Body, StdError},
281 Request,
282 };
283 use tracing::error;
284
285 use shuttle_common::{
286 log::{Backend, LogItem as LogItemCommon, LogRecorder},
287 DeploymentId,
288 };
289
290 impl From<LogItemCommon> for LogItem {
291 fn from(value: LogItemCommon) -> Self {
292 Self {
293 deployment_id: value.id.to_string(),
294 log_line: Some(LogLine {
295 tx_timestamp: Some(prost_types::Timestamp {
296 seconds: value.timestamp.timestamp(),
297 nanos: value.timestamp.timestamp_subsec_nanos() as i32,
298 }),
299 service_name: format!("{:?}", value.internal_origin),
300 data: value.line.into_bytes(),
301 }),
302 }
303 }
304 }
305
306 impl From<LogItem> for LogItemCommon {
307 fn from(value: LogItem) -> Self {
308 value
309 .log_line
310 .expect("log item to have log line")
311 .to_log_item_with_id(value.deployment_id.parse().unwrap_or_default())
312 }
313 }
314
315 impl LogLine {
316 pub fn to_log_item_with_id(self, deployment_id: DeploymentId) -> LogItemCommon {
317 let LogLine {
318 service_name,
319 tx_timestamp,
320 data,
321 } = self;
322 let tx_timestamp = tx_timestamp.expect("log to have timestamp");
323
324 LogItemCommon {
325 id: deployment_id,
326 internal_origin: Backend::from_str(&service_name)
327 .expect("backend name to be valid"),
328 timestamp: Utc.from_utc_datetime(
329 #[allow(deprecated)]
330 &NaiveDateTime::from_timestamp_opt(
331 tx_timestamp.seconds,
332 tx_timestamp.nanos.try_into().unwrap_or_default(),
333 )
334 .unwrap_or_default(),
335 ),
336 line: String::from_utf8(data).expect("line to be utf-8"),
337 }
338 }
339 }
340
341 impl<I> LogRecorder for Batcher<I>
342 where
343 I: VecReceiver<Item = LogItem> + Clone + 'static,
344 {
345 fn record(&self, log: LogItemCommon) {
346 self.send(log.into());
347 }
348 }
349
350 #[async_trait]
352 pub trait VecReceiver: Send {
353 type Item;
354
355 async fn receive(&mut self, items: Vec<Self::Item>);
356 }
357
358 #[async_trait]
359 impl<T> VecReceiver for logger_client::LoggerClient<T>
360 where
361 T: tonic::client::GrpcService<tonic::body::BoxBody> + Send + Sync + Clone,
362 T::Error: Into<StdError>,
363 T::ResponseBody: Body<Data = Bytes> + Send + 'static,
364 T::Future: Send,
365 <T::ResponseBody as Body>::Error: Into<StdError> + Send,
366 {
367 type Item = LogItem;
368
369 async fn receive(&mut self, items: Vec<Self::Item>) {
370 if let Err(error) = self
371 .store_logs(Request::new(StoreLogsRequest { logs: items }))
372 .await
373 {
374 error!(
375 error = &error as &dyn std::error::Error,
376 "failed to send batch logs to logger"
377 );
378 }
379 }
380 }
381
382 #[derive(Clone)]
384 pub struct Batcher<I: VecReceiver> {
385 tx: mpsc::UnboundedSender<I::Item>,
386 }
387
388 impl<I: VecReceiver + 'static> Batcher<I>
389 where
390 I::Item: Send,
391 {
392 pub fn new(inner: I, capacity: usize, interval: Duration) -> Self {
395 let (tx, rx) = mpsc::unbounded_channel();
396
397 tokio::spawn(Self::batch(inner, rx, capacity, interval));
398
399 Self { tx }
400 }
401
402 pub fn wrap(inner: I) -> Self {
405 Self::new(inner, 256, Duration::from_secs(1))
406 }
407
408 pub fn send(&self, item: I::Item) {
410 if self.tx.send(item).is_err() {
411 unreachable!("the receiver will never drop");
412 }
413 }
414
415 async fn batch(
417 mut inner: I,
418 mut rx: mpsc::UnboundedReceiver<I::Item>,
419 capacity: usize,
420 interval_duration: Duration,
421 ) {
422 let mut interval = interval(interval_duration);
423
424 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
427
428 interval.tick().await;
430
431 let mut cache = Vec::with_capacity(capacity);
432
433 loop {
434 select! {
435 item = rx.recv() => {
436 if let Some(item) = item {
437 cache.push(item);
438
439 if cache.len() == capacity {
440 let old_cache = cache;
441 cache = Vec::with_capacity(capacity);
442
443 inner.receive(old_cache).await;
444 }
445 } else {
446 return;
448 }
449 },
450 _ = interval.tick() => {
451 if !cache.is_empty() {
452 let old_cache = cache;
453 cache = Vec::with_capacity(capacity);
454
455 inner.receive(old_cache).await;
456 }
457 }
458 }
459 }
460 }
461 }
462
463 #[cfg(test)]
464 mod tests {
465 use std::{
466 sync::{Arc, Mutex},
467 time::Duration,
468 };
469
470 use tokio::time::sleep;
471 use tonic::async_trait;
472
473 use super::{Batcher, VecReceiver};
474
475 #[derive(Default, Clone)]
476 struct MockGroupReceiver(Arc<Mutex<Option<Vec<u32>>>>);
477
478 #[async_trait]
479 impl VecReceiver for MockGroupReceiver {
480 type Item = u32;
481
482 async fn receive(&mut self, items: Vec<Self::Item>) {
483 *self.0.lock().unwrap() = Some(items);
484 }
485 }
486
487 #[tokio::test]
488 async fn capacity_reached() {
489 let mock = MockGroupReceiver::default();
490 let batcher = Batcher::new(mock.clone(), 2, Duration::from_secs(120));
491
492 batcher.send(1);
493 sleep(Duration::from_millis(50)).await;
494 assert_eq!(*mock.0.lock().unwrap(), None);
495
496 batcher.send(2);
497 sleep(Duration::from_millis(50)).await;
498 assert_eq!(*mock.0.lock().unwrap(), Some(vec![1, 2]));
499
500 batcher.send(3);
501 sleep(Duration::from_millis(50)).await;
502 assert_eq!(*mock.0.lock().unwrap(), Some(vec![1, 2]));
503
504 batcher.send(4);
505 sleep(Duration::from_millis(50)).await;
506 assert_eq!(*mock.0.lock().unwrap(), Some(vec![3, 4]));
507 }
508
509 #[tokio::test]
510 async fn interval_reached() {
511 let mock = MockGroupReceiver::default();
512 let batcher = Batcher::new(mock.clone(), 2, Duration::from_millis(300));
513
514 sleep(Duration::from_millis(500)).await;
515 assert_eq!(
516 *mock.0.lock().unwrap(),
517 None,
518 "we should never send something when the cache is empty"
519 );
520
521 batcher.send(1);
522 sleep(Duration::from_millis(50)).await;
523 assert_eq!(*mock.0.lock().unwrap(), None);
524
525 sleep(Duration::from_millis(500)).await;
526 assert_eq!(*mock.0.lock().unwrap(), Some(vec![1]));
527 }
528 }
529}
530#[cfg(feature = "logger-client")]
531mod _logger_client {
532 use super::logger::*;
533
534 use http::Uri;
535
536 pub type Client = logger_client::LoggerClient<
537 shuttle_common::claims::ClaimService<
538 shuttle_common::claims::InjectPropagation<tonic::transport::Channel>,
539 >,
540 >;
541
542 pub async fn get_client(logger_uri: Uri) -> Client {
544 let channel = tonic::transport::Endpoint::from(logger_uri)
545 .connect()
546 .await
547 .expect("failed to connect to logger");
548
549 let logger_service = tower::ServiceBuilder::new()
550 .layer(shuttle_common::claims::ClaimLayer)
551 .layer(shuttle_common::claims::InjectPropagationLayer)
552 .service(channel);
553
554 Client::new(logger_service)
555 }
556}