testcontainers_modules/arrow_flightsql/mod.rs
1use std::borrow::Cow;
2
3use testcontainers::{core::WaitFor, Image};
4
5const NAME: &str = "voltrondata/flight-sql";
6const TAG: &str = "v1.4.1-slim";
7
8#[derive(Clone, Debug, Default)]
9/// Module to work with [`Arrow FlightSQL`] inside of tests.
10///
11/// This module is based on the [`voltrondata/flight-sql docker image`](https://hub.docker.com/r/voltrondata/flight-sql).
12///
13/// # Example
14/// ```
15/// use arrow_flight::{
16/// flight_service_client::FlightServiceClient, sql::client::FlightSqlServiceClient,
17/// };
18/// use futures::TryStreamExt;
19/// use testcontainers::runners::AsyncRunner;
20/// use testcontainers_modules::arrow_flightsql::ArrowFlightSQL;
21///
22/// #[tokio::test]
23/// async fn arrow_flightsql_select_version() -> Result<(), Box<dyn std::error::Error + 'static>> {
24/// let image = ArrowFlightSQL::default();
25/// let instance = image.start().await?;
26/// let host = instance.get_host().await?;
27/// let port = instance.get_host_port_ipv4(31337).await?;
28/// let url = format!("http://{host}:{port}");
29/// let service_client = FlightServiceClient::connect(url).await?;
30/// let mut client = FlightSqlServiceClient::new_from_inner(service_client);
31/// let _ = client.handshake("flight_username", "test").await?;
32///
33/// let mut statement = client
34/// .prepare("SELECT VERSION();".to_string(), None)
35/// .await?;
36/// let flight_info = statement.execute().await?;
37///
38/// let ticket = flight_info.endpoint[0]
39/// .ticket
40/// .as_ref()
41/// .expect("Ticket not present")
42/// .clone();
43/// let flight_data = client.do_get(ticket).await?;
44/// let batches: Vec<_> = flight_data.try_collect().await?;
45/// let batch = batches.first().expect("batch 0 not present");
46/// let array = batch.columns().first().expect("column not present");
47/// let data = array.to_data();
48/// let buffers = data.buffers();
49/// let buffer = buffers.get(1).expect("buffer not present");
50/// let values = buffer.to_vec();
51/// let version = String::from_utf8(values)?;
52///
53/// assert_eq!(version, "v1.0.0");
54/// Ok(())
55/// }
56/// ```
57///
58/// [`Apache Arrow FlightSQL`]: https://arrow.apache.org/docs/format/FlightSql.html
59/// [`voltrondata/flight-sql docker image`]: https://hub.docker.com/r/voltrondata/flight-sql
60pub struct ArrowFlightSQL {}
61
62impl Image for ArrowFlightSQL {
63 fn name(&self) -> &str {
64 NAME
65 }
66
67 fn tag(&self) -> &str {
68 TAG
69 }
70
71 fn ready_conditions(&self) -> Vec<WaitFor> {
72 vec![WaitFor::message_on_stdout("Flight SQL server - started")]
73 }
74
75 fn env_vars(
76 &self,
77 ) -> impl IntoIterator<Item = (impl Into<Cow<'_, str>>, impl Into<Cow<'_, str>>)> {
78 [
79 ("FLIGHT_PASSWORD", "test"),
80 ("DATABASE_FILENAME", "test.duckdb"),
81 ]
82 }
83}
84
85#[cfg(test)]
86mod tests {
87 use arrow_flight::{
88 flight_service_client::FlightServiceClient, sql::client::FlightSqlServiceClient,
89 };
90 use futures::TryStreamExt;
91 use testcontainers::runners::AsyncRunner;
92
93 use crate::arrow_flightsql::ArrowFlightSQL;
94
95 #[tokio::test]
96 async fn arrow_flightsql_select_version() -> Result<(), Box<dyn std::error::Error + 'static>> {
97 let image = ArrowFlightSQL::default();
98 let instance = image.start().await?;
99 let host = instance.get_host().await?;
100 let port = instance.get_host_port_ipv4(31337).await?;
101 let url = format!("http://{host}:{port}");
102 let service_client = FlightServiceClient::connect(url).await?;
103 let mut client = FlightSqlServiceClient::new_from_inner(service_client);
104 let _ = client.handshake("flight_username", "test").await?;
105
106 let mut statement = client
107 .prepare("SELECT VERSION();".to_string(), None)
108 .await?;
109 let flight_info = statement.execute().await?;
110
111 let ticket = flight_info.endpoint[0]
112 .ticket
113 .as_ref()
114 .expect("Ticket not present")
115 .clone();
116 let flight_data = client.do_get(ticket).await?;
117 let batches: Vec<_> = flight_data.try_collect().await?;
118 let batch = batches.first().expect("batch 0 not present");
119 let array = batch.columns().first().expect("column not present");
120 let data = array.to_data();
121 let buffers = data.buffers();
122 let buffer = buffers.get(1).expect("buffer not present");
123 let values = buffer.to_vec();
124 let version = String::from_utf8(values)?;
125
126 assert_eq!(version, "v1.0.0");
127 Ok(())
128 }
129}