stochy/
lib.rs

1#![doc=include_str!("../README.md")]
2#![doc=include_str!("../README_MORE.md")]
3#![deny(
4    future_incompatible,
5    missing_docs,
6    nonstandard_style,
7    unsafe_op_in_unsafe_fn,
8    unused,
9    warnings,
10    clippy::all,
11    clippy::missing_safety_doc,
12    clippy::undocumented_unsafe_blocks,
13    rustdoc::broken_intra_doc_links,
14    rustdoc::missing_crate_level_docs
15)]
16#![allow(clippy::empty_docs)]
17#![cfg_attr(all(docsrs, not(doctest)), feature(doc_cfg, doc_auto_cfg))]
18
19mod common;
20mod rspsa;
21mod spsa;
22
23#[cfg(feature = "argmin")]
24mod rspsa_argmin;
25
26#[cfg(feature = "argmin")]
27mod spsa_argmin;
28
29#[cfg(feature = "argmin")]
30#[cfg_attr(docsrs, doc(cfg(feature = "argmin")))]
31pub use rspsa_argmin::RspsaSolverArgmin;
32
33#[cfg(feature = "argmin")]
34#[cfg_attr(docsrs, doc(cfg(feature = "argmin")))]
35pub use spsa_argmin::SpsaSolverArgmin;
36
37pub use rspsa::{RspsaAlgo, RspsaParams};
38pub use spsa::{SpsaAlgo, SpsaParams};
39
40use std::{error::Error, fmt::Display, sync::Arc};
41
42/// The error type for the Stochy library.
43#[derive(Debug, Clone)]
44pub enum StochyError {
45    /// Represents an error caused by an invalid hyperparameter.
46    InvalidHyperparameter(String),
47    /// Represents an error caused by calling the underlying objective function being solved
48    ObjectiveFunction(Arc<dyn Error + Send + Sync + 'static>),
49}
50type BoxedError = Box<dyn Error + Send + Sync + 'static>;
51
52impl From<BoxedError> for StochyError {
53    fn from(err: BoxedError) -> Self {
54        StochyError::ObjectiveFunction(Arc::from(err))
55    }
56}
57
58impl Display for StochyError {
59    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
60        std::fmt::Debug::fmt(self, f)
61    }
62}
63
64impl Error for StochyError {
65    fn source(&self) -> Option<&(dyn Error + 'static)> {
66        match &self {
67            Self::ObjectiveFunction(e) => Some(e.as_ref()),
68            _ => None,
69        }
70    }
71}
72
73#[allow(dead_code)]
74struct AssertSendSync<T: Send + Sync>(std::marker::PhantomData<T>);
75const _: AssertSendSync<StochyError> = AssertSendSync(std::marker::PhantomData);
76
77// generate a skipped tests count if feature flag not used
78#[cfg(not(feature = "argmin"))]
79#[test]
80#[ignore]
81fn armin_tests_not_run() {}