spitfire_input/
lib.rs

1#[cfg(not(target_arch = "wasm32"))]
2use glutin::event::{ElementState, MouseButton, MouseScrollDelta, VirtualKeyCode, WindowEvent};
3use std::{
4    borrow::Cow,
5    cmp::Ordering,
6    collections::HashMap,
7    sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard},
8};
9use typid::ID;
10#[cfg(target_arch = "wasm32")]
11use winit::event::{ElementState, MouseButton, MouseScrollDelta, VirtualKeyCode, WindowEvent};
12
13#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
14pub enum InputConsume {
15    #[default]
16    None,
17    Hit,
18    All,
19}
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
22pub enum VirtualAction {
23    KeyButton(VirtualKeyCode),
24    MouseButton(MouseButton),
25    Axis(u32),
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29pub enum VirtualAxis {
30    KeyButton(VirtualKeyCode),
31    MousePositionX,
32    MousePositionY,
33    MouseWheelX,
34    MouseWheelY,
35    MouseButton(MouseButton),
36    Axis(u32),
37}
38
39#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
40pub enum InputAction {
41    #[default]
42    Idle,
43    Pressed,
44    Hold,
45    Released,
46}
47
48impl InputAction {
49    pub fn change(self, hold: bool) -> Self {
50        match (self, hold) {
51            (Self::Idle, true) | (Self::Released, true) => Self::Pressed,
52            (Self::Pressed, true) => Self::Hold,
53            (Self::Pressed, false) | (Self::Hold, false) => Self::Released,
54            (Self::Released, false) => Self::Idle,
55            _ => self,
56        }
57    }
58
59    pub fn update(self) -> Self {
60        match self {
61            Self::Pressed => Self::Hold,
62            Self::Released => Self::Idle,
63            _ => self,
64        }
65    }
66
67    pub fn is_idle(self) -> bool {
68        matches!(self, Self::Idle)
69    }
70
71    pub fn is_pressed(self) -> bool {
72        matches!(self, Self::Pressed)
73    }
74
75    pub fn is_hold(self) -> bool {
76        matches!(self, Self::Hold)
77    }
78
79    pub fn is_released(self) -> bool {
80        matches!(self, Self::Released)
81    }
82
83    pub fn is_up(self) -> bool {
84        matches!(self, Self::Idle | Self::Released)
85    }
86
87    pub fn is_down(self) -> bool {
88        matches!(self, Self::Pressed | Self::Hold)
89    }
90
91    pub fn is_changing(self) -> bool {
92        matches!(self, Self::Pressed | Self::Released)
93    }
94
95    pub fn is_continuing(self) -> bool {
96        matches!(self, Self::Idle | Self::Hold)
97    }
98
99    pub fn to_scalar(self, falsy: f32, truthy: f32) -> f32 {
100        if self.is_down() { truthy } else { falsy }
101    }
102}
103
104#[derive(Debug, Default, Clone, Copy, PartialEq)]
105pub struct InputAxis(pub f32);
106
107impl InputAxis {
108    pub fn threshold(self, value: f32) -> bool {
109        self.0 >= value
110    }
111}
112
113#[derive(Debug, Default, Clone)]
114pub struct InputRef<T: Default + Clone>(Arc<RwLock<T>>);
115
116impl<T: Default + Clone> InputRef<T> {
117    pub fn new(data: T) -> Self {
118        Self(Arc::new(RwLock::new(data)))
119    }
120
121    pub fn read(&self) -> Option<RwLockReadGuard<T>> {
122        self.0.read().ok()
123    }
124
125    pub fn write(&self) -> Option<RwLockWriteGuard<T>> {
126        self.0.write().ok()
127    }
128
129    pub fn get(&self) -> T {
130        self.read().map(|value| value.clone()).unwrap_or_default()
131    }
132
133    pub fn set(&self, value: T) {
134        if let Some(mut data) = self.write() {
135            *data = value;
136        }
137    }
138}
139
140pub type InputActionRef = InputRef<InputAction>;
141pub type InputAxisRef = InputRef<InputAxis>;
142pub type InputCharactersRef = InputRef<InputCharacters>;
143pub type InputMappingRef = InputRef<InputMapping>;
144
145#[derive(Debug, Default, Clone)]
146pub enum InputActionOrAxisRef {
147    #[default]
148    None,
149    Action(InputActionRef),
150    Axis(InputAxisRef),
151}
152
153impl InputActionOrAxisRef {
154    pub fn is_none(&self) -> bool {
155        matches!(self, Self::None)
156    }
157
158    pub fn is_some(&self) -> bool {
159        !self.is_none()
160    }
161
162    pub fn get_scalar(&self, falsy: f32, truthy: f32) -> f32 {
163        match self {
164            Self::None => falsy,
165            Self::Action(action) => action.get().to_scalar(falsy, truthy),
166            Self::Axis(axis) => axis.get().0,
167        }
168    }
169
170    pub fn threshold(&self, value: f32) -> bool {
171        match self {
172            Self::None => false,
173            Self::Action(action) => action.get().is_down(),
174            Self::Axis(axis) => axis.get().threshold(value),
175        }
176    }
177}
178
179impl From<InputActionRef> for InputActionOrAxisRef {
180    fn from(value: InputActionRef) -> Self {
181        Self::Action(value)
182    }
183}
184
185impl From<InputAxisRef> for InputActionOrAxisRef {
186    fn from(value: InputAxisRef) -> Self {
187        Self::Axis(value)
188    }
189}
190
191pub struct InputCombinator<T> {
192    mapper: Box<dyn Fn() -> T>,
193}
194
195impl<T: Default> Default for InputCombinator<T> {
196    fn default() -> Self {
197        Self::new(|| T::default())
198    }
199}
200
201impl<T> InputCombinator<T> {
202    pub fn new(mapper: impl Fn() -> T + 'static) -> Self {
203        Self {
204            mapper: Box::new(mapper),
205        }
206    }
207
208    pub fn get(&self) -> T {
209        (self.mapper)()
210    }
211}
212
213#[derive(Default)]
214pub struct CardinalInputCombinator(InputCombinator<[f32; 2]>);
215
216impl CardinalInputCombinator {
217    pub fn new(
218        left: impl Into<InputActionOrAxisRef>,
219        right: impl Into<InputActionOrAxisRef>,
220        up: impl Into<InputActionOrAxisRef>,
221        down: impl Into<InputActionOrAxisRef>,
222    ) -> Self {
223        let left = left.into();
224        let right = right.into();
225        let up = up.into();
226        let down = down.into();
227        Self(InputCombinator::new(move || {
228            let left = left.get_scalar(0.0, -1.0);
229            let right = right.get_scalar(0.0, 1.0);
230            let up = up.get_scalar(0.0, -1.0);
231            let down = down.get_scalar(0.0, 1.0);
232            [left + right, up + down]
233        }))
234    }
235
236    pub fn get(&self) -> [f32; 2] {
237        self.0.get()
238    }
239}
240
241#[derive(Default)]
242pub struct DualInputCombinator(InputCombinator<f32>);
243
244impl DualInputCombinator {
245    pub fn new(
246        negative: impl Into<InputActionOrAxisRef>,
247        positive: impl Into<InputActionOrAxisRef>,
248    ) -> Self {
249        let negative = negative.into();
250        let positive = positive.into();
251        Self(InputCombinator::new(move || {
252            let negative = negative.get_scalar(0.0, -1.0);
253            let positive = positive.get_scalar(0.0, 1.0);
254            negative + positive
255        }))
256    }
257
258    pub fn get(&self) -> f32 {
259        self.0.get()
260    }
261}
262
263pub struct ArrayInputCombinator<const N: usize>(InputCombinator<[f32; N]>);
264
265impl<const N: usize> Default for ArrayInputCombinator<N> {
266    fn default() -> Self {
267        Self(InputCombinator::new(|| {
268            std::array::from_fn(|_| Default::default())
269        }))
270    }
271}
272
273impl<const N: usize> ArrayInputCombinator<N> {
274    pub fn new(inputs: [impl Into<InputActionOrAxisRef>; N]) -> Self {
275        let mut items = std::array::from_fn::<InputActionOrAxisRef, N, _>(|_| Default::default());
276        for (index, input) in inputs.into_iter().enumerate() {
277            items[index] = input.into();
278        }
279        Self(InputCombinator::new(move || {
280            std::array::from_fn(|index| items[index].get_scalar(0.0, 1.0))
281        }))
282    }
283
284    pub fn get(&self) -> [f32; N] {
285        self.0.get()
286    }
287}
288
289#[derive(Debug, Default, Clone)]
290pub struct InputCharacters {
291    characters: String,
292}
293
294impl InputCharacters {
295    pub fn read(&self) -> &str {
296        &self.characters
297    }
298
299    pub fn write(&mut self) -> &mut String {
300        &mut self.characters
301    }
302
303    pub fn take(&mut self) -> String {
304        std::mem::take(&mut self.characters)
305    }
306}
307
308#[derive(Debug, Default, Clone)]
309pub struct InputMapping {
310    pub actions: HashMap<VirtualAction, InputActionRef>,
311    pub axes: HashMap<VirtualAxis, InputAxisRef>,
312    pub consume: InputConsume,
313    pub layer: isize,
314    pub name: Cow<'static, str>,
315}
316
317impl InputMapping {
318    pub fn action(mut self, id: VirtualAction, action: InputActionRef) -> Self {
319        self.actions.insert(id, action);
320        self
321    }
322
323    pub fn axis(mut self, id: VirtualAxis, axis: InputAxisRef) -> Self {
324        self.axes.insert(id, axis);
325        self
326    }
327
328    pub fn consume(mut self, consume: InputConsume) -> Self {
329        self.consume = consume;
330        self
331    }
332
333    pub fn layer(mut self, value: isize) -> Self {
334        self.layer = value;
335        self
336    }
337
338    pub fn name(mut self, value: impl Into<Cow<'static, str>>) -> Self {
339        self.name = value.into();
340        self
341    }
342}
343
344impl From<InputMapping> for InputMappingRef {
345    fn from(value: InputMapping) -> Self {
346        Self::new(value)
347    }
348}
349
350#[derive(Debug, Clone)]
351pub struct InputContext {
352    pub mouse_wheel_line_scale: f32,
353    /// [(id, mapping)]
354    mappings_stack: Vec<(ID<InputMapping>, InputMappingRef)>,
355    characters: InputCharactersRef,
356}
357
358impl Default for InputContext {
359    fn default() -> Self {
360        Self {
361            mouse_wheel_line_scale: Self::default_mouse_wheel_line_scale(),
362            mappings_stack: Default::default(),
363            characters: Default::default(),
364        }
365    }
366}
367
368impl InputContext {
369    fn default_mouse_wheel_line_scale() -> f32 {
370        10.0
371    }
372
373    pub fn push_mapping(&mut self, mapping: impl Into<InputMappingRef>) -> ID<InputMapping> {
374        let mapping = mapping.into();
375        let id = ID::default();
376        let layer = mapping.read().unwrap().layer;
377        let index = self
378            .mappings_stack
379            .binary_search_by(|(_, mapping)| {
380                mapping
381                    .read()
382                    .unwrap()
383                    .layer
384                    .cmp(&layer)
385                    .then(Ordering::Less)
386            })
387            .unwrap_or_else(|index| index);
388        self.mappings_stack.insert(index, (id, mapping));
389        id
390    }
391
392    pub fn pop_mapping(&mut self) -> Option<InputMappingRef> {
393        self.mappings_stack.pop().map(|(_, mapping)| mapping)
394    }
395
396    pub fn top_mapping(&self) -> Option<&InputMappingRef> {
397        self.mappings_stack.last().map(|(_, mapping)| mapping)
398    }
399
400    pub fn remove_mapping(&mut self, id: ID<InputMapping>) -> Option<InputMappingRef> {
401        self.mappings_stack
402            .iter()
403            .position(|(mid, _)| mid == &id)
404            .map(|index| self.mappings_stack.remove(index).1)
405    }
406
407    pub fn mapping(&self, id: ID<InputMapping>) -> Option<RwLockReadGuard<InputMapping>> {
408        self.mappings_stack
409            .iter()
410            .find(|(mid, _)| mid == &id)
411            .and_then(|(_, mapping)| mapping.read())
412    }
413
414    pub fn stack(&self) -> impl Iterator<Item = &InputMappingRef> {
415        self.mappings_stack.iter().map(|(_, mapping)| mapping)
416    }
417
418    pub fn characters(&self) -> InputCharactersRef {
419        self.characters.clone()
420    }
421
422    pub fn maintain(&mut self) {
423        for (_, mapping) in &mut self.mappings_stack {
424            if let Some(mut mapping) = mapping.write() {
425                for action in mapping.actions.values_mut() {
426                    if let Some(mut action) = action.write() {
427                        *action = action.update();
428                    }
429                }
430                for (id, axis) in &mut mapping.axes {
431                    if let VirtualAxis::MouseWheelX | VirtualAxis::MouseWheelY = id {
432                        if let Some(mut axis) = axis.write() {
433                            axis.0 = 0.0;
434                        }
435                    }
436                }
437            }
438        }
439    }
440
441    pub fn on_event(&mut self, event: &WindowEvent) {
442        match event {
443            WindowEvent::ReceivedCharacter(character) => {
444                if let Some(mut characters) = self.characters.write() {
445                    characters.characters.push(*character);
446                }
447            }
448            WindowEvent::KeyboardInput { input, .. } => {
449                if let Some(key) = input.virtual_keycode {
450                    for (_, mapping) in self.mappings_stack.iter().rev() {
451                        if let Some(mapping) = mapping.read() {
452                            let mut consume = mapping.consume == InputConsume::All;
453                            for (id, data) in &mapping.actions {
454                                if let VirtualAction::KeyButton(button) = id {
455                                    if *button == key {
456                                        if let Some(mut data) = data.write() {
457                                            *data =
458                                                data.change(input.state == ElementState::Pressed);
459                                            if mapping.consume == InputConsume::Hit {
460                                                consume = true;
461                                            }
462                                        }
463                                    }
464                                }
465                            }
466                            for (id, data) in &mapping.axes {
467                                if let VirtualAxis::KeyButton(button) = id {
468                                    if *button == key {
469                                        if let Some(mut data) = data.write() {
470                                            data.0 = if input.state == ElementState::Pressed {
471                                                1.0
472                                            } else {
473                                                0.0
474                                            };
475                                            if mapping.consume == InputConsume::Hit {
476                                                consume = true;
477                                            }
478                                        }
479                                    }
480                                }
481                            }
482                            if consume {
483                                break;
484                            }
485                        }
486                    }
487                }
488            }
489            WindowEvent::CursorMoved { position, .. } => {
490                for (_, mapping) in self.mappings_stack.iter().rev() {
491                    if let Some(mapping) = mapping.read() {
492                        let mut consume = mapping.consume == InputConsume::All;
493                        for (id, data) in &mapping.axes {
494                            match id {
495                                VirtualAxis::MousePositionX => {
496                                    if let Some(mut data) = data.write() {
497                                        data.0 = position.x as _;
498                                        if mapping.consume == InputConsume::Hit {
499                                            consume = true;
500                                        }
501                                    }
502                                }
503                                VirtualAxis::MousePositionY => {
504                                    if let Some(mut data) = data.write() {
505                                        data.0 = position.y as _;
506                                        if mapping.consume == InputConsume::Hit {
507                                            consume = true;
508                                        }
509                                    }
510                                }
511                                _ => {}
512                            }
513                        }
514                        if consume {
515                            break;
516                        }
517                    }
518                }
519            }
520            WindowEvent::MouseWheel { delta, .. } => {
521                for (_, mapping) in self.mappings_stack.iter().rev() {
522                    if let Some(mapping) = mapping.read() {
523                        let mut consume = mapping.consume == InputConsume::All;
524                        for (id, data) in &mapping.axes {
525                            match id {
526                                VirtualAxis::MouseWheelX => {
527                                    if let Some(mut data) = data.write() {
528                                        data.0 = match delta {
529                                            MouseScrollDelta::LineDelta(x, _) => *x,
530                                            MouseScrollDelta::PixelDelta(pos) => pos.x as _,
531                                        };
532                                        if mapping.consume == InputConsume::Hit {
533                                            consume = true;
534                                        }
535                                    }
536                                }
537                                VirtualAxis::MouseWheelY => {
538                                    if let Some(mut data) = data.write() {
539                                        data.0 = match delta {
540                                            MouseScrollDelta::LineDelta(_, y) => *y,
541                                            MouseScrollDelta::PixelDelta(pos) => pos.y as _,
542                                        };
543                                        if mapping.consume == InputConsume::Hit {
544                                            consume = true;
545                                        }
546                                    }
547                                }
548                                _ => {}
549                            }
550                        }
551                        if consume {
552                            break;
553                        }
554                    }
555                }
556            }
557            WindowEvent::MouseInput { state, button, .. } => {
558                for (_, mapping) in self.mappings_stack.iter().rev() {
559                    if let Some(mapping) = mapping.read() {
560                        let mut consume = mapping.consume == InputConsume::All;
561                        for (id, data) in &mapping.actions {
562                            if let VirtualAction::MouseButton(btn) = id {
563                                if button == btn {
564                                    if let Some(mut data) = data.write() {
565                                        *data = data.change(*state == ElementState::Pressed);
566                                        if mapping.consume == InputConsume::Hit {
567                                            consume = true;
568                                        }
569                                    }
570                                }
571                            }
572                        }
573                        for (id, data) in &mapping.axes {
574                            if let VirtualAxis::MouseButton(btn) = id {
575                                if button == btn {
576                                    if let Some(mut data) = data.write() {
577                                        data.0 = if *state == ElementState::Pressed {
578                                            1.0
579                                        } else {
580                                            0.0
581                                        };
582                                        if mapping.consume == InputConsume::Hit {
583                                            consume = true;
584                                        }
585                                    }
586                                }
587                            }
588                        }
589                        if consume {
590                            break;
591                        }
592                    }
593                }
594            }
595            WindowEvent::AxisMotion { axis, value, .. } => {
596                for (_, mapping) in self.mappings_stack.iter().rev() {
597                    if let Some(mapping) = mapping.read() {
598                        let mut consume = mapping.consume == InputConsume::All;
599                        for (id, data) in &mapping.actions {
600                            if let VirtualAction::Axis(index) = id {
601                                if axis == index {
602                                    if let Some(mut data) = data.write() {
603                                        *data = data.change(value.abs() > 0.5);
604                                        if mapping.consume == InputConsume::Hit {
605                                            consume = true;
606                                        }
607                                    }
608                                }
609                            }
610                        }
611                        for (id, data) in &mapping.axes {
612                            if let VirtualAxis::Axis(index) = id {
613                                if axis == index {
614                                    if let Some(mut data) = data.write() {
615                                        data.0 = *value as _;
616                                        if mapping.consume == InputConsume::Hit {
617                                            consume = true;
618                                        }
619                                    }
620                                }
621                            }
622                        }
623                        if consume {
624                            break;
625                        }
626                    }
627                }
628            }
629            _ => {}
630        }
631    }
632}
633
634#[cfg(test)]
635mod tests {
636    use crate::{InputContext, InputMapping};
637
638    #[test]
639    fn test_stack() {
640        let mut context = InputContext::default();
641        context.push_mapping(InputMapping::default().name("a").layer(0));
642        context.push_mapping(InputMapping::default().name("b").layer(0));
643        context.push_mapping(InputMapping::default().name("c").layer(0));
644        context.push_mapping(InputMapping::default().name("d").layer(-1));
645        context.push_mapping(InputMapping::default().name("e").layer(1));
646        context.push_mapping(InputMapping::default().name("f").layer(-1));
647        context.push_mapping(InputMapping::default().name("g").layer(1));
648        context.push_mapping(InputMapping::default().name("h").layer(-2));
649        context.push_mapping(InputMapping::default().name("i").layer(-2));
650        context.push_mapping(InputMapping::default().name("j").layer(2));
651        context.push_mapping(InputMapping::default().name("k").layer(2));
652
653        let provided = context
654            .stack()
655            .map(|mapping| {
656                let mapping = mapping.read().unwrap();
657                (mapping.name.as_ref().to_owned(), mapping.layer)
658            })
659            .collect::<Vec<_>>();
660        assert_eq!(
661            provided,
662            vec![
663                ("h".to_owned(), -2),
664                ("i".to_owned(), -2),
665                ("d".to_owned(), -1),
666                ("f".to_owned(), -1),
667                ("a".to_owned(), 0),
668                ("b".to_owned(), 0),
669                ("c".to_owned(), 0),
670                ("e".to_owned(), 1),
671                ("g".to_owned(), 1),
672                ("j".to_owned(), 2),
673                ("k".to_owned(), 2),
674            ]
675        );
676    }
677}