use arrow::array::RecordBatch;
use crate::errors::SparkError;
use crate::plan::LogicalPlanBuilder;
use crate::session::SparkSession;
use crate::spark;
use crate::storage::StorageLevel;
#[derive(Debug, Clone)]
pub struct Catalog {
spark_session: SparkSession,
}
impl Catalog {
pub fn new(spark_session: SparkSession) -> Self {
Self { spark_session }
}
fn arrow_to_bool(record: RecordBatch) -> Result<bool, SparkError> {
let col = record.column(0);
let data: &arrow::array::BooleanArray = match col.data_type() {
arrow::datatypes::DataType::Boolean => col.as_any().downcast_ref().unwrap(),
_ => unimplemented!("only Boolean data types are currently handled currently."),
};
Ok(data.value(0))
}
pub async fn current_catalog(self) -> Result<String, SparkError> {
let cat_type = Some(spark::catalog::CatType::CurrentCatalog(
spark::CurrentCatalog {},
));
let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type });
let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type));
self.spark_session.client().to_first_value(plan).await
}
pub async fn set_current_catalog(self, catalog_name: &str) -> Result<(), SparkError> {
let cat_type = Some(spark::catalog::CatType::SetCurrentCatalog(
spark::SetCurrentCatalog {
catalog_name: catalog_name.to_string(),
},
));
let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type });
let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type));
self.spark_session.client().execute_command(plan).await
}
pub async fn list_catalogs(self, pattern: Option<&str>) -> Result<RecordBatch, SparkError> {
let pattern = pattern.map(|val| val.to_owned());
let cat_type = Some(spark::catalog::CatType::ListCatalogs(spark::ListCatalogs {
pattern,
}));
let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type });
let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type));
self.spark_session.client().to_arrow(plan).await
}
pub async fn current_database(self) -> Result<String, SparkError> {
let cat_type = Some(spark::catalog::CatType::CurrentDatabase(
spark::CurrentDatabase {},
));
let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type });
let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type));
self.spark_session.client().to_first_value(plan).await
}
pub async fn set_current_database(self, db_name: &str) -> Result<(), SparkError> {
let cat_type = Some(spark::catalog::CatType::SetCurrentDatabase(
spark::SetCurrentDatabase {
db_name: db_name.to_string(),
},
));
let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type });
let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type));
self.spark_session.client().execute_command(plan).await
}
pub async fn list_databases(self, pattern: Option<&str>) -> Result<RecordBatch, SparkError> {
let pattern = pattern.map(|val| val.to_owned());
let cat_type = Some(spark::catalog::CatType::ListDatabases(
spark::ListDatabases { pattern },
));
let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type });
let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type));
self.spark_session.client().to_arrow(plan).await
}
pub async fn get_database(self, db_name: &str) -> Result<RecordBatch, SparkError> {
let cat_type = Some(spark::catalog::CatType::GetDatabase(spark::GetDatabase {
db_name: db_name.to_string(),
}));
let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type });
let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type));
self.spark_session.client().to_arrow(plan).await
}
pub async fn database_exists(self, db_name: &str) -> Result<bool, SparkError> {
let cat_type = Some(spark::catalog::CatType::DatabaseExists(
spark::DatabaseExists {
db_name: db_name.to_string(),
},
));
let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type });
let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type));
let record = self.spark_session.client().to_arrow(plan).await?;
Catalog::arrow_to_bool(record)
}
pub async fn list_tables(
self,
pattern: Option<&str>,
db_name: Option<&str>,
) -> Result<RecordBatch, SparkError> {
let cat_type = Some(spark::catalog::CatType::ListTables(spark::ListTables {
db_name: db_name.map(|db| db.to_owned()),
pattern: pattern.map(|val| val.to_owned()),
}));
let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type });
let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type));
self.spark_session.client().to_arrow(plan).await
}
pub async fn get_table(self, table_name: &str) -> Result<RecordBatch, SparkError> {
let cat_type = Some(spark::catalog::CatType::GetTable(spark::GetTable {
table_name: table_name.to_string(),
db_name: None,
}));
let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type });
let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type));
self.spark_session.client().to_arrow(plan).await
}
pub async fn list_functions(
self,
db_name: Option<&str>,
pattern: Option<&str>,
) -> Result<RecordBatch, SparkError> {
let cat_type = Some(spark::catalog::CatType::ListFunctions(
spark::ListFunctions {
db_name: db_name.map(|val| val.to_owned()),
pattern: pattern.map(|val| val.to_owned()),
},
));
let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type });
let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type));
self.spark_session.client().to_arrow(plan).await
}
pub async fn function_exists(
self,
function_name: &str,
db_name: Option<&str>,
) -> Result<bool, SparkError> {
let cat_type = Some(spark::catalog::CatType::FunctionExists(
spark::FunctionExists {
function_name: function_name.to_string(),
db_name: db_name.map(|val| val.to_owned()),
},
));
let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type });
let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type));
let record = self.spark_session.client().to_arrow(plan).await?;
Catalog::arrow_to_bool(record)
}
pub async fn get_function(self, function_name: &str) -> Result<RecordBatch, SparkError> {
let cat_type = Some(spark::catalog::CatType::GetFunction(spark::GetFunction {
function_name: function_name.to_string(),
db_name: None,
}));
let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type });
let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type));
self.spark_session.client().to_arrow(plan).await
}
pub async fn list_columns(
self,
table_name: &str,
db_name: Option<&str>,
) -> Result<RecordBatch, SparkError> {
let cat_type = Some(spark::catalog::CatType::ListColumns(spark::ListColumns {
table_name: table_name.to_owned(),
db_name: db_name.map(|val| val.to_owned()),
}));
let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type });
let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type));
self.spark_session.client().to_arrow(plan).await
}
pub async fn table_exists(
self,
table_name: &str,
db_name: Option<&str>,
) -> Result<bool, SparkError> {
let cat_type = Some(spark::catalog::CatType::TableExists(spark::TableExists {
table_name: table_name.to_string(),
db_name: db_name.map(|val| val.to_owned()),
}));
let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type });
let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type));
let record = self.spark_session.client().to_arrow(plan).await?;
Catalog::arrow_to_bool(record)
}
pub async fn drop_temp_view(self, view_name: &str) -> Result<bool, SparkError> {
let cat_type = Some(spark::catalog::CatType::DropTempView(spark::DropTempView {
view_name: view_name.to_string(),
}));
let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type });
let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type));
let record = self.spark_session.client().to_arrow(plan).await?;
Catalog::arrow_to_bool(record)
}
pub async fn drop_global_temp_view(self, view_name: &str) -> Result<bool, SparkError> {
let cat_type = Some(spark::catalog::CatType::DropGlobalTempView(
spark::DropGlobalTempView {
view_name: view_name.to_string(),
},
));
let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type });
let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type));
let record = self.spark_session.client().to_arrow(plan).await?;
Catalog::arrow_to_bool(record)
}
pub async fn is_cached(self, table_name: &str) -> Result<bool, SparkError> {
let cat_type = Some(spark::catalog::CatType::IsCached(spark::IsCached {
table_name: table_name.to_string(),
}));
let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type });
let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type));
let record = self.spark_session.client().to_arrow(plan).await?;
Catalog::arrow_to_bool(record)
}
pub async fn cache_table(
self,
table_name: &str,
storage_level: Option<StorageLevel>,
) -> Result<(), SparkError> {
let cat_type = Some(spark::catalog::CatType::CacheTable(spark::CacheTable {
table_name: table_name.to_string(),
storage_level: storage_level.map(|val| val.to_owned().into()),
}));
let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type });
let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type));
self.spark_session.client().execute_command(plan).await
}
pub async fn uncache_table(self, table_name: &str) -> Result<(), SparkError> {
let cat_type = Some(spark::catalog::CatType::UncacheTable(spark::UncacheTable {
table_name: table_name.to_string(),
}));
let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type });
let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type));
self.spark_session.client().execute_command(plan).await
}
pub async fn clear_cache(self) -> Result<(), SparkError> {
let cat_type = Some(spark::catalog::CatType::ClearCache(spark::ClearCache {}));
let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type });
let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type));
self.spark_session.client().execute_command(plan).await
}
pub async fn refresh_table(self, table_name: &str) -> Result<(), SparkError> {
let cat_type = Some(spark::catalog::CatType::RefreshTable(spark::RefreshTable {
table_name: table_name.to_string(),
}));
let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type });
let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type));
self.spark_session.client().execute_command(plan).await
}
pub async fn recover_partitions(self, table_name: &str) -> Result<(), SparkError> {
let cat_type = Some(spark::catalog::CatType::RecoverPartitions(
spark::RecoverPartitions {
table_name: table_name.to_string(),
},
));
let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type });
let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type));
self.spark_session.client().execute_command(plan).await
}
pub async fn refresh_by_path(self, path: &str) -> Result<(), SparkError> {
let cat_type = Some(spark::catalog::CatType::RefreshByPath(
spark::RefreshByPath {
path: path.to_string(),
},
));
let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type });
let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type));
self.spark_session.client().execute_command(plan).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::errors::SparkError;
use crate::SparkSessionBuilder;
async fn setup() -> SparkSession {
println!("SparkSession Setup");
let connection = "sc://127.0.0.1:15002/;user_id=rust_catalog";
SparkSessionBuilder::remote(connection)
.build()
.await
.unwrap()
}
#[tokio::test]
async fn test_current_catalog() -> Result<(), SparkError> {
let spark = setup().await;
let value = spark.catalog().current_catalog().await?;
assert_eq!(value, "spark_catalog".to_string());
Ok(())
}
#[tokio::test]
async fn test_set_current_catalog() -> Result<(), SparkError> {
let spark = setup().await;
spark.catalog().set_current_catalog("spark_catalog").await?;
assert!(true);
Ok(())
}
#[tokio::test]
#[should_panic]
async fn test_set_current_catalog_panic() -> () {
let spark = setup().await;
spark
.catalog()
.set_current_catalog("not_a_real_catalog")
.await
.unwrap();
()
}
#[tokio::test]
async fn test_list_catalogs() -> Result<(), SparkError> {
let spark = setup().await;
let value = spark.catalog().list_catalogs(None).await?;
assert_eq!(2, value.num_columns());
assert_eq!(1, value.num_rows());
Ok(())
}
#[tokio::test]
async fn test_current_database() -> Result<(), SparkError> {
let spark = setup().await;
let value = spark.catalog().current_database().await?;
assert_eq!(value, "default".to_string());
Ok(())
}
#[tokio::test]
async fn test_set_current_database() -> Result<(), SparkError> {
let spark = setup().await;
spark.sql("CREATE SCHEMA current_db").await?;
spark.catalog().set_current_database("current_db").await?;
assert!(true);
spark.sql("DROP SCHEMA current_db").await?;
Ok(())
}
#[tokio::test]
#[should_panic]
async fn test_set_current_database_panic() -> () {
let spark = setup().await;
spark
.catalog()
.set_current_catalog("not_a_real_db")
.await
.unwrap();
()
}
#[tokio::test]
async fn test_get_database() -> Result<(), SparkError> {
let spark = setup().await;
spark.sql("CREATE SCHEMA get_db").await?;
let res = spark.clone().catalog().get_database("get_db").await?;
assert_eq!(res.num_rows(), 1);
spark.sql("DROP SCHEMA get_db").await?;
Ok(())
}
#[tokio::test]
async fn test_database_exists() -> Result<(), SparkError> {
let spark = setup().await;
let res = spark.catalog().database_exists("default").await?;
assert!(res);
let res = spark.catalog().database_exists("not_real").await?;
assert!(!res);
Ok(())
}
#[tokio::test]
async fn test_function_exists() -> Result<(), SparkError> {
let spark = setup().await;
let res = spark.catalog().function_exists("len", None).await?;
assert!(res);
Ok(())
}
#[tokio::test]
async fn test_list_columns() -> Result<(), SparkError> {
let spark = setup().await;
spark.sql("DROP TABLE IF EXISTS tmp_table").await?;
spark
.sql("CREATE TABLE tmp_table (name STRING, age INT) using parquet")
.await?;
let res = spark.catalog().list_columns("tmp_table", None).await?;
assert_eq!(res.num_rows(), 2);
spark.sql("DROP TABLE IF EXISTS tmp_table").await?;
Ok(())
}
#[tokio::test]
async fn test_drop_view() -> Result<(), SparkError> {
let spark = setup().await;
spark
.range(None, 2, 1, Some(1))
.create_or_replace_global_temp_view("tmp_view")
.await?;
let res = spark.catalog().drop_global_temp_view("tmp_view").await?;
assert!(res);
spark
.clone()
.range(None, 2, 1, Some(1))
.create_or_replace_temp_view("tmp_view")
.await?;
let res = spark.catalog().drop_temp_view("tmp_view").await?;
assert!(res);
Ok(())
}
#[tokio::test]
async fn test_cache_table() -> Result<(), SparkError> {
let spark = setup().await;
spark
.sql("CREATE TABLE cache_table (name STRING, age INT) using parquet")
.await?;
spark.catalog().cache_table("cache_table", None).await?;
let res = spark.catalog().is_cached("cache_table").await?;
assert!(res);
spark.catalog().uncache_table("cache_table").await?;
let res = spark.catalog().is_cached("cache_table").await?;
assert!(!res);
spark.sql("DROP TABLE cache_table").await?;
Ok(())
}
}