#[cfg(test)]
#[path = "../../../tests/unit/models/problem/jobs_test.rs"]
mod jobs_test;
use crate::models::common::*;
use crate::models::problem::{Costs, Fleet, TransportCost};
use hashbrown::HashMap;
use std::cell::UnsafeCell;
use std::cmp::Ordering::Less;
use std::hash::{Hash, Hasher};
use std::sync::{Arc, Weak};
#[derive(Clone)]
pub enum Job {
    
    Single(Arc<Single>),
    
    Multi(Arc<Multi>),
}
impl Job {
    
    pub fn as_single(&self) -> Option<&Arc<Single>> {
        match &self {
            Job::Single(job) => Some(job),
            _ => None,
        }
    }
    
    pub fn to_single(&self) -> &Arc<Single> {
        self.as_single().expect("Unexpected job type: multi")
    }
    
    pub fn as_multi(&self) -> Option<&Arc<Multi>> {
        match &self {
            Job::Multi(job) => Some(job),
            _ => None,
        }
    }
    
    pub fn to_multi(&self) -> &Arc<Multi> {
        self.as_multi().expect("Unexpected job type: single")
    }
    
    pub fn dimens(&self) -> &Dimensions {
        match &self {
            Job::Single(single) => &single.dimens,
            Job::Multi(multi) => &multi.dimens,
        }
    }
}
#[derive(Clone)]
pub struct Place {
    
    pub location: Option<Location>,
    
    pub duration: Duration,
    
    pub times: Vec<TimeSpan>,
}
pub struct Single {
    
    pub places: Vec<Place>,
    
    pub dimens: Dimensions,
}
pub struct Multi {
    
    pub jobs: Vec<Arc<Single>>,
    
    pub dimens: Dimensions,
    
    permutator: Box<dyn JobPermutation + Send + Sync>,
}
pub trait JobPermutation {
    
    
    fn get(&self) -> Vec<Vec<usize>>;
    
    fn validate(&self, permutation: &[usize]) -> bool;
}
pub struct FixedJobPermutation {
    permutations: Vec<Vec<usize>>,
}
impl FixedJobPermutation {
    
    pub fn new(permutations: Vec<Vec<usize>>) -> Self {
        Self { permutations }
    }
}
impl JobPermutation for FixedJobPermutation {
    fn get(&self) -> Vec<Vec<usize>> {
        self.permutations.clone()
    }
    fn validate(&self, permutation: &[usize]) -> bool {
        self.permutations
            .iter()
            .any(|prm| prm.len() == permutation.len() && prm.iter().zip(permutation.iter()).all(|(&a, &b)| a == b))
    }
}
impl Multi {
    
    
    pub fn new(jobs: Vec<Arc<Single>>, dimens: Dimensions) -> Self {
        let permutations = vec![(0..jobs.len()).collect()];
        Self { jobs, dimens, permutator: Box::new(FixedJobPermutation::new(permutations)) }
    }
    
    pub fn new_with_permutator(
        jobs: Vec<Arc<Single>>,
        dimens: Dimensions,
        permutator: Box<dyn JobPermutation + Send + Sync>,
    ) -> Self {
        Self { jobs, dimens, permutator }
    }
    
    pub fn permutations(&self) -> Vec<Vec<Arc<Single>>> {
        self.permutator
            .get()
            .iter()
            .map(|perm| perm.iter().map(|&i| self.jobs.get(i).unwrap().clone()).collect())
            .collect()
    }
    
    pub fn validate(&self, permutations: &[usize]) -> bool {
        self.permutator.validate(permutations)
    }
    
    pub fn bind(multi: Self) -> Arc<Self> {
        
        struct SingleConstruct {
            pub places: UnsafeCell<Vec<Place>>,
            pub dimens: UnsafeCell<Dimensions>,
        }
        let multi = Arc::new(multi);
        multi.jobs.iter().for_each(|job| {
            let weak_multi = Arc::downgrade(&multi);
            let job: Arc<SingleConstruct> = unsafe { std::mem::transmute(job.clone()) };
            let dimens = unsafe { &mut *job.dimens.get() };
            dimens.set_value("rf", weak_multi);
        });
        multi
    }
    
    pub fn roots(single: &Single) -> Option<Arc<Multi>> {
        single.dimens.get_value::<Weak<Multi>>("rf").and_then(|w| w.upgrade())
    }
}
type JobIndex = HashMap<Job, (Vec<(Job, Cost)>, HashMap<Job, Cost>, Cost)>;
pub struct Jobs {
    jobs: Vec<Job>,
    index: HashMap<Profile, JobIndex>,
}
impl Jobs {
    
    pub fn new(fleet: &Fleet, jobs: Vec<Job>, transport: &Arc<dyn TransportCost + Send + Sync>) -> Jobs {
        Jobs { jobs: jobs.clone(), index: create_index(fleet, jobs, transport) }
    }
    
    pub fn all<'a>(&'a self) -> impl Iterator<Item = Job> + 'a {
        self.jobs.iter().cloned()
    }
    
    pub fn all_as_slice(&self) -> &[Job] {
        self.jobs.as_slice()
    }
    
    
    pub fn neighbors(&self, profile: Profile, job: &Job, _: Timestamp) -> impl Iterator<Item = &(Job, Cost)> {
        self.index.get(&profile).unwrap().get(job).unwrap().0.iter()
    }
    
    pub fn distance(&self, profile: Profile, from: &Job, to: &Job, _: Timestamp) -> Cost {
        *self.index.get(&profile).unwrap().get(from).unwrap().1.get(to).unwrap()
    }
    
    pub fn rank(&self, profile: Profile, job: &Job) -> Cost {
        self.index.get(&profile).unwrap().get(job).unwrap().2
    }
    
    pub fn size(&self) -> usize {
        self.jobs.len()
    }
}
impl PartialEq<Job> for Job {
    fn eq(&self, other: &Job) -> bool {
        match (&self, other) {
            (Job::Single(_), Job::Multi(_)) => false,
            (Job::Multi(_), Job::Single(_)) => false,
            (Job::Single(lhs), Job::Single(rhs)) => lhs.as_ref() as *const Single == rhs.as_ref() as *const Single,
            (Job::Multi(lhs), Job::Multi(rhs)) => lhs.as_ref() as *const Multi == rhs.as_ref() as *const Multi,
        }
    }
}
impl Eq for Job {}
impl Hash for Job {
    fn hash<H: Hasher>(&self, state: &mut H) {
        match self {
            Job::Single(single) => {
                let address = single.as_ref() as *const Single;
                address.hash(state);
            }
            Job::Multi(multi) => {
                let address = multi.as_ref() as *const Multi;
                address.hash(state);
            }
        }
    }
}
const DEFAULT_DEPARTURE: Timestamp = 0.;
const DEFAULT_COST: Cost = 0.;
const UNREACHABLE_COST: Cost = std::f32::MAX as f64;
fn create_index(
    fleet: &Fleet,
    jobs: Vec<Job>,
    transport: &Arc<dyn TransportCost + Send + Sync>,
) -> HashMap<Profile, JobIndex> {
    let avg_profile_costs = get_avg_profile_costs(fleet);
    fleet.profiles.iter().cloned().fold(HashMap::new(), |mut acc, profile| {
        let avg_costs = avg_profile_costs.get(&profile).unwrap();
        
        let starts: Vec<Location> = fleet
            .vehicles
            .iter()
            .filter(|v| v.profile == profile)
            .flat_map(|v| v.details.iter().map(|d| d.start.as_ref().map(|s| s.location)))
            .filter(|s| s.is_some())
            .map(|s| s.unwrap())
            .collect();
        
        let item = jobs.iter().cloned().fold(HashMap::new(), |mut acc, job| {
            let mut sorted_job_costs: Vec<(Job, Cost)> = jobs
                .iter()
                .filter(|j| **j != job)
                .map(|j| (j.clone(), get_cost_between_jobs(profile, avg_costs, transport.as_ref(), &job, j)))
                .collect();
            sorted_job_costs.sort_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Less));
            let fleet_costs = starts
                .iter()
                .cloned()
                .map(|s| get_cost_between_job_and_location(profile, avg_costs, transport.as_ref(), &job, s))
                .min_by(|a, b| a.partial_cmp(b).unwrap_or(Less))
                .unwrap_or(DEFAULT_COST);
            let job_costs_map = sorted_job_costs.iter().cloned().collect::<HashMap<_, _>>();
            acc.insert(job, (sorted_job_costs, job_costs_map, fleet_costs));
            acc
        });
        acc.insert(profile, item);
        acc
    })
}
fn get_cost_between_locations(
    profile: Profile,
    costs: &Costs,
    transport: &(dyn TransportCost + Send + Sync),
    from: Location,
    to: Location,
) -> f64 {
    let distance = transport.distance(profile, from, to, DEFAULT_DEPARTURE);
    let duration = transport.duration(profile, from, to, DEFAULT_DEPARTURE);
    if distance < 0. || duration < 0. {
        
        UNREACHABLE_COST
    } else {
        distance * costs.per_distance + duration * costs.per_driving_time
    }
}
fn get_cost_between_job_and_location(
    profile: Profile,
    costs: &Costs,
    transport: &(dyn TransportCost + Send + Sync),
    job: &Job,
    to: Location,
) -> Cost {
    get_job_locations(job)
        .map(|from| match from {
            Some(from) => get_cost_between_locations(profile, costs, transport, from, to),
            _ => DEFAULT_COST,
        })
        .min_by(|a, b| a.partial_cmp(b).unwrap_or(Less))
        .unwrap_or(DEFAULT_COST)
}
fn get_cost_between_jobs(
    profile: Profile,
    costs: &Costs,
    transport: &(dyn TransportCost + Send + Sync),
    lhs: &Job,
    rhs: &Job,
) -> f64 {
    let outer: Vec<Option<Location>> = get_job_locations(lhs).collect();
    let inner: Vec<Option<Location>> = get_job_locations(rhs).collect();
    outer
        .iter()
        .flat_map(|o| inner.iter().map(move |i| (*o, *i)))
        .map(|pair| match pair {
            (Some(from), Some(to)) => get_cost_between_locations(profile, costs, transport, from, to),
            _ => DEFAULT_COST,
        })
        .min_by(|a, b| a.partial_cmp(b).unwrap_or(Less))
        .unwrap_or(DEFAULT_COST)
}
fn get_job_locations<'a>(job: &'a Job) -> Box<dyn Iterator<Item = Option<Location>> + 'a> {
    match job {
        Job::Single(single) => Box::new(single.places.iter().map(|p| p.location)),
        Job::Multi(multi) => Box::new(multi.jobs.iter().flat_map(|j| j.places.iter().map(|p| p.location))),
    }
}
fn get_avg_profile_costs(fleet: &Fleet) -> HashMap<Profile, Costs> {
    let get_avg_by = |costs: &Vec<Costs>, map_cost_fn: fn(&Costs) -> f64| -> f64 {
        costs.iter().map(map_cost_fn).sum::<f64>() / (costs.len() as f64)
    };
    fleet
        .vehicles
        .iter()
        .fold(HashMap::new(), |mut acc, vehicle| {
            acc.entry(vehicle.profile).or_insert_with(Vec::new).push(vehicle.costs.clone());
            acc
        })
        .iter()
        .map(|(&profile, costs)| {
            (
                profile,
                Costs {
                    fixed: get_avg_by(&costs, |c| c.fixed),
                    per_distance: get_avg_by(&costs, |c| c.per_distance),
                    per_driving_time: get_avg_by(&costs, |c| c.per_driving_time),
                    per_waiting_time: get_avg_by(&costs, |c| c.per_waiting_time),
                    per_service_time: get_avg_by(&costs, |c| c.per_service_time),
                },
            )
        })
        .collect()
}