use std::fmt::Debug;
use std::collections::BTreeMap;
use logaddexp::LogSumExp;
use crate::*;
use crate::map::Map;
#[derive(Debug,Clone,PartialEq,PartialOrd)]
pub struct Categorical<T>{
    pub count: usize,
    pub map: Map<T,(usize,f64)>
}
impl<T> Categorical<T> {
    pub fn singleton(k: T) -> Categorical<T> {
	let count = 1;
	let map = Map::singleton(k,(1,(0.0))); Categorical{ count, map }
    }
    fn from_sorted(counts: Vec<(T,usize)>) -> Self {
	debug_assert!(counts.iter().all(|(_,n)| *n > 0));
	let count = counts.iter().map(|(_,n)| n).sum();
	let ln_count = (count as f64).ln();
	let lps = counts.into_iter()
	    .map(|(k,n)| (k, (n, (n as f64).ln() - ln_count)))
	    .collect::<Vec<_>>();
	let map = Map(lps);
	Categorical{ count, map }
    }
    pub fn from_data<I>(data: I) -> Self
    where
	T: Ord,
	I: IntoIterator<Item=T>
    {
	let mut counts = BTreeMap::new();
	for item in data {
	    *counts.entry(item).or_insert(0) += 1;
	}
	Self::from_sorted(counts.into_iter().collect())
    }
    pub fn len(&self) -> usize {
	self.map.len()
    }
    pub fn lookup(&self, key: &T) -> Option<f64>
    where T: Clone + Ord
    {
        self.map.lookup(key).copied().map(|p| p.1)
    }
    pub fn index_of(&self, key: &T) -> Option<usize>
    where T: Clone + Ord
    {
	self.map.index_of(key)
    }
    pub fn get_key(&self, i: usize) -> &T {
	self.map.get_key(i)
    }
    pub fn log_pmf(&self, key: &T) -> f64
    where T: Clone + Ord
    {
	self.lookup(key)
	    .map(|lp| lp)
	    .unwrap_or(f64::NEG_INFINITY)
    }
    pub fn entropy(&self) -> f64 {
	-self.map.values().map(|&(n,lp)| (n as f64) * lp).sum::<f64>()
	    / self.count as f64
    }
    pub fn kld(&self, q: &Self) -> f64
    where T: Clone + Ord
    {
	self.map.iter().map(|(k,(_,lp))| {
	    lp.exp() * (*lp - q.map.lookup(&k).unwrap().1)
	}).sum::<f64>()
    }
    pub fn log_probability(&self, val: &T) -> f64
    where T: Clone + Ord
    {
	self.map.lookup(val)
	    .map(|(_,r)| *r)
	    .unwrap_or(f64::NEG_INFINITY)
    }
    }
impl<T: Debug> UnivariateDistribution for Categorical<T> {
    fn truncated(&self) -> Box<dyn TruncatedDistribution> {
       Box::new(
           TruncatedCategorical{
               lo: 0,
               ln_ps: self.map.values()
                   .map(|(_,r)| *r)
                   .collect() }
       )
    }
}
impl<T: Clone + Ord + Debug + 'static> Model<T> for Categorical<T> {
    fn push(&mut self, s: i64) -> Option<T> {
	Some(self.get_key(s as usize).clone())
    }
    fn next_distr(&mut self) -> Box<dyn UnivariateDistribution> {
	Box::new(self.clone())
    }
}
#[derive(Clone,PartialEq,PartialOrd)]
pub struct TruncatedCategorical {
    lo: usize,
    pub ln_ps: Vec<f64>, }
impl std::fmt::Debug for TruncatedCategorical {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(
            f,
            "TruncatedCategorical {{ lo: {}, ln_ps: {:?} }}",
            self.lo,
            self.ln_ps
        )
    }
}
impl TruncatedCategorical {
    pub fn new(ln_ps: Vec<f64>) -> TruncatedCategorical {
	let mut cat = TruncatedCategorical{ lo: 0, ln_ps };
	cat.trim_left();
	cat.trim_right();
	cat.normalize();
	cat }
    pub fn normalize(&mut self) {
	let total_ln_p = self.ln_ps.iter().copied().ln_sum_exp();
	for lp in self.ln_ps.iter_mut() {
	    *lp -= total_ln_p
	}
    }
    fn trim_left(&mut self) {
	let non_zero = |lp| lp != f64::NEG_INFINITY;
	if non_zero(self.ln_ps[0]) { return } let lo = self.ln_ps.iter().position(|&lp| non_zero(lp)).unwrap();
	self.ln_ps = self.ln_ps[lo..].to_vec();
	self.lo += lo;
    }
    fn trim_right(&mut self) {
	let non_zero = |lp| lp != f64::NEG_INFINITY;
	if non_zero(*self.ln_ps.last().unwrap()) { return } let hi = self.ln_ps.iter().rposition(|&lp| non_zero(lp)).unwrap();
	self.ln_ps.truncate(hi + 1); }
}
impl TruncatedDistribution for TruncatedCategorical {
    fn quantile(&self, cp: f64) -> (i64, f64) {
	let ps = self.ln_ps.iter()
	    .map(|lp| (lp.exp())) .collect::<Vec<_>>();
	let mut cp1 = 0.0;
	let mut cps = ps.iter().map(|&p| { cp1 += p; cp1 }).collect::<Vec<_>>();
	let total_p = *cps.last().unwrap();
	let scale = total_p.recip();
	for cp in cps.iter_mut() { *cp *= scale } let s = cps.partition_point(|&cp1| cp1 <= cp); let s_lo = if s == 0 { 0.0 } else { cps[s - 1] };
	let s_rem = (cp - s_lo) / (scale * ps[s]);
	((s + self.lo) as i64, s_rem)
    }
    fn truncate(&mut self, cp: f64, s: i64, s_rem: f64, bit: bool) {
	let i = s as usize - self.lo; if bit { self.lo += i;
	    self.ln_ps = self.ln_ps[i..].to_vec();
	    self.ln_ps[0] += (1.0 - s_rem).ln();
	    let lccp = (1.0 - cp).ln();
	    for lp in self.ln_ps.iter_mut() {
		*lp -= lccp;
	    }
	} else { self.ln_ps.truncate(i+1); *self.ln_ps.last_mut().unwrap() += s_rem.ln();
	    self.trim_right(); let lcp = cp.ln();
	    for lp in self.ln_ps.iter_mut() {
		*lp -= lcp;
	    }
	}
    }
    fn lo(&self) -> i64 { self.lo as i64 }
    fn hi(&self) -> i64 { (self.lo + self.ln_ps.len() - 1) as i64 }
    fn is_resolved(&self) -> bool { self.ln_ps.len() == 1 }
}