shuttle_proto/
lib.rs

1mod generated;
2
3// useful re-exports if types are needed in other crates
4pub use prost;
5pub use prost_types;
6pub use tonic;
7
8#[cfg(feature = "provisioner")]
9pub mod provisioner {
10    pub use super::generated::provisioner::*;
11
12    #[cfg(feature = "provisioner-client")]
13    pub use super::_provisioner_client::*;
14
15    use shuttle_common::{
16        database::{self, AwsRdsEngine, SharedEngine},
17        DatabaseInfo,
18    };
19
20    impl From<DatabaseResponse> for DatabaseInfo {
21        fn from(response: DatabaseResponse) -> Self {
22            DatabaseInfo::new(
23                response.engine,
24                response.username,
25                response.password,
26                response.database_name,
27                response.port,
28                response.address_private,
29                response.address_public,
30            )
31        }
32    }
33
34    impl From<database::Type> for database_request::DbType {
35        fn from(db_type: database::Type) -> Self {
36            match db_type {
37                database::Type::Shared(engine) => {
38                    let engine = match engine {
39                        SharedEngine::Postgres => shared::Engine::Postgres(String::new()),
40                        SharedEngine::MongoDb => shared::Engine::Mongodb(String::new()),
41                    };
42                    database_request::DbType::Shared(Shared {
43                        engine: Some(engine),
44                    })
45                }
46                database::Type::AwsRds(engine) => {
47                    let config = RdsConfig {};
48                    let engine = match engine {
49                        AwsRdsEngine::Postgres => aws_rds::Engine::Postgres(config),
50                        AwsRdsEngine::MariaDB => aws_rds::Engine::Mariadb(config),
51                        AwsRdsEngine::MySql => aws_rds::Engine::Mysql(config),
52                    };
53                    database_request::DbType::AwsRds(AwsRds {
54                        engine: Some(engine),
55                    })
56                }
57            }
58        }
59    }
60
61    impl From<database_request::DbType> for Option<database::Type> {
62        fn from(db_type: database_request::DbType) -> Self {
63            match db_type {
64                database_request::DbType::Shared(Shared {
65                    engine: Some(engine),
66                }) => match engine {
67                    shared::Engine::Postgres(_) => {
68                        Some(database::Type::Shared(SharedEngine::Postgres))
69                    }
70                    shared::Engine::Mongodb(_) => {
71                        Some(database::Type::Shared(SharedEngine::MongoDb))
72                    }
73                },
74                database_request::DbType::AwsRds(AwsRds {
75                    engine: Some(engine),
76                }) => match engine {
77                    aws_rds::Engine::Postgres(_) => {
78                        Some(database::Type::AwsRds(AwsRdsEngine::Postgres))
79                    }
80                    aws_rds::Engine::Mysql(_) => Some(database::Type::AwsRds(AwsRdsEngine::MySql)),
81                    aws_rds::Engine::Mariadb(_) => {
82                        Some(database::Type::AwsRds(AwsRdsEngine::MariaDB))
83                    }
84                },
85                database_request::DbType::Shared(Shared { engine: None })
86                | database_request::DbType::AwsRds(AwsRds { engine: None }) => None,
87            }
88        }
89    }
90
91    impl std::fmt::Display for aws_rds::Engine {
92        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93            match self {
94                Self::Mariadb(_) => write!(f, "mariadb"),
95                Self::Mysql(_) => write!(f, "mysql"),
96                Self::Postgres(_) => write!(f, "postgres"),
97            }
98        }
99    }
100}
101
102#[cfg(feature = "provisioner-client")]
103mod _provisioner_client {
104    use super::provisioner::*;
105
106    use http::Uri;
107
108    pub type Client = provisioner_client::ProvisionerClient<
109        shuttle_common::claims::ClaimService<
110            shuttle_common::claims::InjectPropagation<tonic::transport::Channel>,
111        >,
112    >;
113
114    /// Get a provisioner client that is correctly configured for all services
115    pub async fn get_client(provisioner_uri: Uri) -> Client {
116        let channel = tonic::transport::Endpoint::from(provisioner_uri)
117            .connect()
118            .await
119            .expect("failed to connect to provisioner");
120
121        let provisioner_service = tower::ServiceBuilder::new()
122            .layer(shuttle_common::claims::ClaimLayer)
123            .layer(shuttle_common::claims::InjectPropagationLayer)
124            .service(channel);
125
126        Client::new(provisioner_service)
127            // allow dumps up to 50 MB
128            .max_decoding_message_size(50 * 1024 * 1024)
129            .max_encoding_message_size(50 * 1024 * 1024)
130    }
131}
132
133#[cfg(feature = "runtime")]
134pub mod runtime {
135    pub use super::generated::runtime::*;
136
137    #[cfg(feature = "runtime-client")]
138    pub use super::_runtime_client::*;
139}
140
141#[cfg(feature = "runtime-client")]
142mod _runtime_client {
143    use super::runtime::*;
144
145    use std::time::Duration;
146
147    use anyhow::Context;
148    use tonic::transport::Endpoint;
149    use tracing::{info, trace};
150
151    pub type Client = runtime_client::RuntimeClient<
152        shuttle_common::claims::InjectPropagation<tonic::transport::Channel>,
153    >;
154
155    /// Get a runtime client that is correctly configured
156    #[cfg(feature = "client")]
157    pub async fn get_client(address: String) -> anyhow::Result<Client> {
158        info!("connecting runtime client");
159        let conn = Endpoint::new(address)
160            .context("creating runtime client endpoint")?
161            .connect_timeout(Duration::from_secs(5));
162
163        // Wait for the spawned process to open the control port.
164        // Connecting instantly does not give it enough time.
165        let channel = tokio::time::timeout(Duration::from_millis(7000), async move {
166            let mut ms = 5;
167            loop {
168                if let Ok(channel) = conn.connect().await {
169                    break channel;
170                }
171                trace!("waiting for runtime control port to open");
172                // exponential backoff
173                tokio::time::sleep(Duration::from_millis(ms)).await;
174                ms *= 2;
175            }
176        })
177        .await
178        .context("runtime control port did not open in time")?;
179
180        let runtime_service = tower::ServiceBuilder::new()
181            .layer(shuttle_common::claims::InjectPropagationLayer)
182            .service(channel);
183
184        Ok(Client::new(runtime_service))
185    }
186}
187
188#[cfg(feature = "resource-recorder")]
189pub mod resource_recorder {
190    pub use super::generated::resource_recorder::*;
191
192    #[cfg(feature = "resource-recorder-client")]
193    pub use super::_resource_recorder_client::*;
194
195    use std::str::FromStr;
196
197    use anyhow::Context;
198
199    impl TryFrom<record_request::Resource> for shuttle_common::resource::Response {
200        type Error = anyhow::Error;
201
202        fn try_from(resource: record_request::Resource) -> Result<Self, Self::Error> {
203            let r#type = shuttle_common::resource::Type::from_str(resource.r#type.as_str())
204                .map_err(anyhow::Error::msg)
205                .context("resource type should have a valid resource string")?;
206            let response = shuttle_common::resource::Response {
207                r#type,
208                config: serde_json::from_slice(&resource.config)
209                    .context(format!("{} resource config should be valid JSON", r#type))?,
210                data: serde_json::from_slice(&resource.data)
211                    .context(format!("{} resource data should be valid JSON", r#type))?,
212            };
213
214            Ok(response)
215        }
216    }
217
218    impl TryFrom<Resource> for shuttle_common::resource::Response {
219        type Error = anyhow::Error;
220
221        fn try_from(resource: Resource) -> Result<Self, Self::Error> {
222            let r#type = shuttle_common::resource::Type::from_str(resource.r#type.as_str())
223                .map_err(anyhow::Error::msg)
224                .context("resource type should have a valid resource string")?;
225
226            let response = shuttle_common::resource::Response {
227                r#type,
228                config: serde_json::from_slice(&resource.config)
229                    .context(format!("{} resource config should be valid JSON", r#type))?,
230                data: serde_json::from_slice(&resource.data)
231                    .context(format!("{} resource data should be valid JSON", r#type))?,
232            };
233
234            Ok(response)
235        }
236    }
237}
238
239#[cfg(feature = "resource-recorder-client")]
240mod _resource_recorder_client {
241    use http::Uri;
242
243    pub type Client = super::resource_recorder::resource_recorder_client::ResourceRecorderClient<
244        shuttle_common::claims::ClaimService<
245            shuttle_common::claims::InjectPropagation<tonic::transport::Channel>,
246        >,
247    >;
248
249    /// Get a resource recorder client that is correctly configured for all services
250    pub async fn get_client(resource_recorder_uri: Uri) -> Client {
251        let channel = tonic::transport::Endpoint::from(resource_recorder_uri)
252            .connect()
253            .await
254            .expect("failed to connect to resource recorder");
255
256        let resource_recorder_service = tower::ServiceBuilder::new()
257            .layer(shuttle_common::claims::ClaimLayer)
258            .layer(shuttle_common::claims::InjectPropagationLayer)
259            .service(channel);
260
261        Client::new(resource_recorder_service)
262    }
263}
264
265#[cfg(feature = "logger")]
266pub mod logger {
267    pub use super::generated::logger::*;
268
269    #[cfg(feature = "logger-client")]
270    pub use super::_logger_client::*;
271
272    use std::str::FromStr;
273    use std::time::Duration;
274
275    use chrono::{NaiveDateTime, TimeZone, Utc};
276    use prost::bytes::Bytes;
277    use tokio::{select, sync::mpsc, time::interval};
278    use tonic::{
279        async_trait,
280        codegen::{Body, StdError},
281        Request,
282    };
283    use tracing::error;
284
285    use shuttle_common::{
286        log::{Backend, LogItem as LogItemCommon, LogRecorder},
287        DeploymentId,
288    };
289
290    impl From<LogItemCommon> for LogItem {
291        fn from(value: LogItemCommon) -> Self {
292            Self {
293                deployment_id: value.id.to_string(),
294                log_line: Some(LogLine {
295                    tx_timestamp: Some(prost_types::Timestamp {
296                        seconds: value.timestamp.timestamp(),
297                        nanos: value.timestamp.timestamp_subsec_nanos() as i32,
298                    }),
299                    service_name: format!("{:?}", value.internal_origin),
300                    data: value.line.into_bytes(),
301                }),
302            }
303        }
304    }
305
306    impl From<LogItem> for LogItemCommon {
307        fn from(value: LogItem) -> Self {
308            value
309                .log_line
310                .expect("log item to have log line")
311                .to_log_item_with_id(value.deployment_id.parse().unwrap_or_default())
312        }
313    }
314
315    impl LogLine {
316        pub fn to_log_item_with_id(self, deployment_id: DeploymentId) -> LogItemCommon {
317            let LogLine {
318                service_name,
319                tx_timestamp,
320                data,
321            } = self;
322            let tx_timestamp = tx_timestamp.expect("log to have timestamp");
323
324            LogItemCommon {
325                id: deployment_id,
326                internal_origin: Backend::from_str(&service_name)
327                    .expect("backend name to be valid"),
328                timestamp: Utc.from_utc_datetime(
329                    #[allow(deprecated)]
330                    &NaiveDateTime::from_timestamp_opt(
331                        tx_timestamp.seconds,
332                        tx_timestamp.nanos.try_into().unwrap_or_default(),
333                    )
334                    .unwrap_or_default(),
335                ),
336                line: String::from_utf8(data).expect("line to be utf-8"),
337            }
338        }
339    }
340
341    impl<I> LogRecorder for Batcher<I>
342    where
343        I: VecReceiver<Item = LogItem> + Clone + 'static,
344    {
345        fn record(&self, log: LogItemCommon) {
346            self.send(log.into());
347        }
348    }
349
350    /// Adapter to some client which expects to receive a vector of items
351    #[async_trait]
352    pub trait VecReceiver: Send {
353        type Item;
354
355        async fn receive(&mut self, items: Vec<Self::Item>);
356    }
357
358    #[async_trait]
359    impl<T> VecReceiver for logger_client::LoggerClient<T>
360    where
361        T: tonic::client::GrpcService<tonic::body::BoxBody> + Send + Sync + Clone,
362        T::Error: Into<StdError>,
363        T::ResponseBody: Body<Data = Bytes> + Send + 'static,
364        T::Future: Send,
365        <T::ResponseBody as Body>::Error: Into<StdError> + Send,
366    {
367        type Item = LogItem;
368
369        async fn receive(&mut self, items: Vec<Self::Item>) {
370            if let Err(error) = self
371                .store_logs(Request::new(StoreLogsRequest { logs: items }))
372                .await
373            {
374                error!(
375                    error = &error as &dyn std::error::Error,
376                    "failed to send batch logs to logger"
377                );
378            }
379        }
380    }
381
382    /// Wrapper to batch together items before forwarding them to some vector receiver
383    #[derive(Clone)]
384    pub struct Batcher<I: VecReceiver> {
385        tx: mpsc::UnboundedSender<I::Item>,
386    }
387
388    impl<I: VecReceiver + 'static> Batcher<I>
389    where
390        I::Item: Send,
391    {
392        /// Create a new batcher around inner with the given batch capacity.
393        /// Items will be send when the batch has reached capacity or at the set interval. Whichever comes first.
394        pub fn new(inner: I, capacity: usize, interval: Duration) -> Self {
395            let (tx, rx) = mpsc::unbounded_channel();
396
397            tokio::spawn(Self::batch(inner, rx, capacity, interval));
398
399            Self { tx }
400        }
401
402        /// Create a batcher around inner. It will send a batch of items to inner if a capacity of 256 is reached
403        /// or if an interval of 1 second is reached.
404        pub fn wrap(inner: I) -> Self {
405            Self::new(inner, 256, Duration::from_secs(1))
406        }
407
408        /// Send a single item into this batcher
409        pub fn send(&self, item: I::Item) {
410            if self.tx.send(item).is_err() {
411                unreachable!("the receiver will never drop");
412            }
413        }
414
415        /// Background task to forward the items once the batch capacity has been reached
416        async fn batch(
417            mut inner: I,
418            mut rx: mpsc::UnboundedReceiver<I::Item>,
419            capacity: usize,
420            interval_duration: Duration,
421        ) {
422            let mut interval = interval(interval_duration);
423
424            // Without this, the default behaviour will burst any missed tickers until they are caught up.
425            // This will cause a flood which we want to avoid.
426            interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
427
428            // Get past the first tick
429            interval.tick().await;
430
431            let mut cache = Vec::with_capacity(capacity);
432
433            loop {
434                select! {
435                    item = rx.recv() => {
436                        if let Some(item) = item {
437                            cache.push(item);
438
439                            if cache.len() == capacity {
440                                let old_cache = cache;
441                                cache = Vec::with_capacity(capacity);
442
443                                inner.receive(old_cache).await;
444                            }
445                        } else {
446                            // Sender dropped
447                            return;
448                        }
449                    },
450                    _ = interval.tick() => {
451                        if !cache.is_empty() {
452                            let old_cache = cache;
453                            cache = Vec::with_capacity(capacity);
454
455                            inner.receive(old_cache).await;
456                        }
457                    }
458                }
459            }
460        }
461    }
462
463    #[cfg(test)]
464    mod tests {
465        use std::{
466            sync::{Arc, Mutex},
467            time::Duration,
468        };
469
470        use tokio::time::sleep;
471        use tonic::async_trait;
472
473        use super::{Batcher, VecReceiver};
474
475        #[derive(Default, Clone)]
476        struct MockGroupReceiver(Arc<Mutex<Option<Vec<u32>>>>);
477
478        #[async_trait]
479        impl VecReceiver for MockGroupReceiver {
480            type Item = u32;
481
482            async fn receive(&mut self, items: Vec<Self::Item>) {
483                *self.0.lock().unwrap() = Some(items);
484            }
485        }
486
487        #[tokio::test]
488        async fn capacity_reached() {
489            let mock = MockGroupReceiver::default();
490            let batcher = Batcher::new(mock.clone(), 2, Duration::from_secs(120));
491
492            batcher.send(1);
493            sleep(Duration::from_millis(50)).await;
494            assert_eq!(*mock.0.lock().unwrap(), None);
495
496            batcher.send(2);
497            sleep(Duration::from_millis(50)).await;
498            assert_eq!(*mock.0.lock().unwrap(), Some(vec![1, 2]));
499
500            batcher.send(3);
501            sleep(Duration::from_millis(50)).await;
502            assert_eq!(*mock.0.lock().unwrap(), Some(vec![1, 2]));
503
504            batcher.send(4);
505            sleep(Duration::from_millis(50)).await;
506            assert_eq!(*mock.0.lock().unwrap(), Some(vec![3, 4]));
507        }
508
509        #[tokio::test]
510        async fn interval_reached() {
511            let mock = MockGroupReceiver::default();
512            let batcher = Batcher::new(mock.clone(), 2, Duration::from_millis(300));
513
514            sleep(Duration::from_millis(500)).await;
515            assert_eq!(
516                *mock.0.lock().unwrap(),
517                None,
518                "we should never send something when the cache is empty"
519            );
520
521            batcher.send(1);
522            sleep(Duration::from_millis(50)).await;
523            assert_eq!(*mock.0.lock().unwrap(), None);
524
525            sleep(Duration::from_millis(500)).await;
526            assert_eq!(*mock.0.lock().unwrap(), Some(vec![1]));
527        }
528    }
529}
530#[cfg(feature = "logger-client")]
531mod _logger_client {
532    use super::logger::*;
533
534    use http::Uri;
535
536    pub type Client = logger_client::LoggerClient<
537        shuttle_common::claims::ClaimService<
538            shuttle_common::claims::InjectPropagation<tonic::transport::Channel>,
539        >,
540    >;
541
542    /// Get a logger client that is correctly configured for all services
543    pub async fn get_client(logger_uri: Uri) -> Client {
544        let channel = tonic::transport::Endpoint::from(logger_uri)
545            .connect()
546            .await
547            .expect("failed to connect to logger");
548
549        let logger_service = tower::ServiceBuilder::new()
550            .layer(shuttle_common::claims::ClaimLayer)
551            .layer(shuttle_common::claims::InjectPropagationLayer)
552            .service(channel);
553
554        Client::new(logger_service)
555    }
556}