testcontainers_modules/arrow_flightsql/
mod.rs

1use std::borrow::Cow;
2
3use testcontainers::{core::WaitFor, Image};
4
5const NAME: &str = "voltrondata/flight-sql";
6const TAG: &str = "v1.4.1-slim";
7
8#[derive(Clone, Debug, Default)]
9/// Module to work with [`Arrow FlightSQL`] inside of tests.
10///
11/// This module is based on the [`voltrondata/flight-sql docker image`](https://hub.docker.com/r/voltrondata/flight-sql).
12///
13/// # Example
14/// ```
15/// use arrow_flight::{
16///     flight_service_client::FlightServiceClient, sql::client::FlightSqlServiceClient,
17/// };
18/// use futures::TryStreamExt;
19/// use testcontainers::runners::AsyncRunner;
20/// use testcontainers_modules::arrow_flightsql::ArrowFlightSQL;
21///
22/// #[tokio::test]
23/// async fn arrow_flightsql_select_version() -> Result<(), Box<dyn std::error::Error + 'static>> {
24///     let image = ArrowFlightSQL::default();
25///     let instance = image.start().await?;
26///     let host = instance.get_host().await?;
27///     let port = instance.get_host_port_ipv4(31337).await?;
28///     let url = format!("http://{host}:{port}");
29///     let service_client = FlightServiceClient::connect(url).await?;
30///     let mut client = FlightSqlServiceClient::new_from_inner(service_client);
31///     let _ = client.handshake("flight_username", "test").await?;
32///
33///     let mut statement = client
34///         .prepare("SELECT VERSION();".to_string(), None)
35///         .await?;
36///     let flight_info = statement.execute().await?;
37///
38///     let ticket = flight_info.endpoint[0]
39///         .ticket
40///         .as_ref()
41///         .expect("Ticket not present")
42///         .clone();
43///     let flight_data = client.do_get(ticket).await?;
44///     let batches: Vec<_> = flight_data.try_collect().await?;
45///     let batch = batches.first().expect("batch 0 not present");
46///     let array = batch.columns().first().expect("column not present");
47///     let data = array.to_data();
48///     let buffers = data.buffers();
49///     let buffer = buffers.get(1).expect("buffer not present");
50///     let values = buffer.to_vec();
51///     let version = String::from_utf8(values)?;
52///
53///     assert_eq!(version, "v1.0.0");
54///     Ok(())
55/// }
56/// ```
57///
58/// [`Apache Arrow FlightSQL`]: https://arrow.apache.org/docs/format/FlightSql.html
59/// [`voltrondata/flight-sql docker image`]: https://hub.docker.com/r/voltrondata/flight-sql
60pub struct ArrowFlightSQL {}
61
62impl Image for ArrowFlightSQL {
63    fn name(&self) -> &str {
64        NAME
65    }
66
67    fn tag(&self) -> &str {
68        TAG
69    }
70
71    fn ready_conditions(&self) -> Vec<WaitFor> {
72        vec![WaitFor::message_on_stdout("Flight SQL server - started")]
73    }
74
75    fn env_vars(
76        &self,
77    ) -> impl IntoIterator<Item = (impl Into<Cow<'_, str>>, impl Into<Cow<'_, str>>)> {
78        [
79            ("FLIGHT_PASSWORD", "test"),
80            ("DATABASE_FILENAME", "test.duckdb"),
81        ]
82    }
83}
84
85#[cfg(test)]
86mod tests {
87    use arrow_flight::{
88        flight_service_client::FlightServiceClient, sql::client::FlightSqlServiceClient,
89    };
90    use futures::TryStreamExt;
91    use testcontainers::runners::AsyncRunner;
92
93    use crate::arrow_flightsql::ArrowFlightSQL;
94
95    #[tokio::test]
96    async fn arrow_flightsql_select_version() -> Result<(), Box<dyn std::error::Error + 'static>> {
97        let image = ArrowFlightSQL::default();
98        let instance = image.start().await?;
99        let host = instance.get_host().await?;
100        let port = instance.get_host_port_ipv4(31337).await?;
101        let url = format!("http://{host}:{port}");
102        let service_client = FlightServiceClient::connect(url).await?;
103        let mut client = FlightSqlServiceClient::new_from_inner(service_client);
104        let _ = client.handshake("flight_username", "test").await?;
105
106        let mut statement = client
107            .prepare("SELECT VERSION();".to_string(), None)
108            .await?;
109        let flight_info = statement.execute().await?;
110
111        let ticket = flight_info.endpoint[0]
112            .ticket
113            .as_ref()
114            .expect("Ticket not present")
115            .clone();
116        let flight_data = client.do_get(ticket).await?;
117        let batches: Vec<_> = flight_data.try_collect().await?;
118        let batch = batches.first().expect("batch 0 not present");
119        let array = batch.columns().first().expect("column not present");
120        let data = array.to_data();
121        let buffers = data.buffers();
122        let buffer = buffers.get(1).expect("buffer not present");
123        let values = buffer.to_vec();
124        let version = String::from_utf8(values)?;
125
126        assert_eq!(version, "v1.0.0");
127        Ok(())
128    }
129}