use crate::dag::{Dag, DagLike, NoSharing};
use crate::Tmr;
use std::sync::Arc;
use std::{cmp, fmt, hash};
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)]
pub enum CompleteBound {
Unit,
Sum(Arc<Final>, Arc<Final>),
Product(Arc<Final>, Arc<Final>),
}
#[derive(Clone)]
pub struct Final {
bound: CompleteBound,
bit_width: usize,
tmr: Tmr,
}
impl PartialEq for Final {
fn eq(&self, other: &Self) -> bool {
self.tmr == other.tmr
}
}
impl Eq for Final {}
impl PartialOrd for Final {
fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Final {
fn cmp(&self, other: &Self) -> cmp::Ordering {
self.tmr.cmp(&other.tmr)
}
}
impl hash::Hash for Final {
fn hash<H: hash::Hasher>(&self, hasher: &mut H) {
self.tmr.hash(hasher)
}
}
impl fmt::Debug for Final {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{{ tmr: {}, bit_width: {}, bound: {} }}",
self.tmr, self.bit_width, self
)
}
}
impl fmt::Display for Final {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut skipping: Option<Tmr> = None;
for data in self.verbose_pre_order_iter::<NoSharing>() {
if let Some(skip) = skipping {
if data.is_complete && data.node.tmr == skip {
skipping = None;
}
continue;
} else {
if data.node.tmr == Tmr::POWERS_OF_TWO[0] {
f.write_str("2")?;
skipping = Some(data.node.tmr);
}
for (n, tmr) in Tmr::POWERS_OF_TWO.iter().enumerate().skip(1) {
if data.node.tmr == *tmr {
write!(f, "2^{}", 1 << n)?;
skipping = Some(data.node.tmr);
}
}
}
if skipping.is_some() {
continue;
}
match (&data.node.bound, data.n_children_yielded) {
(CompleteBound::Unit, _) => {
f.write_str("1")?;
}
(CompleteBound::Sum(ref left, _), 0)
if matches!(left.bound, CompleteBound::Unit) =>
{
skipping = Some(Tmr::unit());
}
(CompleteBound::Sum(ref left, _), 1)
if matches!(left.bound, CompleteBound::Unit) => {}
(CompleteBound::Sum(ref left, _), 2)
if matches!(left.bound, CompleteBound::Unit) =>
{
f.write_str("?")?;
}
(CompleteBound::Sum(..), 0) | (CompleteBound::Product(..), 0) => {
if data.index > 0 {
f.write_str("(")?;
}
}
(CompleteBound::Sum(..), 2) | (CompleteBound::Product(..), 2) => {
if data.index > 0 {
f.write_str(")")?;
}
}
(CompleteBound::Sum(..), _) => f.write_str(" + ")?,
(CompleteBound::Product(..), _) => f.write_str(" × ")?,
}
}
Ok(())
}
}
impl<'a> DagLike for &'a Final {
type Node = Final;
fn data(&self) -> &Final {
self
}
fn as_dag_node(&self) -> Dag<Self> {
match self.bound {
CompleteBound::Unit => Dag::Nullary,
CompleteBound::Sum(ref left, ref right)
| CompleteBound::Product(ref left, ref right) => Dag::Binary(left, right),
}
}
}
impl Final {
pub(super) const fn unit() -> Self {
Final {
bound: CompleteBound::Unit,
bit_width: 0,
tmr: Tmr::unit(),
}
}
pub fn two_two_n(n: usize) -> Arc<Self> {
super::precomputed::nth_power_of_2(n).final_data().unwrap()
}
pub(super) fn sum(left: Arc<Self>, right: Arc<Self>) -> Self {
Final {
tmr: Tmr::sum(left.tmr, right.tmr),
bit_width: 1 + cmp::max(left.bit_width, right.bit_width),
bound: CompleteBound::Sum(left, right),
}
}
pub(super) fn product(left: Arc<Self>, right: Arc<Self>) -> Self {
Final {
tmr: Tmr::product(left.tmr, right.tmr),
bit_width: left.bit_width + right.bit_width,
bound: CompleteBound::Product(left, right),
}
}
pub fn tmr(&self) -> Tmr {
self.tmr
}
pub fn bit_width(&self) -> usize {
self.bit_width
}
pub fn bound(&self) -> &CompleteBound {
&self.bound
}
pub fn is_unit(&self) -> bool {
self.bound == CompleteBound::Unit
}
pub fn split_sum(&self) -> Option<(Arc<Self>, Arc<Self>)> {
match &self.bound {
CompleteBound::Sum(left, right) => Some((left.clone(), right.clone())),
_ => None,
}
}
pub fn split_product(&self) -> Option<(Arc<Self>, Arc<Self>)> {
match &self.bound {
CompleteBound::Product(left, right) => Some((left.clone(), right.clone())),
_ => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn final_stringify() {
let ty1 = Final::two_two_n(10);
assert_eq!(ty1.to_string(), "2^1024");
let sum = Final::sum(Final::two_two_n(5), Final::two_two_n(10));
assert_eq!(sum.to_string(), "2^32 + 2^1024");
let prod = Final::product(Final::two_two_n(5), Final::two_two_n(10));
assert_eq!(prod.to_string(), "2^32 × 2^1024");
let ty1 = Final::two_two_n(0);
assert_eq!(ty1.to_string(), "2");
let ty1 = Final::sum(Arc::new(Final::unit()), Final::two_two_n(2));
assert_eq!(ty1.to_string(), "2^4?");
}
}