1use std::f64::consts::E;
13
14#[derive(Debug, Clone, Copy)]
28pub struct Spring {
29 pub stiffness: f64,
32
33 pub damping: f64,
36
37 pub mass: f64,
40
41 pub initial_velocity: f64,
44
45 pub rest_threshold: f64,
48}
49
50impl Default for Spring {
51 fn default() -> Self {
52 Self::new()
53 }
54}
55
56impl Spring {
57 const EPSILON: f64 = 1e-9;
59
60 #[inline]
62 pub fn new() -> Self {
63 Self {
64 stiffness: 100.0,
65 damping: 10.0,
66 mass: 1.0,
67 initial_velocity: 0.0,
68 rest_threshold: 0.001,
69 }
70 }
71
72 #[inline]
74 pub fn stiffness(mut self, s: f64) -> Self {
75 self.stiffness = s.max(Self::EPSILON);
76 self
77 }
78
79 #[inline]
81 pub fn damping(mut self, d: f64) -> Self {
82 self.damping = d.max(0.0);
83 self
84 }
85
86 #[inline]
88 pub fn mass(mut self, m: f64) -> Self {
89 self.mass = m.max(Self::EPSILON);
90 self
91 }
92
93 #[inline]
95 pub fn initial_velocity(mut self, v: f64) -> Self {
96 self.initial_velocity = v;
97 self
98 }
99
100 #[inline]
102 pub fn rest_threshold(mut self, t: f64) -> Self {
103 self.rest_threshold = t.max(0.0);
104 self
105 }
106
107 #[inline]
114 pub fn damping_ratio(&self) -> f64 {
115 self.damping / (2.0 * (self.stiffness * self.mass).sqrt())
116 }
117
118 #[inline]
120 pub fn angular_frequency(&self) -> f64 {
121 (self.stiffness / self.mass).sqrt()
122 }
123
124 pub fn evaluate(&self, t: f64) -> (f64, f64) {
132 if t <= 0.0 {
133 return (1.0, self.initial_velocity);
134 }
135
136 let zeta = self.damping_ratio();
137 let w0 = self.angular_frequency();
138
139 let x0 = 1.0;
141 let v0 = self.initial_velocity;
142
143 if (zeta - 1.0).abs() < 1e-6 {
144 self.evaluate_critical(t, w0, x0, v0)
146 } else if zeta < 1.0 {
147 self.evaluate_underdamped(t, w0, zeta, x0, v0)
149 } else {
150 self.evaluate_overdamped(t, w0, zeta, x0, v0)
152 }
153 }
154
155 #[inline]
157 fn evaluate_underdamped(&self, t: f64, w0: f64, zeta: f64, x0: f64, v0: f64) -> (f64, f64) {
158 let wd = w0 * (1.0 - zeta * zeta).sqrt();
160
161 let a = x0;
163 let b = (v0 + zeta * w0 * x0) / wd;
164
165 let envelope = E.powf(-zeta * w0 * t);
167
168 let cos_term = (wd * t).cos();
170 let sin_term = (wd * t).sin();
171
172 let position = envelope * (a * cos_term + b * sin_term);
174
175 let velocity = envelope * (
177 (-zeta * w0) * (a * cos_term + b * sin_term)
178 + wd * (-a * sin_term + b * cos_term)
179 );
180
181 (position, velocity)
182 }
183
184 #[inline]
186 fn evaluate_critical(&self, t: f64, w0: f64, x0: f64, v0: f64) -> (f64, f64) {
187 let a = x0;
188 let b = v0 + w0 * x0;
189
190 let envelope = E.powf(-w0 * t);
191
192 let position = (a + b * t) * envelope;
194
195 let velocity = envelope * (b - w0 * (a + b * t));
197
198 (position, velocity)
199 }
200
201 #[inline]
203 fn evaluate_overdamped(&self, t: f64, w0: f64, zeta: f64, x0: f64, v0: f64) -> (f64, f64) {
204 let sqrt_term = (zeta * zeta - 1.0).sqrt();
206 let r1 = -w0 * (zeta - sqrt_term);
207 let r2 = -w0 * (zeta + sqrt_term);
208
209 let a = (v0 - r2 * x0) / (r1 - r2);
213 let b = x0 - a;
214
215 let exp1 = E.powf(r1 * t);
216 let exp2 = E.powf(r2 * t);
217
218 let position = a * exp1 + b * exp2;
220
221 let velocity = a * r1 * exp1 + b * r2 * exp2;
223
224 (position, velocity)
225 }
226
227 #[inline]
229 pub fn is_at_rest(&self, t: f64) -> bool {
230 let (position, velocity) = self.evaluate(t);
231 position.abs() + velocity.abs() < self.rest_threshold
232 }
233
234 pub fn estimated_duration(&self) -> f64 {
240 let zeta = self.damping_ratio();
241 let w0 = self.angular_frequency();
242
243 if w0 < Self::EPSILON {
244 return 100.0; }
246
247 let decay_rate = zeta * w0;
253
254 if decay_rate < Self::EPSILON {
255 return 100.0; }
257
258 let duration = -self.rest_threshold.ln() / decay_rate;
260
261 duration.clamp(0.0, 100.0)
263 }
264
265 #[inline]
269 pub fn gentle() -> Self {
270 Self::new().stiffness(120.0).damping(14.0)
271 }
272
273 #[inline]
277 pub fn bouncy() -> Self {
278 Self::new().stiffness(180.0).damping(12.0)
279 }
280
281 #[inline]
285 pub fn stiff() -> Self {
286 Self::new().stiffness(300.0).damping(20.0)
287 }
288
289 #[inline]
293 pub fn slow() -> Self {
294 Self::new().stiffness(60.0).damping(14.0)
295 }
296
297 pub fn as_easing(&self, samples: usize) -> Vec<f64> {
304 if samples == 0 {
305 return vec![];
306 }
307
308 let duration = self.estimated_duration();
309 let dt = duration / samples.max(1) as f64;
310
311 (0..samples)
312 .map(|i| {
313 let t = i as f64 * dt;
314 let (position, _) = self.evaluate(t);
315 1.0 - position.clamp(0.0, 1.0)
317 })
318 .collect()
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 #[test]
327 fn test_damping_ratio() {
328 let spring = Spring::new();
329 let ratio = spring.damping_ratio();
330
331 assert!((ratio - 0.5).abs() < 1e-6);
334 }
335
336 #[test]
337 fn test_angular_frequency() {
338 let spring = Spring::new();
339 let w0 = spring.angular_frequency();
340
341 assert!((w0 - 10.0).abs() < 1e-6);
343 }
344
345 #[test]
346 fn test_initial_conditions() {
347 let spring = Spring::new();
348 let (pos, vel) = spring.evaluate(0.0);
349
350 assert!((pos - 1.0).abs() < 1e-6);
351 assert!(vel.abs() < 1e-6);
352 }
353
354 #[test]
355 fn test_underdamped_oscillation() {
356 let spring = Spring::bouncy(); let mut positions = Vec::new();
360 for i in 0..100 {
361 let t = i as f64 * 0.01;
362 let (pos, _) = spring.evaluate(t);
363 positions.push(pos);
364 }
365
366 let mut sign_changes = 0;
368 for i in 1..positions.len() {
369 if positions[i-1] * positions[i] < 0.0 {
370 sign_changes += 1;
371 }
372 }
373
374 assert!(sign_changes > 0, "Under-damped spring should oscillate");
375 }
376
377 #[test]
378 fn test_overdamped_no_oscillation() {
379 let spring = Spring::new().stiffness(50.0).damping(50.0); for i in 0..100 {
383 let t = i as f64 * 0.01;
384 let (pos, _) = spring.evaluate(t);
385
386 assert!(pos >= -1e-6, "Over-damped spring should not overshoot");
388 }
389 }
390
391 #[test]
392 fn test_critically_damped_fast_settle() {
393 let spring = Spring::new().stiffness(100.0).damping(20.0); let ratio = spring.damping_ratio();
396
397 assert!((ratio - 1.0).abs() < 1e-3);
398
399 let duration = spring.estimated_duration();
401 assert!(duration > 0.0 && duration < 10.0);
402 }
403
404 #[test]
405 fn test_rest_detection() {
406 let spring = Spring::stiff();
407
408 assert!(!spring.is_at_rest(0.0));
410
411 assert!(spring.is_at_rest(5.0));
413 }
414
415 #[test]
416 fn test_presets() {
417 let gentle = Spring::gentle();
418 let bouncy = Spring::bouncy();
419 let stiff = Spring::stiff();
420 let slow = Spring::slow();
421
422 assert!(gentle.stiffness > 0.0);
424 assert!(bouncy.stiffness > 0.0);
425 assert!(stiff.stiffness > 0.0);
426 assert!(slow.stiffness > 0.0);
427 }
428
429 #[test]
430 fn test_as_easing() {
431 let spring = Spring::gentle();
432 let easing = spring.as_easing(100);
433
434 assert_eq!(easing.len(), 100);
435
436 assert!(easing[0] < 0.1);
438
439 assert!(easing[99] > 0.9);
441
442 assert!(easing[99] > easing[0]);
446 }
447
448 #[test]
449 fn test_initial_velocity() {
450 let spring = Spring::new().initial_velocity(10.0);
451 let (_, vel) = spring.evaluate(0.0);
452
453 assert!((vel - 10.0).abs() < 1e-6);
454 }
455
456 #[test]
457 fn test_edge_cases() {
458 let spring = Spring::new().stiffness(0.0);
460 assert!(spring.stiffness > 0.0);
461
462 let spring = Spring::new().mass(0.0);
464 assert!(spring.mass > 0.0);
465
466 let spring = Spring::new().damping(-5.0);
468 assert!(spring.damping >= 0.0);
469 }
470
471 #[test]
472 fn test_energy_conservation() {
473 let spring = Spring::bouncy();
475
476 let duration = spring.estimated_duration();
477 let (pos_early, _) = spring.evaluate(duration * 0.1);
478 let (pos_late, _) = spring.evaluate(duration * 0.9);
479
480 assert!(pos_early.abs() <= 2.0); assert!(pos_late.abs() < spring.rest_threshold * 10.0);
483 }
484}