use vortex::array::{PrimitiveArray, TemporalArray};
use vortex::compute::unary::{scalar_at, ScalarAtFn};
use vortex::compute::{slice, take, ArrayCompute, SliceFn, TakeFn};
use vortex::validity::ArrayValidity;
use vortex::{Array, ArrayDType, IntoArray, IntoArrayVariant};
use vortex_datetime_dtype::{TemporalMetadata, TimeUnit};
use vortex_dtype::DType;
use vortex_error::{vortex_bail, VortexResult, VortexUnwrap as _};
use vortex_scalar::Scalar;
use crate::DateTimePartsArray;
impl ArrayCompute for DateTimePartsArray {
    fn scalar_at(&self) -> Option<&dyn ScalarAtFn> {
        Some(self)
    }
    fn slice(&self) -> Option<&dyn SliceFn> {
        Some(self)
    }
    fn take(&self) -> Option<&dyn TakeFn> {
        Some(self)
    }
}
impl TakeFn for DateTimePartsArray {
    fn take(&self, indices: &Array) -> VortexResult<Array> {
        Ok(Self::try_new(
            self.dtype().clone(),
            take(&self.days(), indices)?,
            take(&self.seconds(), indices)?,
            take(&self.subsecond(), indices)?,
        )?
        .into_array())
    }
}
impl SliceFn for DateTimePartsArray {
    fn slice(&self, start: usize, stop: usize) -> VortexResult<Array> {
        Ok(Self::try_new(
            self.dtype().clone(),
            slice(&self.days(), start, stop)?,
            slice(&self.seconds(), start, stop)?,
            slice(&self.subsecond(), start, stop)?,
        )?
        .into_array())
    }
}
impl ScalarAtFn for DateTimePartsArray {
    fn scalar_at(&self, index: usize) -> VortexResult<Scalar> {
        let DType::Extension(ext, nullability) = self.dtype().clone() else {
            vortex_bail!(
                "DateTimePartsArray must have extension dtype, found {}",
                self.dtype()
            );
        };
        let TemporalMetadata::Timestamp(time_unit, _) = TemporalMetadata::try_from(&ext)? else {
            vortex_bail!("Metadata must be Timestamp, found {}", ext.id());
        };
        let divisor = match time_unit {
            TimeUnit::Ns => 1_000_000_000,
            TimeUnit::Us => 1_000_000,
            TimeUnit::Ms => 1_000,
            TimeUnit::S => 1,
            TimeUnit::D => vortex_bail!("Invalid time unit D"),
        };
        let days: i64 = scalar_at(&self.days(), index)?.try_into()?;
        let seconds: i64 = scalar_at(&self.seconds(), index)?.try_into()?;
        let subseconds: i64 = scalar_at(&self.subsecond(), index)?.try_into()?;
        let scalar = days * 86_400 * divisor + seconds * divisor + subseconds;
        Ok(Scalar::primitive(scalar, nullability))
    }
    fn scalar_at_unchecked(&self, index: usize) -> Scalar {
        <Self as ScalarAtFn>::scalar_at(self, index).vortex_unwrap()
    }
}
pub fn decode_to_temporal(array: &DateTimePartsArray) -> VortexResult<TemporalArray> {
    let DType::Extension(ext, _) = array.dtype().clone() else {
        vortex_bail!(ComputeError: "expected dtype to be DType::Extension variant")
    };
    let Ok(temporal_metadata) = TemporalMetadata::try_from(&ext) else {
        vortex_bail!(ComputeError: "must decode TemporalMetadata from extension metadata");
    };
    let divisor = match temporal_metadata.time_unit() {
        TimeUnit::Ns => 1_000_000_000,
        TimeUnit::Us => 1_000_000,
        TimeUnit::Ms => 1_000,
        TimeUnit::S => 1,
        TimeUnit::D => vortex_bail!(InvalidArgument: "cannot decode into TimeUnit::D"),
    };
    let days_buf = array.days().into_primitive()?;
    let seconds_buf = array.seconds().into_primitive()?;
    let subsecond_buf = array.subsecond().into_primitive()?;
    let values = days_buf
        .maybe_null_slice::<i64>()
        .iter()
        .zip(seconds_buf.maybe_null_slice::<i64>().iter())
        .zip(subsecond_buf.maybe_null_slice::<i64>().iter())
        .map(|((d, s), ss)| d * 86_400 * divisor + s * divisor + ss)
        .collect::<Vec<_>>();
    Ok(TemporalArray::new_timestamp(
        PrimitiveArray::from_vec(values, array.logical_validity().into_validity()).into_array(),
        temporal_metadata.time_unit(),
        temporal_metadata.time_zone().map(ToString::to_string),
    ))
}
#[cfg(test)]
mod test {
    use vortex::array::{PrimitiveArray, TemporalArray};
    use vortex::{IntoArray, IntoArrayVariant};
    use vortex_datetime_dtype::TimeUnit;
    use vortex_dtype::{DType, Nullability};
    use crate::compute::decode_to_temporal;
    use crate::{compress_temporal, DateTimePartsArray};
    #[test]
    fn test_roundtrip_datetimeparts() {
        let raw_values = vec![
            86_400i64,            86_400i64 + 1000,     86_400i64 + 1000 + 1, ];
        let raw_millis = PrimitiveArray::from(raw_values.clone()).into_array();
        let temporal_array =
            TemporalArray::new_timestamp(raw_millis, TimeUnit::Ms, Some("UTC".to_string()));
        let (days, seconds, subseconds) = compress_temporal(temporal_array.clone()).unwrap();
        let date_times = DateTimePartsArray::try_new(
            DType::Extension(temporal_array.ext_dtype().clone(), Nullability::NonNullable),
            days,
            seconds,
            subseconds,
        )
        .unwrap();
        let primitive_values = decode_to_temporal(&date_times)
            .unwrap()
            .temporal_values()
            .into_primitive()
            .unwrap();
        assert_eq!(
            primitive_values.maybe_null_slice::<i64>(),
            raw_values.as_slice()
        );
    }
}