use std::borrow::Borrow;
use std::cell::RefCell;
use std::fmt::{Debug, Formatter};
use std::rc::Rc;
use std::sync::Arc;
use anyhow::Result;
use crate::algorithms::lazy::FstOp2;
use crate::algorithms::randgen::rand_state::RandState;
use crate::algorithms::randgen::tr_sampler::TrSampler;
use crate::algorithms::randgen::TrSelector;
use crate::fst_properties::mutable_properties::rand_gen_properties;
use crate::fst_properties::FstProperties;
use crate::prelude::Fst;
use crate::{Semiring, StateId, Tr, Trs, TrsVec, NO_STATE_ID};
pub struct RandGenFstOp<W, F, B, S>
where
    W: Semiring<Type = f32>,
    F: Fst<W>,
    B: Borrow<F>,
    S: TrSelector,
{
    fst: B,
    sampler: RefCell<TrSampler<W, F, B, S>>,
    npath: usize,
    state_table: RefCell<Vec<Rc<RandState>>>,
    weighted: bool,
    remove_total_weight: bool,
    superfinal: RefCell<StateId>,
}
impl<W, F, B, S> RandGenFstOp<W, F, B, S>
where
    W: Semiring<Type = f32>,
    F: Fst<W>,
    B: Borrow<F>,
    S: TrSelector,
{
    pub fn new(
        fst: B,
        sampler: TrSampler<W, F, B, S>,
        npath: usize,
        weighted: bool,
        remove_total_weight: bool,
    ) -> Self {
        Self {
            fst,
            sampler: RefCell::new(sampler),
            npath,
            state_table: RefCell::new(vec![]),
            weighted,
            remove_total_weight,
            superfinal: RefCell::new(NO_STATE_ID),
        }
    }
}
impl<W, F, B, S> Debug for RandGenFstOp<W, F, B, S>
where
    W: Semiring<Type = f32>,
    F: Fst<W>,
    B: Borrow<F>,
    S: TrSelector,
{
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        write!(
            f,
            "RandGenFstOp {{ fst : {:?}, sampler : {:?}, npath : {:?}, state_table : {:?}, weighted : {:?}, remove_total_weight : {:?}, superfinal : {:?} }}",
            self.fst.borrow(),
            self.sampler.borrow(),
            self.npath,
            self.state_table.borrow(),
            self.weighted,
            self.remove_total_weight,
            self.superfinal
        )
    }
}
impl<W, F, B, S> FstOp2<W> for RandGenFstOp<W, F, B, S>
where
    W: Semiring<Type = f32>,
    F: Fst<W>,
    B: Borrow<F>,
    S: TrSelector,
{
    fn compute_start(&self) -> Result<Option<StateId>> {
        if let Some(s) = self.fst.borrow().start() {
            let n = self.state_table.borrow().len();
            self.state_table.borrow_mut().push(Rc::new(
                RandState::new(s)
                    .with_nsamples(self.npath)
                    .with_length(0)
                    .with_select(0)
                    .with_parent(None),
            ));
            Ok(Some(n as StateId))
        } else {
            Ok(None)
        }
    }
    fn compute_trs_and_final_weight(&self, s: StateId) -> Result<(TrsVec<W>, Option<W>)> {
        if s == *self.superfinal.borrow() {
            let result = Ok((TrsVec::default(), Some(W::one())));
            return result;
        }
        let rstate = Rc::clone(self.state_table.borrow().get(s as usize).unwrap());
        self.sampler.borrow_mut().sample(&rstate)?;
        let aiter = self.fst.borrow().get_trs(rstate.state_id)?;
        let trs = aiter.trs();
        let num_trs = trs.len();
        let mut output_trs: Vec<Tr<W>> = vec![];
        let mut output_final_weight = None;
        for (&pos, &count) in self.sampler.borrow().iter() {
            let prob = (count as f32) / (rstate.nsamples as f32);
            if pos < num_trs {
                let tr = &trs[pos];
                let weight = if self.weighted {
                    W::new(-prob.ln())
                } else {
                    W::one()
                };
                output_trs.push(Tr::new(
                    tr.ilabel,
                    tr.olabel,
                    weight,
                    self.state_table.borrow().len() as StateId,
                ));
                let nrstate = RandState::new(tr.nextstate)
                    .with_nsamples(count)
                    .with_length(rstate.length + 1)
                    .with_select(pos)
                    .with_parent(Some(Rc::clone(&rstate)));
                self.state_table.borrow_mut().push(Rc::new(nrstate));
            } else {
                if self.weighted {
                    let weight = if self.remove_total_weight {
                        W::new(-prob.ln())
                    } else {
                        W::new(-(prob * self.npath as f32).ln())
                    };
                    output_final_weight = Some(weight);
                } else {
                    if *self.superfinal.borrow() == NO_STATE_ID {
                        *self.superfinal.borrow_mut() = self.state_table.borrow().len() as StateId;
                        self.state_table.borrow_mut().push(Rc::new(
                            RandState::new(NO_STATE_ID)
                                .with_nsamples(0)
                                .with_length(0)
                                .with_select(0)
                                .with_parent(None),
                        ));
                    }
                    for _ in 0..count {
                        output_trs.push(Tr::new(0, 0, W::one(), *self.superfinal.borrow()));
                    }
                }
            }
        }
        Ok((TrsVec(Arc::new(output_trs)), output_final_weight))
    }
    fn properties(&self) -> FstProperties {
        rand_gen_properties(self.fst.borrow().properties(), self.weighted)
            & FstProperties::copy_properties()
    }
}