use std::cmp::{max, Reverse};
use std::fs;
use std::io;
use std::path::{Path, PathBuf};
use std::str::FromStr;
use arrayvec::ArrayVec;
use once_cell::sync::OnceCell;
use rustc_hash::FxHashMap;
use positioned_io::RandomAccessFile;
use shakmaty::{Material, Move, Position, Role};
use crate::errors::{ProbeResultExt as _, SyzygyError, SyzygyResult};
use crate::table::{DtzTable, WdlTable};
use crate::types::{DecisiveWdl, Dtz, Metric, Syzygy, Wdl};
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
enum ProbeState {
Normal,
ZeroingBestMove,
Threat,
}
#[derive(Debug)]
pub struct Tablebase<S: Position + Clone + Syzygy> {
wdl: FxHashMap<Material, (PathBuf, OnceCell<WdlTable<S, RandomAccessFile>>)>,
dtz: FxHashMap<Material, (PathBuf, OnceCell<DtzTable<S, RandomAccessFile>>)>,
}
impl<S: Position + Clone + Syzygy> Default for Tablebase<S> {
fn default() -> Tablebase<S> {
Tablebase::new()
}
}
impl<S: Position + Clone + Syzygy> Tablebase<S> {
pub fn new() -> Tablebase<S> {
Tablebase {
wdl: FxHashMap::with_capacity_and_hasher(145, Default::default()),
dtz: FxHashMap::with_capacity_and_hasher(145, Default::default()),
}
}
pub fn add_directory<P: AsRef<Path>>(&mut self, path: P) -> io::Result<usize> {
let mut num = 0;
for entry in fs::read_dir(path)? {
if self.add_file(entry?.path()).is_ok() {
num += 1;
}
}
Ok(num)
}
pub fn add_file<P: AsRef<Path>>(&mut self, path: P) -> io::Result<()> {
let path = path.as_ref();
if !path.is_file() {
return Err(io::Error::from(io::ErrorKind::InvalidInput));
}
let (stem, ext) = match (path.file_stem().and_then(|s| s.to_str()), path.extension()) {
(Some(stem), Some(ext)) => (stem, ext),
_ => return Err(io::Error::from(io::ErrorKind::InvalidInput)),
};
let material = match Material::from_str(stem) {
Ok(material) => material,
_ => return Err(io::Error::from(io::ErrorKind::InvalidInput)),
};
if material.count() > S::MAX_PIECES {
return Err(io::Error::from(io::ErrorKind::InvalidInput));
}
if material.white.count() < 1 || material.black.count() < 1 {
return Err(io::Error::from(io::ErrorKind::InvalidInput));
}
if ext == S::TBW.ext || (!material.has_pawns() && S::PAWNLESS_TBW.map_or(false, |t| ext == t.ext)) {
self.wdl.insert(material, (path.to_path_buf(), OnceCell::new()));
} else if ext == S::TBZ.ext || (!material.has_pawns() && S::PAWNLESS_TBZ.map_or(false, |t| ext == t.ext)) {
self.dtz.insert(material, (path.to_path_buf(), OnceCell::new()));
} else {
return Err(io::Error::from(io::ErrorKind::InvalidInput));
}
Ok(())
}
fn wdl_table(&self, key: &Material) -> SyzygyResult<&WdlTable<S, RandomAccessFile>> {
if let Some(&(ref path, ref table)) = self.wdl.get(key).or_else(|| self.wdl.get(&key.clone().into_flipped())) {
table.get_or_try_init(|| WdlTable::open(path, key)).ctx(Metric::Wdl, key)
} else {
Err(SyzygyError::MissingTable {
metric: Metric::Wdl,
material: key.clone().into_normalized(),
})
}
}
fn dtz_table(&self, key: &Material) -> SyzygyResult<&DtzTable<S, RandomAccessFile>> {
if let Some(&(ref path, ref table)) = self.dtz.get(key).or_else(|| self.dtz.get(&key.clone().into_flipped())) {
table.get_or_try_init(|| DtzTable::open(path, key)).ctx(Metric::Dtz, key)
} else {
Err(SyzygyError::MissingTable {
metric: Metric::Dtz,
material: key.clone().into_normalized(),
})
}
}
pub fn probe_wdl(&self, pos: &S) -> SyzygyResult<Wdl> {
self.probe(pos).map(|entry| entry.wdl())
}
pub fn probe_dtz(&self, pos: &S) -> SyzygyResult<Dtz> {
self.probe(pos).and_then(|entry| entry.dtz())
}
pub fn best_move(&self, pos: &S) -> SyzygyResult<Option<(Move, Dtz)>> {
struct WithAfter<S> {
m: Move,
after: S,
}
struct WithWdlEntry<'a, S: Position + Clone + Syzygy> {
m: Move,
entry: WdlEntry<'a, S>,
}
struct WithDtz {
m: Move,
immediate_loss: bool,
zeroing: bool,
dtz: Dtz,
}
let with_after = pos.legal_moves().into_iter().map(|m| {
let mut after = pos.clone();
after.play_unchecked(&m);
WithAfter { m, after }
}).collect::<ArrayVec<[_; 256]>>();
let with_wdl = with_after.iter().map(|e| Ok(WithWdlEntry {
m: e.m.clone(),
entry: self.probe(&e.after)?,
})).collect::<SyzygyResult<ArrayVec<[_; 256]>>>()?;
let best_wdl = with_wdl.iter().map(|a| a.entry.wdl).min().unwrap_or(Wdl::Loss);
itertools::process_results(with_wdl.iter().filter(|a| a.entry.wdl == best_wdl).map(|a| {
let dtz = a.entry.dtz()?;
Ok(WithDtz {
immediate_loss: dtz == Dtz(-1) && (a.entry.pos.is_checkmate() || a.entry.pos.variant_outcome().is_some()),
zeroing: a.m.is_zeroing(),
m: a.m.clone(),
dtz,
})
}), |iter| iter.min_by_key(|m| (
Reverse(m.immediate_loss),
m.zeroing ^ (m.dtz < Dtz(0)),
Reverse(m.dtz),
)).map(|m| (m.m, m.dtz)))
}
fn probe<'a>(&'a self, pos: &'a S) -> SyzygyResult<WdlEntry<'a, S>> {
if pos.board().occupied().count() > S::MAX_PIECES {
return Err(SyzygyError::TooManyPieces);
}
if pos.castles().any() {
return Err(SyzygyError::Castling);
}
if S::CAPTURES_COMPULSORY {
let (v, state) = self.probe_compulsory_captures(pos, Wdl::Loss, Wdl::Win, true)?;
return Ok(WdlEntry {
tablebase: self,
pos,
wdl: v,
state,
});
} else if let Some(outcome) = pos.variant_outcome() {
return Ok(WdlEntry {
tablebase: self,
pos,
wdl: Wdl::from_outcome(outcome, pos.turn()),
state: ProbeState::ZeroingBestMove,
});
}
let mut best_capture = Wdl::Loss;
let mut best_ep = Wdl::Loss;
let legals = pos.legal_moves();
for m in legals.iter().filter(|m| m.is_capture()) {
let mut after = pos.clone();
after.play_unchecked(m);
let v = -self.probe_ab_no_ep(&after, Wdl::Loss, -best_capture)?;
if v == Wdl::Win {
return Ok(WdlEntry {
tablebase: self,
pos,
wdl: v,
state: ProbeState::ZeroingBestMove,
});
}
if m.is_en_passant() {
best_ep = max(best_ep, v);
} else {
best_capture = max(best_capture, v);
}
}
let v = self.probe_wdl_table(pos)?;
if best_ep > max(v, best_capture) {
return Ok(WdlEntry {
tablebase: self,
pos,
wdl: best_ep,
state: ProbeState::ZeroingBestMove,
});
}
best_capture = max(best_capture, best_ep);
if best_capture >= v {
return Ok(WdlEntry {
tablebase: self,
pos,
wdl: best_capture,
state: if best_capture > Wdl::Draw { ProbeState::ZeroingBestMove } else { ProbeState::Normal },
})
}
if v == Wdl::Draw && !legals.is_empty() && legals.iter().all(|m| m.is_en_passant()) {
return Ok(WdlEntry {
tablebase: self,
pos,
wdl: best_ep,
state: ProbeState::ZeroingBestMove,
})
}
Ok(WdlEntry {
tablebase: self,
pos,
wdl: v,
state: ProbeState::Normal,
})
}
fn probe_ab_no_ep(&self, pos: &S, mut alpha: Wdl, beta: Wdl) -> SyzygyResult<Wdl> {
assert!(pos.ep_square().is_none());
for m in pos.capture_moves() {
let mut after = pos.clone();
after.play_unchecked(&m);
let v = -self.probe_ab_no_ep(&after, -beta, -alpha)?;
if v >= beta {
return Ok(v);
}
alpha = max(alpha, v);
}
let v = self.probe_wdl_table(pos)?;
Ok(max(alpha, v))
}
fn probe_compulsory_captures(&self, pos: &S, mut alpha: Wdl, beta: Wdl, threats: bool) -> SyzygyResult<(Wdl, ProbeState)> {
assert!(S::CAPTURES_COMPULSORY);
if let Some(outcome) = pos.variant_outcome() {
return Ok((Wdl::from_outcome(outcome, pos.turn()), ProbeState::ZeroingBestMove));
}
if pos.them().count() > 1 {
if let Some(v) = self.probe_captures(pos, alpha, beta)? {
return Ok((v, ProbeState::ZeroingBestMove));
}
} else {
if !pos.capture_moves().is_empty() {
return Ok((Wdl::Loss, ProbeState::ZeroingBestMove));
}
}
let mut threats_found = false;
if threats || pos.board().occupied().count() >= 6 {
for threat in pos.legal_moves() {
if threat.role() != Role::Pawn {
let mut after = pos.clone();
after.play_unchecked(&threat);
if let Some(v_plus) = self.probe_captures(&after, -beta, -alpha)? {
let v = -v_plus;
if v > alpha {
threats_found = true;
alpha = v;
if alpha >= beta {
return Ok((v, ProbeState::Threat));
}
}
}
}
}
}
let v = self.probe_wdl_table(pos)?;
if v > alpha {
Ok((v, ProbeState::Normal))
} else {
Ok((alpha, if threats_found { ProbeState::Threat } else { ProbeState::Normal }))
}
}
fn probe_captures(&self, pos: &S, mut alpha: Wdl, beta: Wdl) -> SyzygyResult<Option<Wdl>> {
assert!(S::CAPTURES_COMPULSORY);
let captures = pos.capture_moves();
for m in pos.capture_moves() {
let mut after = pos.clone();
after.play_unchecked(&m);
let (v_plus, _) = self.probe_compulsory_captures(&after, -beta, -alpha, false)?;
let v = -v_plus;
alpha = max(v, alpha);
if alpha >= beta {
break;
}
}
Ok(if captures.is_empty() { None } else { Some(alpha) })
}
fn probe_wdl_table(&self, pos: &S) -> SyzygyResult<Wdl> {
if let Some(outcome) = pos.variant_outcome() {
return Ok(Wdl::from_outcome(outcome, pos.turn()));
}
if S::ONE_KING && pos.board().kings() == pos.board().occupied() {
return Ok(Wdl::Draw);
}
let key = pos.board().material();
self.wdl_table(&key)
.and_then(|table| table.probe_wdl(pos).ctx(Metric::Wdl, &key))
}
fn probe_dtz_table(&self, pos: &S, wdl: DecisiveWdl) -> SyzygyResult<Option<Dtz>> {
let key = pos.board().material();
self.dtz_table(&key)
.and_then(|table| table.probe_dtz(pos, wdl).ctx(Metric::Dtz, &key))
}
}
#[derive(Debug)]
struct WdlEntry<'a, S: Position + Clone + Syzygy> {
tablebase: &'a Tablebase<S>,
pos: &'a S,
wdl: Wdl,
state: ProbeState,
}
impl<'a, S: Position + Clone + Syzygy + 'a> WdlEntry<'a, S> {
fn wdl(&self) -> Wdl {
self.wdl
}
fn dtz(&self) -> SyzygyResult<Dtz> {
let wdl = match self.wdl.decisive() {
Some(wdl) => wdl,
None => return Ok(Dtz(0)),
};
if self.state == ProbeState::ZeroingBestMove || self.pos.us() == self.pos.our(Role::Pawn) {
return Ok(Dtz::before_zeroing(wdl));
}
if self.state == ProbeState::Threat && wdl >= DecisiveWdl::CursedWin {
return Ok(Dtz::before_zeroing(wdl).add_plies(1));
}
if wdl >= DecisiveWdl::CursedWin {
let mut pawn_advances = self.pos.legal_moves();
pawn_advances.retain(|m| !m.is_capture() && m.role() == Role::Pawn);
for m in &pawn_advances {
let mut after = self.pos.clone();
after.play_unchecked(m);
let v = -self.tablebase.probe_wdl(&after)?;
if v == wdl.into() {
return Ok(Dtz::before_zeroing(wdl));
}
}
}
if let Some(Dtz(dtz)) = self.tablebase.probe_dtz_table(&self.pos, wdl)? {
return Ok(Dtz::before_zeroing(wdl).add_plies(dtz));
}
let mut moves = self.pos.legal_moves();
moves.retain(|m| !m.is_zeroing());
let mut best = if wdl >= DecisiveWdl::CursedWin {
None
} else {
Some(Dtz::before_zeroing(wdl))
};
for m in &moves {
let mut after = self.pos.clone();
after.play_unchecked(m);
let v = -self.tablebase.probe_dtz(&after)?;
if v == Dtz(1) && after.is_checkmate() {
best = Some(Dtz(1));
} else if wdl >= DecisiveWdl::CursedWin {
if v > Dtz(0) && best.map_or(true, |best| v + Dtz(1) < best) {
best = Some(v + Dtz(1));
}
} else if best.map_or(true, |best| v - Dtz(1) < best) {
best = Some(v - Dtz(1));
}
}
(|| Ok(u!(best)))().ctx(Metric::Dtz, &self.pos.board().material())
}
}
#[cfg(test)]
mod tests {
use super::*;
use shakmaty::fen::Fen;
use shakmaty::{CastlingMode, Chess, Square};
#[test]
fn test_send_sync() {
fn assert_send<T: Send>(_: T) {}
fn assert_sync<T: Sync>(_: T) {}
assert_send(Tablebase::<Chess>::new());
assert_sync(Tablebase::<Chess>::new());
}
#[test]
fn test_mating_best_move() {
let mut tables = Tablebase::new();
tables.add_directory("tables/chess").expect("read directory");
let pos: Chess = "5BrN/8/8/8/8/2k5/8/2K5 b - -"
.parse::<Fen>()
.expect("valid fen")
.position(CastlingMode::Chess960)
.expect("legal position");
assert!(matches!(tables.best_move(&pos), Ok(Some((Move::Normal {
role: Role::Rook,
from: Square::G8,
capture: None,
to: Square::G1,
promotion: None,
}, Dtz(-1))))));
}
#[test]
fn test_black_escapes_via_underpromotion() {
let mut tables = Tablebase::new();
tables.add_directory("tables/chess").expect("read directory");
let pos: Chess = "8/6B1/8/8/B7/8/K1pk4/8 b - - 0 1"
.parse::<Fen>()
.expect("valid fen")
.position(CastlingMode::Chess960)
.expect("legal position");
assert!(matches!(tables.best_move(&pos), Ok(Some((Move::Normal {
role: Role::Pawn,
from: Square::C2,
to: Square::C1,
capture: None,
promotion: Some(Role::Knight),
}, Dtz(109))))));
}
#[test]
#[ignore]
fn test_many_pawns() {
let mut tables = Tablebase::new();
tables.add_directory("tables/chess").expect("read directory");
let pos: Chess = "3k4/5P2/8/8/4K3/2P3P1/PP6/8 w - - 0 1"
.parse::<Fen>()
.expect("valid fen")
.position(CastlingMode::Chess960)
.expect("legal position");
assert!(matches!(tables.probe_dtz(&pos), Ok(Dtz(1))));
}
}