use crate::column::Column;
use crate::errors::SparkError;
use crate::expressions::{ToExpr, ToFilterExpr, ToVecExpr};
use crate::group::GroupedData;
use crate::plan::LogicalPlanBuilder;
use crate::session::SparkSession;
use crate::storage;
pub use crate::readwriter::{DataFrameReader, DataFrameWriter, DataFrameWriterV2};
pub use crate::streaming::{DataStreamReader, DataStreamWriter, OutputMode, StreamingQuery};
use crate::spark;
pub use spark::aggregate::GroupType;
pub use spark::analyze_plan_request::explain::ExplainMode;
pub use spark::join::JoinType;
use spark::relation::RelType;
pub use spark::write_operation::SaveMode;
use arrow::array::PrimitiveArray;
use arrow::datatypes::{DataType, Float64Type};
use arrow::json::ArrayWriter;
use arrow::record_batch::RecordBatch;
use arrow::util::pretty;
#[cfg(feature = "datafusion")]
use datafusion::execution::context::SessionContext;
#[cfg(feature = "polars")]
use polars;
#[cfg(feature = "polars")]
use polars_arrow;
#[derive(Clone, Debug)]
pub struct DataFrame {
pub(crate) spark_session: Box<SparkSession>,
pub(crate) plan: LogicalPlanBuilder,
}
impl DataFrame {
pub fn new(spark_session: SparkSession, plan: LogicalPlanBuilder) -> DataFrame {
DataFrame {
spark_session: Box::new(spark_session),
plan,
}
}
fn check_same_session(&self, other: &DataFrame) -> Result<(), SparkError> {
if self.spark_session.session_id() != other.spark_session.session_id() {
return Err(SparkError::AnalysisException(
"Spark Session is not the same!".to_string(),
));
};
Ok(())
}
pub fn agg<T: ToVecExpr>(self, exprs: T) -> DataFrame {
self.group_by::<Column>(None).agg(exprs)
}
pub fn alias(self, alias: &str) -> DataFrame {
let plan = self.plan.alias(alias);
DataFrame {
spark_session: self.spark_session,
plan,
}
}
pub async fn cache(self) -> DataFrame {
self.persist(storage::StorageLevel::MemoryAndDiskDeser)
.await
}
pub fn coalesce(self, num_partitions: u32) -> DataFrame {
self.repartition(num_partitions, Some(false))
}
pub async fn count(self) -> Result<i64, SparkError> {
let res = self.group_by::<Column>(None).count().collect().await?;
let col = res.column(0);
let data: &arrow::array::Int64Array = match col.data_type() {
arrow::datatypes::DataType::Int64 => col.as_any().downcast_ref().unwrap(),
_ => unimplemented!("only Utf8 data types are currently handled currently."),
};
Ok(data.value(0))
}
pub fn col_regex(self, col_name: &str) -> Column {
let expr = spark::Expression {
expr_type: Some(spark::expression::ExprType::UnresolvedRegex(
spark::expression::UnresolvedRegex {
col_name: col_name.to_string(),
plan_id: Some(self.plan.plan_id()),
},
)),
};
Column::from(expr)
}
pub async fn collect(self) -> Result<RecordBatch, SparkError> {
let plan = LogicalPlanBuilder::plan_root(self.plan);
self.spark_session.client().to_arrow(plan).await
}
pub async fn columns(self) -> Result<Vec<String>, SparkError> {
let schema = self.schema().await?;
let struct_val = schema.kind.expect("Unwrapped an empty schema");
let cols = match struct_val {
spark::data_type::Kind::Struct(val) => val
.fields
.iter()
.map(|field| field.name.to_string())
.collect(),
_ => unimplemented!("Unexpected schema response"),
};
Ok(cols)
}
pub async fn corr(self, col1: &str, col2: &str) -> Result<f64, SparkError> {
let plan = self.plan.corr(col1, col2);
let df = DataFrame {
spark_session: self.spark_session,
plan,
};
let result = df.collect().await?;
let col = result.column(0);
let data: &PrimitiveArray<Float64Type> = match col.data_type() {
DataType::Float64 => col
.as_any()
.downcast_ref()
.expect("failed to unwrap result"),
_ => panic!("Expected Float64 in response type"),
};
Ok(data.value(0))
}
pub async fn cov(self, col1: &str, col2: &str) -> Result<f64, SparkError> {
let plan = self.plan.cov(col1, col2);
let df = DataFrame {
spark_session: self.spark_session,
plan,
};
let result = df.collect().await?;
let col = result.column(0);
let data: &PrimitiveArray<Float64Type> = match col.data_type() {
DataType::Float64 => col
.as_any()
.downcast_ref()
.expect("failed to unwrap result"),
_ => panic!("Expected Float64 in response type"),
};
Ok(data.value(0))
}
pub async fn create_temp_view(self, name: &str) -> Result<(), SparkError> {
self.create_view_cmd(name, false, false).await
}
pub async fn create_global_temp_view(self, name: &str) -> Result<(), SparkError> {
self.create_view_cmd(name, true, false).await
}
pub async fn create_or_replace_global_temp_view(self, name: &str) -> Result<(), SparkError> {
self.create_view_cmd(name, true, true).await
}
pub async fn create_or_replace_temp_view(self, name: &str) -> Result<(), SparkError> {
self.create_view_cmd(name, false, true).await
}
async fn create_view_cmd(
self,
name: &str,
is_global: bool,
replace: bool,
) -> Result<(), SparkError> {
let command_type =
spark::command::CommandType::CreateDataframeView(spark::CreateDataFrameViewCommand {
input: Some(self.plan.relation()),
name: name.to_string(),
is_global,
replace,
});
let plan = LogicalPlanBuilder::plan_cmd(command_type);
self.spark_session.client().execute_command(plan).await?;
Ok(())
}
pub fn cross_join(self, other: DataFrame) -> DataFrame {
let plan = self
.plan
.join(other.plan, None::<&str>, JoinType::Cross, vec![]);
DataFrame {
spark_session: self.spark_session,
plan,
}
}
pub fn crosstab(self, col1: &str, col2: &str) -> DataFrame {
let plan = self.plan.crosstab(col1, col2);
DataFrame {
spark_session: self.spark_session,
plan,
}
}
pub fn cube<T: ToVecExpr>(self, cols: T) -> GroupedData {
GroupedData::new(self, GroupType::Cube, cols.to_vec_expr(), None, None)
}
pub fn describe<'a, I>(self, cols: Option<I>) -> DataFrame
where
I: IntoIterator<Item = &'a str> + std::default::Default,
{
let plan = self.plan.describe(cols);
DataFrame {
spark_session: self.spark_session,
plan,
}
}
pub fn distinct(self) -> DataFrame {
let plan = self.plan.distinct();
DataFrame {
spark_session: self.spark_session,
plan,
}
}
pub fn drop<T: ToVecExpr>(self, cols: T) -> DataFrame {
let plan = self.plan.drop(cols);
DataFrame {
spark_session: self.spark_session,
plan,
}
}
pub fn drop_duplicates(self, cols: Option<Vec<&str>>) -> DataFrame {
let plan = self.plan.drop_duplicates(cols);
DataFrame {
spark_session: self.spark_session,
plan,
}
}
pub fn dropna(self, how: &str, threshold: Option<i32>, subset: Option<Vec<&str>>) -> DataFrame {
let plan = self.plan.dropna(how, threshold, subset);
DataFrame {
spark_session: self.spark_session,
plan,
}
}
pub async fn dtypes(self) -> Result<Vec<(String, spark::data_type::Kind)>, SparkError> {
let schema = self.schema().await?;
let struct_val = schema.kind.expect("unwrapped an empty schema");
let dtypes = match struct_val {
spark::data_type::Kind::Struct(val) => val
.fields
.iter()
.map(|field| {
(
field.name.to_string(),
field.data_type.clone().unwrap().kind.unwrap(),
)
})
.collect(),
_ => unimplemented!("Unexpected schema response"),
};
Ok(dtypes)
}
pub fn except_all(self, other: DataFrame) -> DataFrame {
self.check_same_session(&other).unwrap();
let plan = self.plan.except_all(other.plan);
DataFrame {
spark_session: self.spark_session,
plan,
}
}
pub async fn explain(self, mode: Option<ExplainMode>) -> Result<String, SparkError> {
let explain_mode = match mode {
Some(mode) => mode,
None => ExplainMode::Simple,
};
let plan = LogicalPlanBuilder::plan_root(self.plan);
let analyze =
spark::analyze_plan_request::Analyze::Explain(spark::analyze_plan_request::Explain {
plan: Some(plan),
explain_mode: explain_mode.into(),
});
let mut client = self.spark_session.client();
let explain = client.analyze(analyze).await?.explain()?;
println!("{}", explain);
Ok(explain)
}
pub fn filter<T: ToFilterExpr>(self, condition: T) -> DataFrame {
let plan = self.plan.filter(condition);
DataFrame {
spark_session: self.spark_session,
plan,
}
}
pub async fn first(self) -> Result<RecordBatch, SparkError> {
self.head(None).await
}
pub fn freq_items<'a, I>(self, cols: I, support: Option<f64>) -> DataFrame
where
I: IntoIterator<Item = &'a str>,
{
let plan = self.plan.freq_items(cols, support);
DataFrame {
spark_session: self.spark_session,
plan,
}
}
pub fn group_by<T: ToVecExpr>(self, cols: Option<T>) -> GroupedData {
let grouping_cols = match cols {
Some(cols) => cols.to_vec_expr(),
None => vec![],
};
GroupedData::new(self, GroupType::Groupby, grouping_cols, None, None)
}
pub async fn head(self, n: Option<i32>) -> Result<RecordBatch, SparkError> {
self.limit(n.unwrap_or(1)).collect().await
}
pub fn hint<T: ToVecExpr>(self, name: &str, parameters: Option<T>) -> DataFrame {
let plan = self.plan.hint(name, parameters);
DataFrame {
spark_session: self.spark_session,
plan,
}
}
pub async fn input_files(self) -> Result<Vec<String>, SparkError> {
let input_files = spark::analyze_plan_request::Analyze::InputFiles(
spark::analyze_plan_request::InputFiles {
plan: Some(LogicalPlanBuilder::plan_root(self.plan)),
},
);
let mut client = self.spark_session.client();
client.analyze(input_files).await?.input_files()
}
pub fn intersect(self, other: DataFrame) -> DataFrame {
self.check_same_session(&other).unwrap();
let plan = self.plan.intersect(other.plan);
DataFrame {
spark_session: self.spark_session,
plan,
}
}
pub fn intersect_all(self, other: DataFrame) -> DataFrame {
self.check_same_session(&other).unwrap();
let plan = self.plan.intersect_all(other.plan);
DataFrame {
spark_session: self.spark_session,
plan,
}
}
pub async fn is_empty(self) -> Result<bool, SparkError> {
let val = &self.select("*").limit(1).collect().await?;
Ok(val.num_rows() == 0)
}
pub async fn is_streaming(self) -> Result<bool, SparkError> {
let is_streaming = spark::analyze_plan_request::Analyze::IsStreaming(
spark::analyze_plan_request::IsStreaming {
plan: Some(LogicalPlanBuilder::plan_root(self.plan)),
},
);
let mut client = self.spark_session.client();
client.analyze(is_streaming).await?.is_streaming()
}
pub fn join<T: ToExpr>(self, other: DataFrame, on: Option<T>, how: JoinType) -> DataFrame {
let plan = self.plan.join(other.plan, on, how, vec![]);
DataFrame {
spark_session: self.spark_session,
plan,
}
}
pub fn limit(self, limit: i32) -> DataFrame {
let plan = self.plan.limit(limit);
DataFrame {
spark_session: self.spark_session,
plan,
}
}
pub fn melt<I, K>(
self,
ids: I,
values: Option<K>,
variable_column_name: &str,
value_column_name: &str,
) -> DataFrame
where
I: ToVecExpr,
K: ToVecExpr,
{
self.unpivot(ids, values, variable_column_name, value_column_name)
}
pub fn offset(self, num: i32) -> DataFrame {
let plan = self.plan.offset(num);
DataFrame {
spark_session: self.spark_session,
plan,
}
}
pub fn order_by<I>(self, cols: I) -> DataFrame
where
I: IntoIterator<Item = Column>,
{
let plan = self.plan.sort(cols);
DataFrame {
spark_session: self.spark_session,
plan,
}
}
pub async fn persist(self, storage_level: storage::StorageLevel) -> DataFrame {
let analyze =
spark::analyze_plan_request::Analyze::Persist(spark::analyze_plan_request::Persist {
relation: Some(self.plan.clone().relation()),
storage_level: Some(storage_level.into()),
});
let mut client = self.spark_session.clone().client();
client.analyze(analyze).await.unwrap();
let plan = self.plan;
DataFrame {
spark_session: self.spark_session,
plan,
}
}
pub async fn print_schema(self, level: Option<i32>) -> Result<String, SparkError> {
let tree_string = spark::analyze_plan_request::Analyze::TreeString(
spark::analyze_plan_request::TreeString {
plan: Some(LogicalPlanBuilder::plan_root(self.plan)),
level,
},
);
let mut client = self.spark_session.client();
client.analyze(tree_string).await?.tree_string()
}
pub fn repartition(self, num_partitions: u32, shuffle: Option<bool>) -> DataFrame {
let plan = self.plan.repartition(num_partitions, shuffle);
DataFrame {
spark_session: self.spark_session,
plan,
}
}
pub fn rollup<T: ToVecExpr>(self, cols: T) -> GroupedData {
GroupedData::new(self, GroupType::Rollup, cols.to_vec_expr(), None, None)
}
pub async fn same_semantics(self, other: DataFrame) -> Result<bool, SparkError> {
let target_plan = Some(LogicalPlanBuilder::plan_root(self.plan));
let other_plan = Some(LogicalPlanBuilder::plan_root(other.plan));
let same_semantics = spark::analyze_plan_request::Analyze::SameSemantics(
spark::analyze_plan_request::SameSemantics {
target_plan,
other_plan,
},
);
let mut client = self.spark_session.client();
client.analyze(same_semantics).await?.same_semantics()
}
pub fn sample(
self,
lower_bound: f64,
upper_bound: f64,
with_replacement: Option<bool>,
seed: Option<i64>,
) -> DataFrame {
let plan = self
.plan
.sample(lower_bound, upper_bound, with_replacement, seed);
DataFrame {
spark_session: self.spark_session,
plan,
}
}
pub async fn schema(self) -> Result<spark::DataType, SparkError> {
let plan = LogicalPlanBuilder::plan_root(self.plan);
let schema =
spark::analyze_plan_request::Analyze::Schema(spark::analyze_plan_request::Schema {
plan: Some(plan),
});
let mut client = self.spark_session.client();
client.analyze(schema).await?.schema()
}
pub fn select<T: ToVecExpr>(self, cols: T) -> DataFrame {
let plan = self.plan.select(cols);
DataFrame {
spark_session: self.spark_session,
plan,
}
}
pub fn select_expr<'a, I>(self, cols: I) -> DataFrame
where
I: IntoIterator<Item = &'a str>,
{
let plan = self.plan.select_expr(cols);
DataFrame {
spark_session: self.spark_session,
plan,
}
}
pub async fn semantic_hash(self) -> Result<i32, SparkError> {
let plan = LogicalPlanBuilder::plan_root(self.plan);
let semantic_hash = spark::analyze_plan_request::Analyze::SemanticHash(
spark::analyze_plan_request::SemanticHash { plan: Some(plan) },
);
let mut client = self.spark_session.client();
client.analyze(semantic_hash).await?.semantic_hash()
}
pub async fn show(
self,
num_rows: Option<i32>,
truncate: Option<i32>,
vertical: Option<bool>,
) -> Result<(), SparkError> {
let show_expr = RelType::ShowString(Box::new(spark::ShowString {
input: self.plan.relation_input(),
num_rows: num_rows.unwrap_or(10),
truncate: truncate.unwrap_or(0),
vertical: vertical.unwrap_or(false),
}));
let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(show_expr));
let rows = self.spark_session.client().to_arrow(plan).await?;
Ok(pretty::print_batches(&[rows])?)
}
pub fn sort<I>(self, cols: I) -> DataFrame
where
I: IntoIterator<Item = Column>,
{
let plan = self.plan.sort(cols);
DataFrame {
spark_session: self.spark_session,
plan,
}
}
pub fn spark_session(self) -> Box<SparkSession> {
self.spark_session
}
pub async fn storage_level(self) -> Result<storage::StorageLevel, SparkError> {
let storage_level = spark::analyze_plan_request::Analyze::GetStorageLevel(
spark::analyze_plan_request::GetStorageLevel {
relation: Some(self.plan.relation()),
},
);
let mut client = self.spark_session.client();
let storage = client.analyze(storage_level).await?.get_storage_level();
Ok(storage?.into())
}
pub fn subtract(self, other: DataFrame) -> DataFrame {
self.check_same_session(&other).unwrap();
let plan = self.plan.substract(other.plan);
DataFrame {
spark_session: self.spark_session,
plan,
}
}
pub async fn tail(self, limit: i32) -> Result<RecordBatch, SparkError> {
let limit_expr = RelType::Tail(Box::new(spark::Tail {
input: self.plan.relation_input(),
limit,
}));
let plan = LogicalPlanBuilder::from(limit_expr);
let df = DataFrame {
spark_session: self.spark_session,
plan,
};
df.collect().await
}
pub async fn take(self, n: i32) -> Result<RecordBatch, SparkError> {
self.limit(n).collect().await
}
pub fn to_df<'a, I>(self, cols: I) -> DataFrame
where
I: IntoIterator<Item = &'a str>,
{
let plan = self.plan.to_df(cols);
DataFrame {
spark_session: self.spark_session,
plan,
}
}
pub async fn to_json(self) -> Result<String, SparkError> {
if self.clone().is_empty().await? {
return Ok(String::from("[]"));
};
let batches = self.collect().await?;
let buf = Vec::new();
let mut writer = ArrayWriter::new(buf);
writer.write_batches(&[&batches])?;
writer.finish()?;
let res = String::from_utf8_lossy(&writer.into_inner()).into_owned();
Ok(res)
}
#[cfg(feature = "datafusion")]
#[cfg(any(feature = "default", feature = "datafusion"))]
pub async fn to_datafusion(
self,
ctx: &SessionContext,
) -> Result<datafusion::dataframe::DataFrame, SparkError> {
let batch = self.collect().await?;
Ok(ctx.read_batch(batch)?)
}
#[cfg(feature = "polars")]
#[cfg(any(feature = "default", feature = "polars"))]
pub async fn to_polars(self) -> Result<polars::frame::DataFrame, SparkError> {
let batch = self.collect().await?;
let schema = batch.schema();
let mut columns = Vec::with_capacity(batch.num_columns());
for (i, column) in batch.columns().iter().enumerate() {
let arrow = Box::<dyn polars_arrow::array::Array>::from(&**column);
columns.push(polars::series::Series::from_arrow(
schema.fields().get(i).unwrap().name(),
arrow,
)?);
}
Ok(polars::frame::DataFrame::from_iter(columns))
}
pub fn transform<F>(self, mut func: F) -> DataFrame
where
F: FnMut(DataFrame) -> DataFrame,
{
func(self)
}
pub fn union(self, other: DataFrame) -> DataFrame {
self.check_same_session(&other).unwrap();
let plan = self.plan.union_all(other.plan);
DataFrame {
spark_session: self.spark_session,
plan,
}
}
pub fn union_all(self, other: DataFrame) -> DataFrame {
self.check_same_session(&other).unwrap();
let plan = self.plan.union_all(other.plan);
DataFrame {
spark_session: self.spark_session,
plan,
}
}
pub fn union_by_name(self, other: DataFrame, allow_missing_columns: Option<bool>) -> DataFrame {
self.check_same_session(&other).unwrap();
let plan = self.plan.union_by_name(other.plan, allow_missing_columns);
DataFrame {
spark_session: self.spark_session,
plan,
}
}
pub async fn unpersist(self, blocking: Option<bool>) -> DataFrame {
let unpersist = spark::analyze_plan_request::Analyze::Unpersist(
spark::analyze_plan_request::Unpersist {
relation: Some(self.plan.clone().relation()),
blocking,
},
);
let mut client = self.spark_session.clone().client();
client.analyze(unpersist).await.unwrap();
DataFrame {
spark_session: self.spark_session,
plan: self.plan,
}
}
pub fn unpivot<I, K>(
self,
ids: I,
values: Option<K>,
variable_column_name: &str,
value_column_name: &str,
) -> DataFrame
where
I: ToVecExpr,
K: ToVecExpr,
{
let ids = ids.to_vec_expr();
let values = values.map(|val| val.to_vec_expr());
let plan = self
.plan
.unpivot(ids, values, variable_column_name, value_column_name);
DataFrame {
spark_session: self.spark_session,
plan,
}
}
pub fn with_column(self, col_name: &str, col: Column) -> DataFrame {
let plan = self.plan.with_column(col_name, col);
DataFrame {
spark_session: self.spark_session,
plan,
}
}
pub fn with_columns<I, K>(self, col_map: I) -> DataFrame
where
I: IntoIterator<Item = (K, Column)>,
K: ToString,
{
let plan = self.plan.with_columns(col_map);
DataFrame {
spark_session: self.spark_session,
plan,
}
}
pub fn with_columns_renamed<I, K, V>(self, cols: I) -> DataFrame
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<str>,
V: AsRef<str>,
{
let plan = self.plan.with_columns_renamed(cols);
DataFrame {
spark_session: self.spark_session,
plan,
}
}
pub fn write(self) -> DataFrameWriter {
DataFrameWriter::new(self)
}
pub fn write_to(self, table: &str) -> DataFrameWriterV2 {
DataFrameWriterV2::new(self, table)
}
pub fn write_stream(self) -> DataStreamWriter {
DataStreamWriter::new(self)
}
}
#[cfg(test)]
mod tests {
use arrow::{
array::{ArrayRef, Float32Array, Float64Array, Int64Array, StringArray},
datatypes::{DataType, Field, Schema},
record_batch::RecordBatch,
};
use std::{collections::HashMap, sync::Arc};
use super::*;
use crate::functions::*;
use crate::SparkSessionBuilder;
async fn setup() -> SparkSession {
println!("SparkSession Setup");
let connection =
"sc://127.0.0.1:15002/;user_id=rust_df;session_id=b5714cb4-6bb4-4c02-90b1-b9b93c70b323";
SparkSessionBuilder::remote(connection)
.build()
.await
.unwrap()
}
fn mock_data() -> RecordBatch {
let name: ArrayRef = Arc::new(StringArray::from(vec!["Tom", "Alice", "Bob"]));
let age: ArrayRef = Arc::new(Int64Array::from(vec![14, 23, 16]));
RecordBatch::try_from_iter(vec![("name", name), ("age", age)]).unwrap()
}
#[tokio::test]
async fn test_df_alias() -> Result<(), SparkError> {
let spark = setup().await;
let data = mock_data();
let df = spark.create_dataframe(&data)?;
let df_as1 = df.clone().alias("df_as1");
let df_as2 = df.alias("df_as2");
let condition = Some(col("df_as1.name").eq(col("df_as2.name")));
let joined_df = df_as1.join(df_as2, condition, JoinType::Inner);
let name: ArrayRef = Arc::new(StringArray::from(vec!["Tom", "Alice", "Bob"]));
let age: ArrayRef = Arc::new(Int64Array::from(vec![14, 23, 16]));
let expected =
RecordBatch::try_from_iter(vec![("name", name.clone()), ("name", name), ("age", age)])?;
let res = joined_df
.clone()
.select(["df_as1.name", "df_as2.name", "df_as2.age"])
.sort([asc("df_as1.name")])
.collect()
.await?;
assert_eq!(expected, res);
Ok(())
}
#[tokio::test]
async fn test_df_cache() -> Result<(), SparkError> {
let spark = setup().await;
let df = spark.range(None, 2, 1, None);
df.clone().cache().await;
let exp = df.clone().explain(None).await?;
assert!(exp.contains("InMemoryTableScan"));
Ok(())
}
#[tokio::test]
async fn test_df_coalesce() -> Result<(), SparkError> {
let spark = setup().await;
let val = spark
.range(None, 10, 1, Some(5))
.coalesce(1)
.select(spark_partition_id().alias("partition"))
.distinct()
.collect()
.await?;
assert_eq!(1, val.num_rows());
Ok(())
}
#[tokio::test]
async fn test_df_colregex() -> Result<(), SparkError> {
let spark = setup().await;
let col1: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"]));
let col2: ArrayRef = Arc::new(Int64Array::from(vec![1, 2, 3]));
let data = RecordBatch::try_from_iter(vec![("col1", col1), ("col2", col2)])?;
let df = spark.create_dataframe(&data)?;
let res = df
.clone()
.select(df.col_regex("`(Col1)?+.+`"))
.columns()
.await?;
assert_eq!(vec!["col2".to_string(),], res);
Ok(())
}
#[tokio::test]
async fn test_df_columns() -> Result<(), SparkError> {
let spark = setup().await;
let age: ArrayRef = Arc::new(Int64Array::from(vec![14, 23, 16]));
let name: ArrayRef = Arc::new(StringArray::from(vec!["Tom", "Alice", "Bob"]));
let state: ArrayRef = Arc::new(StringArray::from(vec!["CA", "NY", "TX"]));
let data =
RecordBatch::try_from_iter(vec![("age", age), ("name", name), ("state", state)])?;
let df = spark.create_dataframe(&data)?;
let cols = df.clone().columns().await?;
assert_eq!(
vec!["age".to_string(), "name".to_string(), "state".to_string()],
cols
);
let select_cols: Vec<String> = cols.into_iter().filter(|c| c != "age").collect();
let cols = df.select(select_cols.clone()).columns().await?;
assert_eq!(select_cols, cols);
Ok(())
}
#[tokio::test]
async fn test_df_corr() -> Result<(), SparkError> {
let spark = setup().await;
let c1: ArrayRef = Arc::new(Int64Array::from(vec![1, 10, 19]));
let c2: ArrayRef = Arc::new(Int64Array::from(vec![12, 1, 8]));
let data = RecordBatch::try_from_iter(vec![("c1", c1), ("c2", c2)])?;
let val = spark.create_dataframe(&data)?.corr("c1", "c2").await?;
assert_eq!(-0.3592106040535498_f64, val);
Ok(())
}
#[tokio::test]
async fn test_df_count() -> Result<(), SparkError> {
let spark = setup().await;
let data = mock_data();
let df = spark.create_dataframe(&data)?;
assert_eq!(3, df.count().await?);
Ok(())
}
#[tokio::test]
async fn test_df_cov() -> Result<(), SparkError> {
let spark = setup().await;
let c1: ArrayRef = Arc::new(Int64Array::from(vec![1, 10, 19]));
let c2: ArrayRef = Arc::new(Int64Array::from(vec![12, 1, 8]));
let data = RecordBatch::try_from_iter(vec![("c1", c1), ("c2", c2)])?;
let val = spark.create_dataframe(&data)?.cov("c1", "c2").await?;
assert_eq!(-18.0_f64, val);
let small: ArrayRef = Arc::new(Int64Array::from(vec![11, 10, 9]));
let big: ArrayRef = Arc::new(Int64Array::from(vec![12, 11, 10]));
let data = RecordBatch::try_from_iter(vec![("small", small), ("big", big)])?;
let val = spark.create_dataframe(&data)?.cov("small", "big").await?;
assert_eq!(1.0_f64, val);
Ok(())
}
#[tokio::test]
async fn test_df_view() -> Result<(), SparkError> {
let spark = setup().await;
let data = mock_data();
spark
.create_dataframe(&data)?
.create_or_replace_global_temp_view("people")
.await?;
let rows = spark
.sql("SELECT * FROM global_temp.people")
.await?
.collect()
.await?;
assert_eq!(rows, data);
Ok(())
}
#[tokio::test]
async fn test_df_crosstab() -> Result<(), SparkError> {
let spark = setup().await;
let df = spark
.range(None, 5, 1, Some(1))
.select(vec![col("id").alias("col1"), col("id").alias("col2")])
.crosstab("col1", "col2");
let res = df.clone().collect().await?;
assert!(df.columns().await?.contains(&"col1_col2".to_string()));
assert_eq!(6, res.num_columns());
Ok(())
}
#[tokio::test]
async fn test_df_crossjoin() -> Result<(), SparkError> {
let spark = setup().await;
let name: ArrayRef = Arc::new(StringArray::from(vec!["Tom", "Alice", "Bob"]));
let age: ArrayRef = Arc::new(Int64Array::from(vec![14, 23, 16]));
let height: ArrayRef = Arc::new(Int64Array::from(vec![60, 55, 63]));
let data = RecordBatch::try_from_iter(vec![("name", name.clone()), ("age", age)])?;
let data2 = RecordBatch::try_from_iter(vec![("name", name), ("height", height)])?;
let df = spark.create_dataframe(&data)?;
let df2 = spark.create_dataframe(&data2)?;
let rows = df
.cross_join(df2.select("height"))
.select(["age", "name", "height"])
.collect()
.await?;
let name: ArrayRef = Arc::new(StringArray::from(vec![
"Tom", "Tom", "Tom", "Alice", "Alice", "Alice", "Bob", "Bob", "Bob",
]));
let age: ArrayRef = Arc::new(Int64Array::from(vec![14, 14, 14, 23, 23, 23, 16, 16, 16]));
let height: ArrayRef = Arc::new(Int64Array::from(vec![60, 55, 63, 60, 55, 63, 60, 55, 63]));
let data =
RecordBatch::try_from_iter(vec![("age", age), ("name", name), ("height", height)])?;
assert_eq!(data, rows);
Ok(())
}
#[tokio::test]
async fn test_df_describe() -> Result<(), SparkError> {
let spark = setup().await;
let data = mock_data();
let res = spark
.create_dataframe(&data)?
.describe(Some(["age"]))
.collect()
.await?;
let summary: ArrayRef = Arc::new(StringArray::from(vec![
"count", "mean", "stddev", "min", "max",
]));
let age: ArrayRef = Arc::new(StringArray::from(vec![
"3",
"17.666666666666668",
"4.725815626252608",
"14",
"23",
]));
let schema = Schema::new(vec![
Field::new("summary", DataType::Utf8, true),
Field::new("age", DataType::Utf8, true),
]);
let expected = RecordBatch::try_new(Arc::new(schema), vec![summary, age])?;
assert_eq!(expected, res);
Ok(())
}
#[tokio::test]
async fn test_df_distinct() -> Result<(), SparkError> {
let spark = setup().await;
let data = mock_data();
let val = spark.create_dataframe(&data)?.distinct().count().await?;
assert_eq!(3_i64, val);
Ok(())
}
#[tokio::test]
async fn test_df_drop() -> Result<(), SparkError> {
let spark = setup().await;
let data = mock_data();
let df = spark.create_dataframe(&data)?;
let cols = df.clone().drop("age").columns().await?;
assert_eq!(vec![String::from("name")], cols);
let cols = df
.clone()
.with_column("val", lit(1))
.drop([col("age"), col("name")])
.columns()
.await?;
assert_eq!(vec![String::from("val")], cols);
Ok(())
}
#[tokio::test]
async fn test_df_drop_duplicates() -> Result<(), SparkError> {
let spark = setup().await;
let name: ArrayRef = Arc::new(StringArray::from(vec!["Alice", "Alice", "Alice"]));
let age: ArrayRef = Arc::new(Int64Array::from(vec![5, 5, 10]));
let height: ArrayRef = Arc::new(Int64Array::from(vec![80, 80, 80]));
let data =
RecordBatch::try_from_iter(vec![("name", name), ("age", age), ("height", height)])?;
let df = spark.create_dataframe(&data)?;
let res = df.clone().drop_duplicates(None).count().await?;
assert_eq!(res, 2);
let res = df
.drop_duplicates(Some(vec!["name", "height"]))
.count()
.await?;
assert_eq!(res, 1);
Ok(())
}
#[tokio::test]
async fn test_df_dropna() -> Result<(), SparkError> {
let spark = setup().await;
let name: ArrayRef = Arc::new(StringArray::from(vec![
Some("Alice"),
Some("Bob"),
Some("Tom"),
None,
]));
let age: ArrayRef = Arc::new(Int64Array::from(vec![Some(10), Some(5), None, None]));
let height: ArrayRef = Arc::new(Int64Array::from(vec![Some(80), None, None, None]));
let schema = Schema::new(vec![
Field::new("name", DataType::Utf8, true),
Field::new("age", DataType::Int64, true),
Field::new("height", DataType::Int64, true),
]);
let data = RecordBatch::try_new(Arc::new(schema), vec![name, age, height])?;
let df = spark.create_dataframe(&data)?;
let res = df.clone().dropna("any", None, None).count().await?;
assert_eq!(res, 1);
let res = df.clone().dropna("all", None, None).count().await?;
assert_eq!(res, 3);
let res = df
.clone()
.dropna("any", None, Some(vec!["name"]))
.count()
.await?;
assert_eq!(res, 3);
Ok(())
}
#[tokio::test]
async fn test_df_dtypes() -> Result<(), SparkError> {
let spark = setup().await;
let data = mock_data();
let df = spark.create_dataframe(&data)?;
let res = df.dtypes().await?;
assert!(&res.iter().any(|(col, _kind)| col == "name"));
Ok(())
}
#[tokio::test]
async fn test_df_except_all() -> Result<(), SparkError> {
let spark = setup().await;
let c1: ArrayRef = Arc::new(Int64Array::from(vec![1, 1, 10, 19]));
let c2: ArrayRef = Arc::new(Int64Array::from(vec![1, 1, 2, 8]));
let data = RecordBatch::try_from_iter(vec![("c1", c1), ("c2", c2)])?;
let c1: ArrayRef = Arc::new(Int64Array::from(vec![1, 19]));
let c2: ArrayRef = Arc::new(Int64Array::from(vec![1, 8]));
let data2 = RecordBatch::try_from_iter(vec![("c1", c1), ("c2", c2)])?;
let df1 = spark.create_dataframe(&data)?;
let df2 = spark.create_dataframe(&data2)?;
let output = df1.except_all(df2).collect().await?;
let c1: ArrayRef = Arc::new(Int64Array::from(vec![1, 10]));
let c2: ArrayRef = Arc::new(Int64Array::from(vec![1, 2]));
let expected = RecordBatch::try_from_iter(vec![("c1", c1), ("c2", c2)])?;
assert_eq!(output, expected);
Ok(())
}
#[tokio::test]
async fn test_df_explain() -> Result<(), SparkError> {
let spark = setup().await;
let data = mock_data();
let df = spark.create_dataframe(&data)?;
let val = df.explain(None).await?;
assert!(val.contains("== Physical Plan =="));
Ok(())
}
#[tokio::test]
async fn test_df_filter() -> Result<(), SparkError> {
let spark = setup().await;
let data = mock_data();
let df = spark.create_dataframe(&data)?;
let output = df.clone().filter("age > 20").count().await?;
assert_eq!(output, 1);
let output = df.filter("age == 16").count().await?;
assert_eq!(output, 1);
Ok(())
}
#[tokio::test]
async fn test_df_first() -> Result<(), SparkError> {
let spark = setup().await;
let data = mock_data();
let df = spark.create_dataframe(&data)?;
let val = df.first().await?;
let name: ArrayRef = Arc::new(StringArray::from(vec!["Tom"]));
let age: ArrayRef = Arc::new(Int64Array::from(vec![14]));
let expected = RecordBatch::try_from_iter(vec![("name", name), ("age", age)])?;
assert_eq!(val, expected);
Ok(())
}
#[tokio::test]
async fn test_df_group_by() -> Result<(), SparkError> {
let spark = setup().await;
let data = mock_data();
let df = spark.create_dataframe(&data)?;
let val = df
.clone()
.group_by::<Column>(None)
.avg("age")
.collect()
.await?;
let age: ArrayRef = Arc::new(Float64Array::from(vec![17.666666666666668]));
let schema = Schema::new(vec![Field::new("avg(age)", DataType::Float64, true)]);
let expected = RecordBatch::try_new(Arc::new(schema), vec![age])?;
assert_eq!(val, expected);
let val = df
.clone()
.group_by::<Column>(None)
.max("age")
.collect()
.await?;
let age: ArrayRef = Arc::new(Int64Array::from(vec![23]));
let schema = Schema::new(vec![Field::new("max(age)", DataType::Int64, true)]);
let expected = RecordBatch::try_new(Arc::new(schema), vec![age])?;
assert_eq!(val, expected);
let val = df
.clone()
.group_by::<Column>(None)
.sum("age")
.collect()
.await?;
let age: ArrayRef = Arc::new(Int64Array::from(vec![53]));
let schema = Schema::new(vec![Field::new("sum(age)", DataType::Int64, true)]);
let expected = RecordBatch::try_new(Arc::new(schema), vec![age])?;
assert_eq!(val, expected);
Ok(())
}
#[tokio::test]
async fn test_df_head() -> Result<(), SparkError> {
let spark = setup().await;
let data = mock_data();
let df = spark.create_dataframe(&data)?;
let val = df.head(None).await?;
let name: ArrayRef = Arc::new(StringArray::from(vec!["Tom"]));
let age: ArrayRef = Arc::new(Int64Array::from(vec![14]));
let expected = RecordBatch::try_from_iter(vec![("name", name), ("age", age)])?;
assert_eq!(val, expected);
Ok(())
}
#[tokio::test]
async fn test_df_hint() -> Result<(), SparkError> {
let spark = setup().await;
let data = mock_data();
let df = spark.create_dataframe(&data)?.alias("df1");
let df2 = spark.create_dataframe(&data)?.alias("df2");
let df = df.join(
df2.hint::<Column>("broadcast", None),
Some(col("df1.name").eq("df2.name")),
JoinType::Inner,
);
let plan = df.explain(Some(ExplainMode::Extended)).await?;
assert!(plan.contains("UnresolvedHint broadcast"));
Ok(())
}
#[tokio::test]
async fn test_df_input_files() -> Result<(), SparkError> {
let spark = setup().await;
let path = ["/opt/spark/work-dir/datasets/people.csv"];
let df = spark
.read()
.format("csv")
.option("header", "True")
.option("delimiter", ";")
.load(path)?;
let res = df.input_files().await?;
assert_eq!(res.len(), 1);
Ok(())
}
#[tokio::test]
async fn test_df_intersect() -> Result<(), SparkError> {
let spark = setup().await;
let c1: ArrayRef = Arc::new(Int64Array::from(vec![1, 1, 10, 19]));
let c2: ArrayRef = Arc::new(Int64Array::from(vec![1, 1, 2, 8]));
let data = RecordBatch::try_from_iter(vec![("c1", c1), ("c2", c2)])?;
let c1: ArrayRef = Arc::new(Int64Array::from(vec![1, 19]));
let c2: ArrayRef = Arc::new(Int64Array::from(vec![1, 8]));
let data2 = RecordBatch::try_from_iter(vec![("c1", c1), ("c2", c2)])?;
let df1 = spark.create_dataframe(&data)?;
let df2 = spark.create_dataframe(&data2)?;
let output = df1.intersect(df2).collect().await?;
let c1: ArrayRef = Arc::new(Int64Array::from(vec![1, 19]));
let c2: ArrayRef = Arc::new(Int64Array::from(vec![1, 8]));
let expected = RecordBatch::try_from_iter(vec![("c1", c1), ("c2", c2)])?;
assert_eq!(output, expected);
Ok(())
}
#[tokio::test]
async fn test_df_is_empty() -> Result<(), SparkError> {
let spark = setup().await;
let records: ArrayRef = Arc::new(Int64Array::from(vec![] as Vec<i64>));
let schema = Schema::new(vec![Field::new("record", DataType::Int64, true)]);
let data = RecordBatch::try_new(Arc::new(schema), vec![records])?;
let df = spark.create_dataframe(&data)?;
assert!(df.is_empty().await?);
let records: ArrayRef = Arc::new(Int64Array::from(vec![None]));
let schema = Schema::new(vec![Field::new("record", DataType::Int64, true)]);
let data = RecordBatch::try_new(Arc::new(schema), vec![records])?;
let df = spark.create_dataframe(&data)?;
assert!(!df.is_empty().await?);
let records: ArrayRef = Arc::new(Int64Array::from(vec![1]));
let schema = Schema::new(vec![Field::new("record", DataType::Int64, true)]);
let data = RecordBatch::try_new(Arc::new(schema), vec![records])?;
let df = spark.create_dataframe(&data)?;
assert!(!df.is_empty().await?);
Ok(())
}
#[tokio::test]
async fn test_df_join() -> Result<(), SparkError> {
let spark = setup().await;
let data1 = mock_data();
let name: ArrayRef = Arc::new(StringArray::from(vec!["Tom", "Bob"]));
let height: ArrayRef = Arc::new(Int64Array::from(vec![80, 85]));
let data2 = RecordBatch::try_from_iter(vec![("name", name), ("height", height)])?;
let name: ArrayRef = Arc::new(StringArray::from(vec!["Alice", "Bob"]));
let age: ArrayRef = Arc::new(Int64Array::from(vec![23, 16]));
let data3 = RecordBatch::try_from_iter(vec![("name", name), ("age", age)])?;
let name: ArrayRef = Arc::new(StringArray::from(vec![
Some("Alice"),
Some("Bob"),
Some("Tom"),
None,
]));
let age: ArrayRef = Arc::new(Int64Array::from(vec![Some(18), Some(16), None, None]));
let height: ArrayRef = Arc::new(Int64Array::from(vec![Some(80), None, None, None]));
let schema = Schema::new(vec![
Field::new("name", DataType::Utf8, true),
Field::new("age", DataType::Int64, true),
Field::new("height", DataType::Int64, true),
]);
let data4 = RecordBatch::try_new(Arc::new(schema), vec![name, age, height])?;
let df1 = spark.create_dataframe(&data1)?.alias("df1");
let df2 = spark.create_dataframe(&data2)?.alias("df2");
let df3 = spark.create_dataframe(&data3)?.alias("df3");
let df4 = spark.create_dataframe(&data4)?.alias("df4");
let condition = Some(col("df1.name").eq(col("df2.name")));
let res = df1
.clone()
.join(df2.clone(), condition, JoinType::Inner)
.select(["df1.name", "df2.height"])
.collect()
.await?;
assert_eq!(2, res.num_rows());
let name = col("df1.name").eq(col("df4.name"));
let age = col("df1.age").eq(col("df4.age"));
let condition = Some(name.and(age));
let res = df1
.clone()
.join(df4.clone(), condition, JoinType::Inner)
.collect()
.await?;
assert_eq!(1, res.num_rows());
let condition = Some(col("df1.name").eq(col("df2.name")));
let res = df1
.clone()
.join(df2.clone(), condition, JoinType::FullOuter)
.select(["df1.name", "df2.height"])
.collect()
.await?;
assert_eq!(3, res.num_rows());
let name = col("df1.name").eq(col("df3.name"));
let age = col("df1.age").eq(col("df3.age"));
let condition = Some(name.and(age));
let res = df1
.clone()
.join(df3.clone(), condition, JoinType::FullOuter)
.select(["df1.name", "df3.age"])
.collect()
.await?;
assert_eq!(3, res.num_rows());
Ok(())
}
#[tokio::test]
async fn test_df_limit() -> Result<(), SparkError> {
let spark = setup().await;
let data = mock_data();
let df = spark.create_dataframe(&data)?;
let val = df.clone().limit(1).collect().await?;
let name: ArrayRef = Arc::new(StringArray::from(vec!["Tom"]));
let age: ArrayRef = Arc::new(Int64Array::from(vec![14]));
let expected = RecordBatch::try_from_iter(vec![("name", name), ("age", age)])?;
assert_eq!(val, expected);
let val = df.clone().limit(0).collect().await?;
assert_eq!(0, val.num_rows());
Ok(())
}
#[tokio::test]
async fn test_df_select() -> Result<(), SparkError> {
let spark = setup().await;
let data = mock_data();
let df = spark.create_dataframe(&data)?;
let val = df.clone().select("*").collect().await?;
assert_eq!(2, val.num_columns());
let val = df.clone().select("name").collect().await?;
assert_eq!(1, val.num_columns());
let val = df.clone().select(["name", "age"]).collect().await?;
assert_eq!(2, val.num_columns());
let val = df.clone().select(vec!["name", "age"]).collect().await?;
assert_eq!(2, val.num_columns());
let val = df
.clone()
.select([col("name"), col("age")])
.collect()
.await?;
assert_eq!(2, val.num_columns());
Ok(())
}
#[tokio::test]
async fn test_df_select_expr() -> Result<(), SparkError> {
let spark = setup().await;
let data = mock_data();
let df = spark.create_dataframe(&data)?;
let val = df.select_expr(["age * 2", "abs(age)"]).collect().await?;
let age: ArrayRef = Arc::new(Int64Array::from(vec![14, 23, 16]));
let age2: ArrayRef = Arc::new(Int64Array::from(vec![28, 46, 32]));
let expected = RecordBatch::try_from_iter(vec![("(age * 2)", age2), ("abs(age)", age)])?;
assert_eq!(val, expected);
Ok(())
}
#[tokio::test]
async fn test_df_with_columns() -> Result<(), SparkError> {
let spark = setup().await;
let data = mock_data();
let df = spark.create_dataframe(&data)?;
let cols = [("age2", col("age") + lit(2)), ("age3", col("age") + lit(3))];
let val = df
.clone()
.with_columns(cols)
.select(["name", "age", "age2", "age3"])
.collect()
.await?;
let name: ArrayRef = Arc::new(StringArray::from(vec!["Tom", "Alice", "Bob"]));
let age: ArrayRef = Arc::new(Int64Array::from(vec![14, 23, 16]));
let age2: ArrayRef = Arc::new(Int64Array::from(vec![16, 25, 18]));
let age3: ArrayRef = Arc::new(Int64Array::from(vec![17, 26, 19]));
let expected = RecordBatch::try_from_iter(vec![
("name", name),
("age", age),
("age2", age2),
("age3", age3),
])?;
assert_eq!(&val, &expected);
let cols = HashMap::from([("age2", col("age") + lit(2)), ("age3", col("age") + lit(3))]);
let val = df
.clone()
.with_columns(cols)
.select(["name", "age", "age2", "age3"])
.collect()
.await?;
assert_eq!(&val, &expected);
Ok(())
}
#[tokio::test]
async fn test_df_sort() -> Result<(), SparkError> {
let spark = setup().await;
let df = spark.range(None, 100, 1, Some(1));
let rows = df
.clone()
.sort([col("id").desc()])
.limit(1)
.collect()
.await?;
let a: ArrayRef = Arc::new(Int64Array::from(vec![99]));
let expected_batch = RecordBatch::try_from_iter(vec![("id", a)])?;
assert_eq!(expected_batch, rows);
let rows = df.sort([col("id")]).limit(1).collect().await?;
let a: ArrayRef = Arc::new(Int64Array::from(vec![0]));
let expected_batch = RecordBatch::try_from_iter(vec![("id", a)])?;
assert_eq!(expected_batch, rows);
Ok(())
}
#[tokio::test]
async fn test_df_unpivot() -> Result<(), SparkError> {
let spark = setup().await;
let ids: ArrayRef = Arc::new(Int64Array::from(vec![1, 2]));
let ints: ArrayRef = Arc::new(Int64Array::from(vec![11, 12]));
let floats: ArrayRef = Arc::new(Float32Array::from(vec![1.1, 1.2]));
let data = RecordBatch::try_from_iter(vec![("id", ids), ("int", ints), ("float", floats)])?;
let df = spark.create_dataframe(&data)?;
let df = df.unpivot("id", Some(["int", "float"]), "var", "val");
let res = df.collect().await?;
let ids: ArrayRef = Arc::new(Int64Array::from(vec![1, 1, 2, 2]));
let var: ArrayRef = Arc::new(StringArray::from(vec!["int", "float", "int", "float"]));
let val: ArrayRef = Arc::new(Float32Array::from(vec![11.0, 1.1, 12.0, 1.2]));
let expected = RecordBatch::try_from_iter(vec![("id", ids), ("var", var), ("val", val)])?;
assert_eq!(expected, res);
Ok(())
}
#[tokio::test]
async fn test_df_transform() -> Result<(), SparkError> {
let spark = setup().await;
let df = spark.range(None, 1, 1, None);
let val: i64 = 100;
let func =
|df: DataFrame| -> DataFrame { df.with_column("new_col", lit(val)).select("new_col") };
let res = df.transform(func).collect().await?;
let col: ArrayRef = Arc::new(Int64Array::from(vec![val]));
let expected = RecordBatch::try_from_iter(vec![("new_col", col)])?;
assert_eq!(expected, res);
Ok(())
}
#[tokio::test]
async fn test_df_to_json() -> Result<(), SparkError> {
let spark = setup().await;
let data = mock_data();
let df = spark.create_dataframe(&data)?;
let val = df.to_json().await?;
let expected = String::from("[{\"name\":\"Tom\",\"age\":14},{\"name\":\"Alice\",\"age\":23},{\"name\":\"Bob\",\"age\":16}]");
assert_eq!(expected, val);
let df = spark.range(Some(0), 0, 1, None);
let val = df.to_json().await?;
assert_eq!(String::from("[]"), val);
Ok(())
}
#[tokio::test]
#[cfg(feature = "datafusion")]
async fn test_df_to_datafusion() -> Result<(), SparkError> {
use datafusion::prelude::SessionContext;
let spark = setup().await;
let ctx = SessionContext::new();
let data = mock_data();
let df = spark.create_dataframe(&data)?;
let df_output = df.to_datafusion(&ctx).await?.collect().await?;
let df_expected = ctx.read_batch(data)?.collect().await?;
assert_eq!(df_expected, df_output);
let df = spark.range(Some(0), 0, 1, None);
let val = df.to_datafusion(&ctx).await?.collect().await?;
assert_eq!(0, val[0].num_rows());
Ok(())
}
#[tokio::test]
#[cfg(feature = "polars")]
async fn test_df_to_polars() -> Result<(), SparkError> {
let spark = setup().await;
let data = mock_data();
let schema = data.schema();
let mut columns = Vec::with_capacity(data.num_columns());
for (i, column) in data.columns().iter().enumerate() {
let arrow = Box::<dyn polars_arrow::array::Array>::from(&**column);
columns.push(polars::series::Series::from_arrow(
schema.fields().get(i).unwrap().name(),
arrow,
)?);
}
let df_expected = polars::frame::DataFrame::from_iter(columns);
let df = spark.create_dataframe(&data)?;
let df_output = df.to_polars().await?;
assert_eq!(df_expected, df_output);
Ok(())
}
#[tokio::test]
async fn test_df_explain_concurrent() -> Result<(), SparkError> {
let spark = setup().await;
let spark_clone = spark.clone();
let data = mock_data();
let df = spark.create_dataframe(&data)?;
let df_clone = spark_clone.create_dataframe(&data)?;
let (res, res_clone) = futures::join!(df.explain(None), df_clone.explain(None));
let (val, val_clone) = (res?, res_clone?);
assert!(val.contains("== Physical Plan =="));
assert!(val_clone.contains("== Physical Plan =="));
Ok(())
}
}