use std::collections::HashMap;
use crate::errors::SparkError;
use crate::plan::LogicalPlanBuilder;
use crate::session::SparkSession;
use crate::spark;
use crate::DataFrame;
use spark::write_operation::SaveMode;
#[derive(Clone, Debug)]
pub struct DataFrameReader {
spark_session: SparkSession,
format: Option<String>,
read_options: HashMap<String, String>,
}
impl DataFrameReader {
pub fn new(spark_session: SparkSession) -> Self {
Self {
spark_session,
format: None,
read_options: HashMap::new(),
}
}
pub fn format(mut self, format: &str) -> Self {
self.format = Some(format.to_string());
self
}
pub fn option(mut self, key: &str, value: &str) -> Self {
self.read_options.insert(key.to_string(), value.to_string());
self
}
pub fn options<I, K, V>(mut self, options: I) -> Self
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<str>,
V: AsRef<str>,
{
self.read_options = options
.into_iter()
.map(|(k, v)| (k.as_ref().to_string(), v.as_ref().to_string()))
.collect();
self
}
pub fn load<'a, I>(&mut self, paths: I) -> DataFrame
where
I: IntoIterator<Item = &'a str>,
{
let read_type = Some(spark::relation::RelType::Read(spark::Read {
is_streaming: false,
read_type: Some(spark::read::ReadType::DataSource(spark::read::DataSource {
format: self.format.clone(),
schema: None,
options: self.read_options.clone(),
paths: paths.into_iter().map(|p| p.to_string()).collect(),
predicates: vec![],
})),
}));
let relation = spark::Relation {
common: Some(spark::RelationCommon {
source_info: "NA".to_string(),
plan_id: Some(1),
}),
rel_type: read_type,
};
let logical_plan = LogicalPlanBuilder::new(relation);
DataFrame::new(self.spark_session.clone(), logical_plan)
}
pub fn table(
&mut self,
table_name: &str,
options: Option<HashMap<String, String>>,
) -> DataFrame {
let read_type = Some(spark::relation::RelType::Read(spark::Read {
is_streaming: false,
read_type: Some(spark::read::ReadType::NamedTable(spark::read::NamedTable {
unparsed_identifier: table_name.to_string(),
options: options.unwrap_or(self.read_options.clone()),
})),
}));
let relation = spark::Relation {
common: Some(spark::RelationCommon {
source_info: "NA".to_string(),
plan_id: Some(1),
}),
rel_type: read_type,
};
let logical_plan = LogicalPlanBuilder::new(relation);
DataFrame::new(self.spark_session.clone(), logical_plan)
}
}
pub struct DataFrameWriter {
dataframe: DataFrame,
format: Option<String>,
mode: SaveMode,
bucket_by: Option<spark::write_operation::BucketBy>,
partition_by: Vec<String>,
sort_by: Vec<String>,
write_options: HashMap<String, String>,
}
impl DataFrameWriter {
pub fn new(dataframe: DataFrame) -> Self {
Self {
dataframe,
format: None,
mode: SaveMode::Overwrite,
bucket_by: None,
partition_by: vec![],
sort_by: vec![],
write_options: HashMap::new(),
}
}
pub fn format(mut self, format: &str) -> Self {
self.format = Some(format.to_string());
self
}
pub fn mode(mut self, mode: SaveMode) -> Self {
self.mode = mode;
self
}
#[allow(non_snake_case)]
pub fn bucketBy<'a, I>(mut self, num_buckets: i32, buckets: I) -> Self
where
I: IntoIterator<Item = &'a str>,
{
self.bucket_by = Some(spark::write_operation::BucketBy {
bucket_column_names: buckets.into_iter().map(|b| b.to_string()).collect(),
num_buckets,
});
self
}
#[allow(non_snake_case)]
pub fn sortBy<'a, I>(mut self, cols: I) -> Self
where
I: IntoIterator<Item = &'a str>,
{
self.sort_by = cols.into_iter().map(|col| col.to_string()).collect();
self
}
#[allow(non_snake_case)]
pub fn partitionBy<'a, I>(mut self, cols: I) -> Self
where
I: IntoIterator<Item = &'a str>,
{
self.sort_by = cols.into_iter().map(|col| col.to_string()).collect();
self
}
pub fn option(mut self, key: &str, value: &str) -> Self {
self.write_options
.insert(key.to_string(), value.to_string());
self
}
pub fn options<I, K, V>(mut self, options: I) -> Self
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<str>,
V: AsRef<str>,
{
self.write_options = options
.into_iter()
.map(|(k, v)| (k.as_ref().to_string(), v.as_ref().to_string()))
.collect();
self
}
pub async fn save(&mut self, path: &str) -> Result<(), SparkError> {
let write_command = spark::command::CommandType::WriteOperation(spark::WriteOperation {
input: Some(self.dataframe.logical_plan.relation.clone()),
source: self.format.clone(),
mode: self.mode.into(),
sort_column_names: self.sort_by.clone(),
partitioning_columns: self.partition_by.clone(),
bucket_by: self.bucket_by.clone(),
options: self.write_options.clone(),
save_type: Some(spark::write_operation::SaveType::Path(path.to_string())),
});
let plan = LogicalPlanBuilder::build_plan_cmd(write_command);
self.dataframe
.spark_session
.client
.execute_command(plan)
.await?;
Ok(())
}
async fn save_table(&mut self, table_name: &str, save_method: i32) -> Result<(), SparkError> {
let write_command = spark::command::CommandType::WriteOperation(spark::WriteOperation {
input: Some(self.dataframe.logical_plan.relation.clone()),
source: self.format.clone(),
mode: self.mode.into(),
sort_column_names: self.sort_by.clone(),
partitioning_columns: self.partition_by.clone(),
bucket_by: self.bucket_by.clone(),
options: self.write_options.clone(),
save_type: Some(spark::write_operation::SaveType::Table(
spark::write_operation::SaveTable {
table_name: table_name.to_string(),
save_method,
},
)),
});
let plan = LogicalPlanBuilder::build_plan_cmd(write_command);
self.dataframe
.spark_session
.client
.execute_command(plan)
.await?;
Ok(())
}
#[allow(non_snake_case)]
pub async fn saveAsTable(&mut self, table_name: &str) -> Result<(), SparkError> {
self.save_table(table_name, 1).await
}
#[allow(non_snake_case)]
pub async fn insertInto(&mut self, table_name: &str) -> Result<(), SparkError> {
self.save_table(table_name, 2).await
}
}