pub trait Driver {
type Algo;
Show 14 methods
// Required methods
fn solve(self) -> Result<(Self::Algo, Step), DriverError>;
fn iter_step(
&mut self,
) -> Result<Option<(&mut Self::Algo, &Step)>, DriverError>;
fn on_step<G>(self, g: G) -> impl Driver<Algo = Self::Algo>
where Self: Sized,
G: FnMut(&mut Self::Algo, &Step);
fn try_on_step<F, E>(self, f: F) -> impl Driver<Algo = Self::Algo>
where Self: Sized,
F: FnMut(&mut Self::Algo, &Step) -> Result<(), E>,
E: Error + Sync + Send + 'static;
fn converge_when<F>(self, pred: F) -> impl Driver<Algo = Self::Algo>
where Self: Sized,
F: FnMut(&mut Self::Algo, &Step) -> bool;
fn fail_if<F>(self, pred: F) -> impl Driver<Algo = Self::Algo>
where Self: Sized,
F: FnMut(&mut Self::Algo, &Step) -> bool;
fn show_progress_bar_after(
self,
after: Duration,
) -> impl Driver<Algo = Self::Algo>
where Self: Sized;
fn set_fixed_iters(
self,
fixed_iters: usize,
) -> impl Driver<Algo = Self::Algo>
where Self: Sized;
fn set_fail_after_iters(
self,
max_iters: usize,
) -> impl Driver<Algo = Self::Algo>
where Self: Sized;
fn set_timeout(self, timeout: Duration) -> impl Driver<Algo = Self::Algo>
where Self: Sized;
fn checkpoint(
self,
path: impl AsRef<Path>,
every: Duration,
) -> impl Driver<Algo = Self::Algo>
where Self: Sized;
fn recovery(
self,
path: impl AsRef<Path>,
action: RecoveryFile,
) -> impl Driver<Algo = Self::Algo>
where Self: Sized;
fn into_boxed(self) -> BoxedDriver<Self::Algo>
where Self: Sized + 'static;
fn into_dyn(self) -> DynDriver
where Self: Sized + 'static,
Self::Algo: Algo;
}
Expand description
An executor for the algorithm controlling max iterations, execution time, convergence criteria, and callbacks.
Required Associated Types§
Required Methods§
Sourcefn solve(self) -> Result<(Self::Algo, Step), DriverError>
fn solve(self) -> Result<(Self::Algo, Step), DriverError>
Runs the algorithm, until failure, iteration exhaustion, or convergence.
In cases of success, the algo in its final state and the final iteration step is returned.
By convention, the algo variable in its final state is named solved
, and will offer accessors
to retrieve the solution or best approximation to the solution.
§Example
let (solved, step) = fixed_iters(my_algo, 1000)
.solve()?;
assert_approx_eq!(solved.x(), &vec![1.5, 2.0] );
assert_eq!(step.iteration(), 1000 );
fn iter_step(&mut self) -> Result<Option<(&mut Self::Algo, &Step)>, DriverError>
Sourcefn on_step<G>(self, g: G) -> impl Driver<Algo = Self::Algo>
fn on_step<G>(self, g: G) -> impl Driver<Algo = Self::Algo>
Invoked after each iteration, allowing printing or debugging of the iteration Step. Any errors
encountered by the underlying algorithm will terminate the solving, and Self::solve
will return the error.
Can be used to capture details of the iteration, or update hyperparameters on the algorithm for adaptive learning.
§Example
let (solved, _step) = fixed_iters(my_algo, 1000)
.on_step(|_algo, step| println!("{step:?}"))
.solve()?;
Sourcefn try_on_step<F, E>(self, f: F) -> impl Driver<Algo = Self::Algo>
fn try_on_step<F, E>(self, f: F) -> impl Driver<Algo = Self::Algo>
Invoked after each iteration, allowing printing or debugging of the iteration step.
Can be used to capture details of the iteration. See Self::try_on_step
.
Sourcefn converge_when<F>(self, pred: F) -> impl Driver<Algo = Self::Algo>
fn converge_when<F>(self, pred: F) -> impl Driver<Algo = Self::Algo>
Used for early stopping.
Common convergence predicates are step size below a small epsilon, or residuals being below near zero.
Some specific convergence criteria are best handled by using metrics
within this callback.
Sourcefn fail_if<F>(self, pred: F) -> impl Driver<Algo = Self::Algo>
fn fail_if<F>(self, pred: F) -> impl Driver<Algo = Self::Algo>
Decide when to abandon iteration with Self::solve
returning DriverError::FailIfPredicate
.
Convergence predicates are tested first.
Common failure predicates are residuals growing in size after an initial set of iterations,
or user cancelation - implemented perhaps by the closure predicate
checking the value of an AtomicBool
.
Sourcefn show_progress_bar_after(
self,
after: Duration,
) -> impl Driver<Algo = Self::Algo>where
Self: Sized,
fn show_progress_bar_after(
self,
after: Duration,
) -> impl Driver<Algo = Self::Algo>where
Self: Sized,
Display a progress bar to the terminal after the algo has been running for this duration.
Use Duration::MAX
to disable the progress bar, and Duration::ZERO
to display it immediately.
The progress bar will refresh at most every 50ms.
The progress bar will not display if either:
- stdout is not a terminal (perhaps redirected to a ulility or file)
- the environment variable
NO_COLOR
is set - the elapsed time never reaches the duration set by this function
§Example
let solution = fixed_iters(my_algo, 1000)
// show a progress bar if the algo is slow to execute
.show_progress_bar_after(Duration::from_millis(250))
.solve()?;
Sourcefn set_fixed_iters(self, fixed_iters: usize) -> impl Driver<Algo = Self::Algo>where
Self: Sized,
fn set_fixed_iters(self, fixed_iters: usize) -> impl Driver<Algo = Self::Algo>where
Self: Sized,
After this number of iterations the algo is deemed to have converged on a solution.
If Driver::converge_when
is used, then iteration may terminate early.
Contrast with Driver::set_fail_after_iters
Sourcefn set_fail_after_iters(
self,
max_iters: usize,
) -> impl Driver<Algo = Self::Algo>where
Self: Sized,
fn set_fail_after_iters(
self,
max_iters: usize,
) -> impl Driver<Algo = Self::Algo>where
Self: Sized,
Unless convergence has occurred, the driver will error after the designated iteration count.
Either the algo has internal logic to terminate iteration when a solution is found,
or the caller needs to use Driver::converge_when
to ensure iteration stops before a max iterations exceeded error
.
Contrast with Driver::set_fixed_iters
Sourcefn set_timeout(self, timeout: Duration) -> impl Driver<Algo = Self::Algo>where
Self: Sized,
fn set_timeout(self, timeout: Duration) -> impl Driver<Algo = Self::Algo>where
Self: Sized,
Either the algo
has internal logic to terminate iteration when a solution is found, or the caller
needs to use Driver::converge_when
to ensure iteration stops before a (timeout error)DriverError::Timeout.
Any previous timeout will be overwritten
Sourcefn checkpoint(
self,
path: impl AsRef<Path>,
every: Duration,
) -> impl Driver<Algo = Self::Algo>where
Self: Sized,
fn checkpoint(
self,
path: impl AsRef<Path>,
every: Duration,
) -> impl Driver<Algo = Self::Algo>where
Self: Sized,
For writing checkpoint files, for possible later recovery.
Checkpoints are useful for very long running iterations, where restarting from ‘last checkpoint’ in cases of network error or crash, is a timesaver.
The path can contain replaceable tokens. Note that if the path is generated by a format!
macro then the {
will need to be
escaped by doubling {{
.
- {iter} : current iteration number eg. 0000745
- {elapsed} : current elapsed time in hours, mins, secs and millis eg 00h-34m-12s-034ms
format_duration
- {pid} : process ID of the running process
The file format is determined by the file extension of the supplied path.
There will typically be a file written for the algorithm, and a .control
file for the Driver
state. Both are needed for recovery.
The control file’s location is the path and file name supplied, but with the extension set to ‘.control’.
Which file formats (if any) are supported is algorithm dependent.
§Example:
let my_program_name = "algo-solver";
let every = Duration::from_secs(60);
let (solved, _) = fixed_iters(my_algo, 2)
.checkpoint(format!("/tmp/checkpoint-{my_program_name}-{{iter}}.json"), every)
.solve()?;
Sourcefn recovery(
self,
path: impl AsRef<Path>,
action: RecoveryFile,
) -> impl Driver<Algo = Self::Algo>where
Self: Sized,
fn recovery(
self,
path: impl AsRef<Path>,
action: RecoveryFile,
) -> impl Driver<Algo = Self::Algo>where
Self: Sized,
For recovery from checkpoint files. Specifies the location of files used to recover algorithm and driver state.
Typically you dont want to always recover from the last checkpoint file, only in cases where the iterative process was terminated. This
entails having a recovery file location different to the checkpoint location, and some manual process to move/rename files from
where they are written, to the recovery path location. See RecoveryFile
.
There will typically be a file for the algorithm, and a “.control” file for the Driver
state. Both are needed.
Note that metrics outside the driver, such as moving averages will be re-initialized, even though the algorithm is recovered
.
Which file formats (if any) are supported is algorithm dependent.
§Example:
let every = Duration::from_secs(60);
let action = if command_line == "recover" {
RecoveryFile::Require
} else {
RecoveryFile::Ignore
};
let (solved, _) = fixed_iters(my_algo, 2)
.recovery("/tmp/checkpoint-to-use.json", action)
.checkpoint(format!("/tmp/checkpoint-{{iter}}.json"), every)
.solve()?;
Sourcefn into_boxed(self) -> BoxedDriver<Self::Algo>where
Self: Sized + 'static,
fn into_boxed(self) -> BoxedDriver<Self::Algo>where
Self: Sized + 'static,
Converts the Driver
into a sized type (BoxedDriver<A>
), but remains generic over the algorithm.
§Examples
The following wont compile as the if/then arms have different types…
let driver = if log_enabled {
fixed_iters(my_algo, 100).on_step(|_, step| println!("{step:?}"))
} else {
fixed_iters(my_algo, 100)
};
Boxing a dyn trait gets around this…
let log_enabled = true;
let driver = if log_enabled {
fixed_iters(my_algo, 100).on_step(|_, step| println!("{step:?}")).into_boxed()
} else {
fixed_iters(my_algo, 100).into_boxed()
};
// driver is type BoxedDriver<MyAlgo>
let (my_algo, _step) = driver.show_progress_bar_after(Duration::ZERO).solve()?;
Sourcefn into_dyn(self) -> DynDriver
fn into_dyn(self) -> DynDriver
Converts the Driver
into a sized type without generics (DynDriver
)
The prefered was of having common routines pass a Driver as a parameter
is to pass Driver by impl Driver
.
But occasionally a concrete implementation is required.
Whereas BoxedDriver<A>
is generic over the algorithm, DynDriver
is a concrete type
and can be stored in non-generic structs.
To call methods on DynAlgo
you will need to downcast to the specific Algorithm used,
via DynAlgo::downcast_ref
or DynAlgo::downcast_mut
Since closures cannot be named, if the algorithm is generic over closure/functions you will need to Box the closure. See examples folder.
§Example:
// note we need a concrete *function* to name and downcast to
type GradFn = fn(&[f64]) -> Vec<f64>;
let lr = 0.1;
let my_closure = |x: &[f64]| vec![2.0 * sphere_grad(x)[0], sphere_grad(x)[1]];
// coerce the closure into a function (must not capture from environment)
let grad: GradFn = my_closure;
let gd = GradientDescent::new(lr, vec![5.55, 5.55], grad);
let (mut dyn_solved, _step) = fixed_iters(gd.clone(), 500)
.show_progress_bar_after(Duration::ZERO)
.into_dyn()
.solve()
.unwrap();
let gd_algo: &GradientDescent<GradFn> =
dyn_solved.downcast_ref().expect("failed to downcast_ref");
let solution = gd_algo.x();
assert_approx_eq!(solution, &[0.0, 0.0]);