Trait Driver

Source
pub trait Driver {
    type Algo;

Show 14 methods // Required methods fn solve(self) -> Result<(Self::Algo, Step), DriverError>; fn iter_step( &mut self, ) -> Result<Option<(&mut Self::Algo, &Step)>, DriverError>; fn on_step<G>(self, g: G) -> impl Driver<Algo = Self::Algo> where Self: Sized, G: FnMut(&mut Self::Algo, &Step); fn try_on_step<F, E>(self, f: F) -> impl Driver<Algo = Self::Algo> where Self: Sized, F: FnMut(&mut Self::Algo, &Step) -> Result<(), E>, E: Error + Sync + Send + 'static; fn converge_when<F>(self, pred: F) -> impl Driver<Algo = Self::Algo> where Self: Sized, F: FnMut(&mut Self::Algo, &Step) -> bool; fn fail_if<F>(self, pred: F) -> impl Driver<Algo = Self::Algo> where Self: Sized, F: FnMut(&mut Self::Algo, &Step) -> bool; fn show_progress_bar_after( self, after: Duration, ) -> impl Driver<Algo = Self::Algo> where Self: Sized; fn set_fixed_iters( self, fixed_iters: usize, ) -> impl Driver<Algo = Self::Algo> where Self: Sized; fn set_fail_after_iters( self, max_iters: usize, ) -> impl Driver<Algo = Self::Algo> where Self: Sized; fn set_timeout(self, timeout: Duration) -> impl Driver<Algo = Self::Algo> where Self: Sized; fn checkpoint( self, path: impl AsRef<Path>, every: Duration, ) -> impl Driver<Algo = Self::Algo> where Self: Sized; fn recovery( self, path: impl AsRef<Path>, action: RecoveryFile, ) -> impl Driver<Algo = Self::Algo> where Self: Sized; fn into_boxed(self) -> BoxedDriver<Self::Algo> where Self: Sized + 'static; fn into_dyn(self) -> DynDriver where Self: Sized + 'static, Self::Algo: Algo;
}
Expand description

An executor for the algorithm controlling max iterations, execution time, convergence criteria, and callbacks.

Required Associated Types§

Source

type Algo

The underlying algorithm being driven.

Required Methods§

Source

fn solve(self) -> Result<(Self::Algo, Step), DriverError>

Runs the algorithm, until failure, iteration exhaustion, or convergence.

In cases of success, the algo in its final state and the final iteration step is returned.

By convention, the algo variable in its final state is named solved, and will offer accessors to retrieve the solution or best approximation to the solution.

§Example
let (solved, step) = fixed_iters(my_algo, 1000)
    .solve()?;

assert_approx_eq!(solved.x(),  &vec![1.5, 2.0] );
assert_eq!(step.iteration(), 1000 );
Source

fn iter_step(&mut self) -> Result<Option<(&mut Self::Algo, &Step)>, DriverError>

Source

fn on_step<G>(self, g: G) -> impl Driver<Algo = Self::Algo>
where Self: Sized, G: FnMut(&mut Self::Algo, &Step),

Invoked after each iteration, allowing printing or debugging of the iteration Step. Any errors encountered by the underlying algorithm will terminate the solving, and Self::solve will return the error.

Can be used to capture details of the iteration, or update hyperparameters on the algorithm for adaptive learning.

§Example
let (solved, _step) = fixed_iters(my_algo, 1000)
    .on_step(|_algo, step| println!("{step:?}"))
    .solve()?;
Source

fn try_on_step<F, E>(self, f: F) -> impl Driver<Algo = Self::Algo>
where Self: Sized, F: FnMut(&mut Self::Algo, &Step) -> Result<(), E>, E: Error + Sync + Send + 'static,

Invoked after each iteration, allowing printing or debugging of the iteration step. Can be used to capture details of the iteration. See Self::try_on_step.

Source

fn converge_when<F>(self, pred: F) -> impl Driver<Algo = Self::Algo>
where Self: Sized, F: FnMut(&mut Self::Algo, &Step) -> bool,

Used for early stopping.

Common convergence predicates are step size below a small epsilon, or residuals being below near zero. Some specific convergence criteria are best handled by using metrics within this callback.

Source

fn fail_if<F>(self, pred: F) -> impl Driver<Algo = Self::Algo>
where Self: Sized, F: FnMut(&mut Self::Algo, &Step) -> bool,

Decide when to abandon iteration with Self::solve returning DriverError::FailIfPredicate.

Convergence predicates are tested first.

Common failure predicates are residuals growing in size after an initial set of iterations, or user cancelation - implemented perhaps by the closure predicate checking the value of an AtomicBool.

Source

fn show_progress_bar_after( self, after: Duration, ) -> impl Driver<Algo = Self::Algo>
where Self: Sized,

Display a progress bar to the terminal after the algo has been running for this duration.

Use Duration::MAX to disable the progress bar, and Duration::ZERO to display it immediately. The progress bar will refresh at most every 50ms.

The progress bar will not display if either:

  • stdout is not a terminal (perhaps redirected to a ulility or file)
  • the environment variable NO_COLOR is set
  • the elapsed time never reaches the duration set by this function
§Example
let solution = fixed_iters(my_algo, 1000)
    // show a progress bar if the algo is slow to execute
    .show_progress_bar_after(Duration::from_millis(250))  
    .solve()?;
Source

fn set_fixed_iters(self, fixed_iters: usize) -> impl Driver<Algo = Self::Algo>
where Self: Sized,

After this number of iterations the algo is deemed to have converged on a solution. If Driver::converge_when is used, then iteration may terminate early.

Contrast with Driver::set_fail_after_iters

Source

fn set_fail_after_iters( self, max_iters: usize, ) -> impl Driver<Algo = Self::Algo>
where Self: Sized,

Unless convergence has occurred, the driver will error after the designated iteration count. Either the algo has internal logic to terminate iteration when a solution is found, or the caller needs to use Driver::converge_when to ensure iteration stops before a max iterations exceeded error.

Contrast with Driver::set_fixed_iters

Source

fn set_timeout(self, timeout: Duration) -> impl Driver<Algo = Self::Algo>
where Self: Sized,

Either the algo has internal logic to terminate iteration when a solution is found, or the caller needs to use Driver::converge_when to ensure iteration stops before a (timeout error)DriverError::Timeout.

Any previous timeout will be overwritten

Source

fn checkpoint( self, path: impl AsRef<Path>, every: Duration, ) -> impl Driver<Algo = Self::Algo>
where Self: Sized,

For writing checkpoint files, for possible later recovery.

Checkpoints are useful for very long running iterations, where restarting from ‘last checkpoint’ in cases of network error or crash, is a timesaver.

The path can contain replaceable tokens. Note that if the path is generated by a format! macro then the { will need to be escaped by doubling {{.

  • {iter} : current iteration number eg. 0000745
  • {elapsed} : current elapsed time in hours, mins, secs and millis eg 00h-34m-12s-034ms format_duration
  • {pid} : process ID of the running process

The file format is determined by the file extension of the supplied path.

There will typically be a file written for the algorithm, and a .control file for the Driver state. Both are needed for recovery. The control file’s location is the path and file name supplied, but with the extension set to ‘.control’.

Which file formats (if any) are supported is algorithm dependent.

§Example:
let my_program_name = "algo-solver";
let every = Duration::from_secs(60);
let (solved, _) = fixed_iters(my_algo, 2)
    .checkpoint(format!("/tmp/checkpoint-{my_program_name}-{{iter}}.json"), every)
    .solve()?;
Source

fn recovery( self, path: impl AsRef<Path>, action: RecoveryFile, ) -> impl Driver<Algo = Self::Algo>
where Self: Sized,

For recovery from checkpoint files. Specifies the location of files used to recover algorithm and driver state.

Typically you dont want to always recover from the last checkpoint file, only in cases where the iterative process was terminated. This entails having a recovery file location different to the checkpoint location, and some manual process to move/rename files from where they are written, to the recovery path location. See RecoveryFile.

There will typically be a file for the algorithm, and a “.control” file for the Driver state. Both are needed.

Note that metrics outside the driver, such as moving averages will be re-initialized, even though the algorithm is recovered.

Which file formats (if any) are supported is algorithm dependent.

§Example:
let every = Duration::from_secs(60);

let action = if command_line == "recover" {
    RecoveryFile::Require
} else {
    RecoveryFile::Ignore
};
let (solved, _) = fixed_iters(my_algo, 2)
    .recovery("/tmp/checkpoint-to-use.json", action)
    .checkpoint(format!("/tmp/checkpoint-{{iter}}.json"), every)
    .solve()?;
Source

fn into_boxed(self) -> BoxedDriver<Self::Algo>
where Self: Sized + 'static,

Converts the Driver into a sized type (BoxedDriver<A>), but remains generic over the algorithm.

§Examples

The following wont compile as the if/then arms have different types…

let driver = if log_enabled {
    fixed_iters(my_algo, 100).on_step(|_, step| println!("{step:?}"))
} else {
    fixed_iters(my_algo, 100)
};

Boxing a dyn trait gets around this…

let log_enabled = true;
let driver = if log_enabled {
    fixed_iters(my_algo, 100).on_step(|_, step| println!("{step:?}")).into_boxed()
} else {
    fixed_iters(my_algo, 100).into_boxed()
};
// driver is type BoxedDriver<MyAlgo>
let (my_algo, _step) = driver.show_progress_bar_after(Duration::ZERO).solve()?;
Source

fn into_dyn(self) -> DynDriver
where Self: Sized + 'static, Self::Algo: Algo,

Converts the Driver into a sized type without generics (DynDriver)

The prefered was of having common routines pass a Driver as a parameter is to pass Driver by impl Driver. But occasionally a concrete implementation is required.

Whereas BoxedDriver<A> is generic over the algorithm, DynDriver is a concrete type and can be stored in non-generic structs. To call methods on DynAlgo you will need to downcast to the specific Algorithm used, via DynAlgo::downcast_ref or DynAlgo::downcast_mut

Since closures cannot be named, if the algorithm is generic over closure/functions you will need to Box the closure. See examples folder.

§Example:
// note we need a concrete *function* to name and downcast to
type GradFn = fn(&[f64]) -> Vec<f64>;

let lr = 0.1;
let my_closure = |x: &[f64]| vec![2.0 * sphere_grad(x)[0], sphere_grad(x)[1]];

// coerce the closure into a function (must not capture from environment)
let grad: GradFn = my_closure;

let gd = GradientDescent::new(lr, vec![5.55, 5.55], grad);

let (mut dyn_solved, _step) = fixed_iters(gd.clone(), 500)
    .show_progress_bar_after(Duration::ZERO)
    .into_dyn()
    .solve()
    .unwrap();


let gd_algo: &GradientDescent<GradFn> =
    dyn_solved.downcast_ref().expect("failed to downcast_ref");

let solution = gd_algo.x();
assert_approx_eq!(solution, &[0.0, 0.0]);

Implementors§