simple_optimization/
simulated_annealing.rs

1use itertools::izip;
2use rand::{distributions::uniform::SampleUniform, thread_rng, Rng};
3use rand_distr::{Distribution, Normal};
4
5use std::{
6    convert::TryInto,
7    f64,
8    ops::{Range, Sub},
9    sync::{
10        atomic::{AtomicBool, AtomicU64, AtomicU8, Ordering},
11        Arc, Mutex,
12    },
13    thread,
14    time::{Duration, Instant},
15};
16
17use crate::util::{poll, update_execution_position, Polling};
18
19/// Cooling schedule for simulated annealing.
20#[derive(Clone, Copy)]
21pub enum CoolingSchedule {
22    /// $ t_{n+1} = t_n \cdot \ln{\frac{\ln{2}}{s+1}} $
23    Logarithmic,
24    /// $ t_{n+1} = x \cdot t_n $
25    Exponential(f64),
26    /// $ t_n = \frac{t_1}{n} $
27    Fast,
28}
29impl CoolingSchedule {
30    fn decay(&self, t_start: f64, t_current: f64, step: u64) -> f64 {
31        match self {
32            Self::Logarithmic => t_current * (2f64).ln() / ((step + 1) as f64).ln(),
33            Self::Exponential(x) => x * t_current,
34            Self::Fast => t_start / step as f64,
35        }
36    }
37    /// Given temperature start and temperature min, gives number of steps of decay which will occur
38    ///  before temperature start decays to be less than temperature min, then causing program to exit.
39    fn steps(&self, t_start: f64, t_min: f64) -> u64 {
40        match self {
41            Self::Logarithmic => (((2f64).ln() * t_start / t_min).exp() - 1f64).ceil() as u64,
42            Self::Exponential(x) => ((t_min / t_start).log(*x)).ceil() as u64,
43            Self::Fast => (t_start / t_min).ceil() as u64,
44        }
45    }
46}
47
48/// Castes all given ranges to `f64` values and calls [`simulated_annealing()`].
49/// ```
50/// use std::sync::Arc;
51/// use simple_optimization::{simulated_annealing, Polling};
52/// fn simple_function(list: &[f64; 3], _: Option<Arc<()>>) -> f64 {
53///  list.iter().sum()
54/// }
55/// let best = simulated_annealing!(
56///     (0f64..10f64, 5u32..15u32, 10i16..20i16), // Value ranges.
57///     simple_function, // Evaluation function.
58///     None, // No additional evaluation data.
59///     // By using `new` this defaults to polling every `10ms`, we don't print progress `false` and exit early if `19.` or less is reached.
60///     Some(Polling::new(false,Some(17.))),
61///     None, // We don't specify the number of threads.
62///     100., // Starting temperature is `100.`.
63///     1., // Minimum temperature is `1.`.
64///     simple_optimization::CoolingSchedule::Fast, // Use fast cooling schedule.
65///     // Take `100` samples per temperature
66///     // This is split between threads, so each thread only samples
67///     //  `100/n` at each temperature.
68///     100,
69///     1., // Variance in sampling.
70/// );
71/// assert!(simple_function(&best, None) < 19.);
72/// ```
73#[macro_export]
74macro_rules! simulated_annealing {
75    (
76        // Generic
77        ($($x:expr),*),
78        $f: expr,
79        $evaluation_data: expr,
80        $polling: expr,
81        $threads: expr,
82        // Specific
83        $starting_temperature: expr,
84        $minimum_temperature: expr,
85        $cooling_schedule: expr,
86        $samples_per_temperature: expr,
87        $variance: expr,
88    ) => {
89        {
90            use num::ToPrimitive;
91            let mut ranges = [
92                $(
93                    $x.start.to_f64().unwrap()..$x.end.to_f64().unwrap(),
94                )*
95            ];
96            simple_optimization::simulated_annealing(
97                ranges,
98                $f,
99                $evaluation_data,
100                $polling,
101                $threads,
102                $starting_temperature,
103                $minimum_temperature,
104                $cooling_schedule,
105                $samples_per_temperature,
106                $variance
107            )
108        }
109    };
110}
111
112// TODO Multi-thread this
113/// [Simulated annealing](https://en.wikipedia.org/wiki/Simulated_annealing)
114///
115/// Run simulated annealing starting at temperature `100.` decaying with a fast cooling schedule until reach a minimum temperature of `1.`, taking `100` samples at each temperature, with a variance in sampling of `1.`.
116/// ```
117/// use std::sync::Arc;
118/// use simple_optimization::{simulated_annealing, Polling};
119/// fn simple_function(list: &[f64; 3], _: Option<Arc<()>>) -> f64 {
120///  list.iter().sum()
121/// }
122/// let best = simulated_annealing(
123///     [0f64..10f64, 5f64..15f64, 10f64..20f64], // Value ranges.
124///     simple_function, // Evaluation function.
125///     None, // No additional evaluation data.
126///     // By using `new` this defaults to polling every `10ms`, we don't print progress `false` and exit early if `19.` or less is reached.
127///     Some(Polling::new(false,Some(17.))),
128///     None, // We don't specify the number of threads.
129///     100., // Starting temperature is `100.`.
130///     1., // Minimum temperature is `1.`.
131///     simple_optimization::CoolingSchedule::Fast, // Use fast cooling schedule.
132///     // Take `100` samples per temperature
133///     // This is split between threads, so each thread only samples
134///     //  `100/n` at each temperature.
135///     100,
136///     1., // Variance in sampling.
137/// );
138/// assert!(simple_function(&best, None) < 19.);
139/// ```
140pub fn simulated_annealing<
141    A: 'static + Send + Sync,
142    T: 'static
143        + Copy
144        + Send
145        + Sync
146        + Default
147        + SampleUniform
148        + PartialOrd
149        + Sub<Output = T>
150        + num::ToPrimitive
151        + num::FromPrimitive,
152    const N: usize,
153>(
154    // Generic
155    ranges: [Range<T>; N],
156    f: fn(&[T; N], Option<Arc<A>>) -> f64,
157    evaluation_data: Option<Arc<A>>,
158    polling: Option<Polling>,
159    threads: Option<usize>,
160    // Specific
161    starting_temperature: f64,
162    minimum_temperature: f64,
163    cooling_schedule: CoolingSchedule,
164    samples_per_temperature: u64,
165    variance: f64,
166) -> [T; N] {
167    // Gets cpu number
168    let cpus = crate::cpus!(threads);
169    // 1 cpu is used for polling (this one), so we have -1 cpus for searching.
170    let search_cpus = cpus - 1;
171
172    let steps = cooling_schedule.steps(starting_temperature, minimum_temperature);
173    let thread_exit = Arc::new(AtomicBool::new(false));
174    let ranges_arc = Arc::new(ranges);
175
176    let remainder = samples_per_temperature % search_cpus;
177    let per = samples_per_temperature / search_cpus;
178
179    let (best_value, mut best_params) = search(
180        // Generics
181        ranges_arc.clone(),
182        f,
183        evaluation_data.clone(),
184        Arc::new(AtomicU64::new(Default::default())),
185        Arc::new(Mutex::new(Default::default())),
186        Arc::new(AtomicBool::new(false)),
187        Arc::new(AtomicU8::new(0)),
188        Arc::new([
189            Mutex::new((Duration::new(0, 0), 0)),
190            Mutex::new((Duration::new(0, 0), 0)),
191            Mutex::new((Duration::new(0, 0), 0)),
192        ]),
193        // Specifics
194        starting_temperature,
195        minimum_temperature,
196        cooling_schedule,
197        remainder,
198        variance,
199    );
200
201    let (handles, links): (Vec<_>, Vec<_>) = (0..search_cpus)
202        .map(|_| {
203            let ranges_clone = ranges_arc.clone();
204            let counter = Arc::new(AtomicU64::new(0));
205            let thread_best = Arc::new(Mutex::new(f64::MAX));
206            let thread_execution_position = Arc::new(AtomicU8::new(0));
207            let thread_execution_time = Arc::new([
208                Mutex::new((Duration::new(0, 0), 0)),
209                Mutex::new((Duration::new(0, 0), 0)),
210                Mutex::new((Duration::new(0, 0), 0)),
211            ]);
212
213            let counter_clone = counter.clone();
214            let thread_best_clone = thread_best.clone();
215            let thread_exit_clone = thread_exit.clone();
216            let evaluation_data_clone = evaluation_data.clone();
217            let thread_execution_position_clone = thread_execution_position.clone();
218            let thread_execution_time_clone = thread_execution_time.clone();
219            (
220                thread::spawn(move || {
221                    search(
222                        // Generics
223                        ranges_clone,
224                        f,
225                        evaluation_data_clone,
226                        counter_clone,
227                        thread_best_clone,
228                        thread_exit_clone,
229                        thread_execution_position_clone,
230                        thread_execution_time_clone,
231                        // Specifics
232                        starting_temperature,
233                        minimum_temperature,
234                        cooling_schedule,
235                        per,
236                        variance,
237                    )
238                }),
239                (
240                    counter,
241                    (
242                        thread_best,
243                        (thread_execution_position, thread_execution_time),
244                    ),
245                ),
246            )
247        })
248        .unzip();
249    let (counters, links): (Vec<Arc<AtomicU64>>, Vec<_>) = links.into_iter().unzip();
250    let (thread_bests, links): (Vec<Arc<Mutex<f64>>>, Vec<_>) = links.into_iter().unzip();
251    let (thread_execution_positions, thread_execution_times) = links.into_iter().unzip();
252
253    if let Some(poll_data) = polling {
254        poll(
255            poll_data,
256            counters,
257            steps * remainder,
258            steps * samples_per_temperature,
259            thread_bests,
260            thread_exit,
261            thread_execution_positions,
262            thread_execution_times,
263        );
264    }
265
266    // Joins all handles and folds across extracting best value and best points.
267    let (new_best_value, new_best_params) = handles.into_iter().map(|h| h.join().unwrap()).fold(
268        (best_value, best_params),
269        |(bv, bp), (v, p)| {
270            if v < bv {
271                (v, p)
272            } else {
273                (bv, bp)
274            }
275        },
276    );
277    // If the best value from threads is better than the value from remainder
278    if new_best_value > best_value {
279        best_params = new_best_params
280    }
281
282    return best_params;
283
284    fn search<
285        A: 'static + Send + Sync,
286        T: 'static
287            + Copy
288            + Send
289            + Sync
290            + Default
291            + SampleUniform
292            + PartialOrd
293            + Sub<Output = T>
294            + num::ToPrimitive
295            + num::FromPrimitive,
296        const N: usize,
297    >(
298        // Generic
299        ranges: Arc<[Range<T>; N]>,
300        f: fn(&[T; N], Option<Arc<A>>) -> f64,
301        evaluation_data: Option<Arc<A>>,
302        counter: Arc<AtomicU64>,
303        best: Arc<Mutex<f64>>,
304        thread_exit: Arc<AtomicBool>,
305        thread_execution_position: Arc<AtomicU8>,
306        thread_execution_times: Arc<[Mutex<(Duration, u64)>; 3]>,
307        // Specific
308        starting_temperature: f64,
309        minimum_temperature: f64,
310        cooling_schedule: CoolingSchedule,
311        samples_per_temperature: u64,
312        variance: f64,
313    ) -> (f64, [T; N]) {
314        let mut execution_position_timer = Instant::now();
315        let mut rng = thread_rng();
316        // Get initial point
317        let mut current_point = [Default::default(); N];
318        for (p, r) in current_point.iter_mut().zip(ranges.iter()) {
319            *p = rng.gen_range(r.clone());
320        }
321        let mut best_point = current_point;
322
323        let mut current_value = f(&best_point, evaluation_data.clone());
324        let mut best_value = current_value;
325
326        // Gets ranges in f64
327        // Since `Range` doesn't implement copy and array initialization will not clone,
328        //  this bypasses it.
329        let mut float_ranges: [Range<f64>; N] = vec![Default::default(); N].try_into().unwrap();
330        for (float_range, range) in float_ranges.iter_mut().zip(ranges.iter()) {
331            *float_range = range.start.to_f64().unwrap()..range.end.to_f64().unwrap();
332        }
333        // Variances scaled to the different ranges.
334        let mut scaled_variances: [f64; N] = [Default::default(); N];
335        for (scaled_variance, range) in scaled_variances.iter_mut().zip(float_ranges.iter()) {
336            *scaled_variance = (range.end - range.start) * variance
337        }
338
339        let mut step = 1;
340        let mut temperature = starting_temperature;
341        // Iterate while starting temperature has yet to cool to the minimum temperature.
342        while temperature >= minimum_temperature {
343            // Distributions to sample from at this temperature.
344            // `Normal::new(1.,1.).unwrap()` just replacement for `default()` since it doesn't implement trait.
345            let mut distributions: [Normal<f64>; N] = [Normal::new(1., 1.).unwrap(); N];
346            for (distribution, variance, point) in izip!(
347                distributions.iter_mut(),
348                scaled_variances.iter(),
349                current_point.iter()
350            ) {
351                *distribution = Normal::new(point.to_f64().unwrap(), *variance).unwrap()
352            }
353            // Iterate over samples from this temperature
354            for _ in 0..samples_per_temperature {
355                // Update execution position
356                execution_position_timer = update_execution_position(
357                    1,
358                    execution_position_timer,
359                    &thread_execution_position,
360                    &thread_execution_times,
361                );
362
363                // Samples new point
364                let mut point = [Default::default(); N];
365                for (p, r, d) in izip!(point.iter_mut(), float_ranges.iter(), distributions.iter())
366                {
367                    *p = sample_normal(r, d, &mut rng);
368                }
369
370                // Update execution position
371                execution_position_timer = update_execution_position(
372                    2,
373                    execution_position_timer,
374                    &thread_execution_position,
375                    &thread_execution_times,
376                );
377
378                // Evaluates new point
379                let value = f(&point, evaluation_data.clone());
380
381                // Update execution position
382                execution_position_timer = update_execution_position(
383                    3,
384                    execution_position_timer,
385                    &thread_execution_position,
386                    &thread_execution_times,
387                );
388                // Increment counter
389                counter.fetch_add(1, Ordering::SeqCst);
390
391                // Update:
392                // - if there is any progression
393                // - the regression `allow_change` is within a limit `rng.gen_range(0f64..1f64)`
394                let difference = value - current_value;
395                let allow_change = (difference / temperature).exp();
396                if difference < 0. || allow_change < rng.gen_range(0f64..1f64) {
397                    current_point = point;
398                    current_value = value;
399                    // If this value is new best value, update best value
400                    if current_value < best_value {
401                        best_point = current_point;
402                        best_value = current_value;
403                        *best.lock().unwrap() = best_value;
404                    }
405                }
406                if thread_exit.load(Ordering::SeqCst) {
407                    return (best_value, best_point);
408                }
409            }
410            step += 1;
411            temperature = cooling_schedule.decay(starting_temperature, temperature, step);
412        }
413        // Update execution position
414        // 0 represents ended state
415        thread_execution_position.store(0, Ordering::SeqCst);
416        (best_value, best_point)
417    }
418
419    // Samples until value in range
420    fn sample_normal<R: Rng + ?Sized, T: num::FromPrimitive>(
421        range: &Range<f64>,
422        distribution: &Normal<f64>,
423        rng: &mut R,
424    ) -> T {
425        let mut point: f64 = distribution.sample(rng);
426        while !range.contains(&point) {
427            point = distribution.sample(rng);
428        }
429        T::from_f64(point).unwrap()
430    }
431}