tui_dispatch_core/
keybindings.rs

1//! Keybindings system with context-aware key parsing and lookup
2
3use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
4use serde::{Deserialize, Deserializer, Serialize, Serializer};
5use std::collections::HashMap;
6use std::hash::Hash;
7
8/// Trait for user-defined keybinding contexts
9///
10/// Implement this trait for your own context enum, or use `#[derive(BindingContext)]`
11/// from `tui-dispatch-macros` to auto-generate the implementation.
12///
13/// # Example
14/// ```ignore
15/// #[derive(BindingContext, Clone, Copy, PartialEq, Eq, Hash)]
16/// pub enum MyContext {
17///     Default,
18///     Search,
19///     Modal,
20/// }
21/// ```
22pub trait BindingContext: Clone + Copy + Eq + Hash {
23    /// Get the context name as a string (for config file lookup)
24    fn name(&self) -> &'static str;
25
26    /// Parse a context from its name
27    fn from_name(name: &str) -> Option<Self>;
28
29    /// Get all possible context values (for iteration/config loading)
30    fn all() -> &'static [Self];
31}
32
33/// Keybindings configuration with context support
34///
35/// Generic over the context type `C` which must implement `BindingContext`.
36#[derive(Debug, Clone)]
37pub struct Keybindings<C: BindingContext> {
38    /// Global keybindings - checked as fallback for all contexts
39    global: HashMap<String, Vec<String>>,
40    /// Context-specific keybindings
41    contexts: HashMap<C, HashMap<String, Vec<String>>>,
42}
43
44impl<C: BindingContext> Default for Keybindings<C> {
45    fn default() -> Self {
46        Self::new()
47    }
48}
49
50impl<C: BindingContext> Serialize for Keybindings<C> {
51    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
52    where
53        S: Serializer,
54    {
55        use serde::ser::SerializeMap;
56
57        // Count total entries: global + all contexts
58        let mut map = serializer.serialize_map(Some(1 + self.contexts.len()))?;
59
60        // Serialize global bindings
61        map.serialize_entry("global", &self.global)?;
62
63        // Serialize context-specific bindings using context names
64        for (context, bindings) in &self.contexts {
65            map.serialize_entry(context.name(), bindings)?;
66        }
67
68        map.end()
69    }
70}
71
72impl<'de, C: BindingContext> Deserialize<'de> for Keybindings<C> {
73    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
74    where
75        D: Deserializer<'de>,
76    {
77        // Deserialize as a map of string -> bindings
78        let raw: HashMap<String, HashMap<String, Vec<String>>> =
79            HashMap::deserialize(deserializer)?;
80
81        let mut keybindings = Keybindings::new();
82
83        for (context_name, bindings) in raw {
84            if context_name == "global" {
85                keybindings.global = bindings;
86            } else if let Some(context) = C::from_name(&context_name) {
87                keybindings.contexts.insert(context, bindings);
88            }
89            // Silently ignore unknown contexts (allows forward compatibility)
90        }
91
92        Ok(keybindings)
93    }
94}
95
96impl<C: BindingContext> Keybindings<C> {
97    /// Create a new empty keybindings configuration
98    pub fn new() -> Self {
99        Self {
100            global: HashMap::new(),
101            contexts: HashMap::new(),
102        }
103    }
104
105    /// Add a global keybinding
106    pub fn add_global(&mut self, command: impl Into<String>, keys: Vec<String>) {
107        self.global.insert(command.into(), keys);
108    }
109
110    /// Add a context-specific keybinding
111    pub fn add(&mut self, context: C, command: impl Into<String>, keys: Vec<String>) {
112        self.contexts
113            .entry(context)
114            .or_default()
115            .insert(command.into(), keys);
116    }
117
118    /// Get bindings for a specific context
119    pub fn get_context_bindings(&self, context: C) -> Option<&HashMap<String, Vec<String>>> {
120        self.contexts.get(&context)
121    }
122
123    /// Get global bindings
124    pub fn global_bindings(&self) -> &HashMap<String, Vec<String>> {
125        &self.global
126    }
127
128    /// Get command name for a key event in the given context
129    ///
130    /// First checks context-specific bindings, then falls back to global
131    pub fn get_command(&self, key: KeyEvent, context: C) -> Option<String> {
132        // First try context-specific bindings
133        if let Some(context_bindings) = self.contexts.get(&context) {
134            if let Some(cmd) = self.match_key_in_bindings(key, context_bindings) {
135                return Some(cmd);
136            }
137        }
138
139        // Fall back to global bindings
140        self.match_key_in_bindings(key, &self.global)
141    }
142
143    /// Helper to match a key against a set of bindings
144    fn match_key_in_bindings(
145        &self,
146        key: KeyEvent,
147        bindings: &HashMap<String, Vec<String>>,
148    ) -> Option<String> {
149        for (command, keys) in bindings {
150            for key_str in keys {
151                if let Some(parsed_key) = parse_key_string(key_str) {
152                    // Compare code and modifiers (ignore kind and state)
153                    // For character keys, compare case-insensitively
154                    let codes_match = match (&parsed_key.code, &key.code) {
155                        (KeyCode::Char(c1), KeyCode::Char(c2)) => {
156                            c1.to_lowercase().to_string() == c2.to_lowercase().to_string()
157                        }
158                        _ => parsed_key.code == key.code,
159                    };
160
161                    if codes_match && parsed_key.modifiers == key.modifiers {
162                        return Some(command.clone());
163                    }
164                }
165            }
166        }
167        None
168    }
169
170    /// Get the first keybinding string for a command in the given context
171    ///
172    /// First checks context-specific bindings, then falls back to global
173    pub fn get_first_keybinding(&self, command: &str, context: C) -> Option<String> {
174        if let Some(context_bindings) = self.contexts.get(&context) {
175            if let Some(keys) = context_bindings.get(command) {
176                if let Some(first) = keys.first() {
177                    return Some(first.clone());
178                }
179            }
180        }
181
182        self.global
183            .get(command)
184            .and_then(|keys| keys.first().cloned())
185    }
186
187    /// Merge user config onto defaults - user config overrides defaults
188    pub fn merge(mut defaults: Self, user: Self) -> Self {
189        // Merge global
190        for (key, value) in user.global {
191            defaults.global.insert(key, value);
192        }
193
194        // Merge contexts
195        for (context, bindings) in user.contexts {
196            let entry = defaults.contexts.entry(context).or_default();
197            for (key, value) in bindings {
198                entry.insert(key, value);
199            }
200        }
201
202        defaults
203    }
204}
205
206/// Parse a key string like "q", "esc", "ctrl+p", "shift+tab" into a KeyEvent
207pub fn parse_key_string(key_str: &str) -> Option<KeyEvent> {
208    let key_str = key_str.trim().to_lowercase();
209
210    if key_str.is_empty() {
211        return None;
212    }
213
214    // Special case: shift+tab should be BackTab
215    if key_str == "shift+tab" || key_str == "backtab" {
216        return Some(KeyEvent {
217            code: KeyCode::BackTab,
218            modifiers: KeyModifiers::SHIFT,
219            kind: crossterm::event::KeyEventKind::Press,
220            state: crossterm::event::KeyEventState::empty(),
221        });
222    }
223
224    // Check for modifiers
225    let parts: Vec<&str> = key_str.split('+').collect();
226    let mut modifiers = KeyModifiers::empty();
227    let key_part = parts.last()?.trim();
228
229    if parts.len() > 1 {
230        for part in &parts[..parts.len() - 1] {
231            match part.trim() {
232                "ctrl" | "control" => modifiers |= KeyModifiers::CONTROL,
233                "shift" => modifiers |= KeyModifiers::SHIFT,
234                "alt" => modifiers |= KeyModifiers::ALT,
235                _ => {}
236            }
237        }
238    }
239
240    // Parse the key code
241    let code = match key_part {
242        "esc" | "escape" => KeyCode::Esc,
243        "enter" | "return" => KeyCode::Enter,
244        "tab" => KeyCode::Tab,
245        "backtab" => {
246            if modifiers.is_empty() {
247                modifiers |= KeyModifiers::SHIFT;
248            }
249            KeyCode::BackTab
250        }
251        "backspace" => KeyCode::Backspace,
252        "up" => KeyCode::Up,
253        "down" => KeyCode::Down,
254        "left" => KeyCode::Left,
255        "right" => KeyCode::Right,
256        "home" => KeyCode::Home,
257        "end" => KeyCode::End,
258        "pageup" => KeyCode::PageUp,
259        "pagedown" => KeyCode::PageDown,
260        "delete" => KeyCode::Delete,
261        "insert" => KeyCode::Insert,
262        "f1" => KeyCode::F(1),
263        "f2" => KeyCode::F(2),
264        "f3" => KeyCode::F(3),
265        "f4" => KeyCode::F(4),
266        "f5" => KeyCode::F(5),
267        "f6" => KeyCode::F(6),
268        "f7" => KeyCode::F(7),
269        "f8" => KeyCode::F(8),
270        "f9" => KeyCode::F(9),
271        "f10" => KeyCode::F(10),
272        "f11" => KeyCode::F(11),
273        "f12" => KeyCode::F(12),
274        "space" => KeyCode::Char(' '),
275        // Single character
276        c if c.len() == 1 => {
277            let ch = c.chars().next()?;
278            KeyCode::Char(ch)
279        }
280        _ => return None,
281    };
282
283    Some(KeyEvent {
284        code,
285        modifiers,
286        kind: crossterm::event::KeyEventKind::Press,
287        state: crossterm::event::KeyEventState::empty(),
288    })
289}
290
291/// Format a key string for display (e.g., "ctrl+p" -> "^P", "q" -> "q", "tab" -> "Tab")
292pub fn format_key_for_display(key_str: &str) -> String {
293    let key_str = key_str.trim().to_lowercase();
294
295    // Handle special cases first
296    if key_str == "shift+tab" || key_str == "backtab" {
297        return "Shift+Tab".to_string();
298    }
299
300    // Check for modifiers
301    let parts: Vec<&str> = key_str.split('+').collect();
302    let mut modifiers = Vec::new();
303    let key_part = parts.last().copied().unwrap_or(key_str.as_str());
304
305    if parts.len() > 1 {
306        for part in &parts[..parts.len() - 1] {
307            match part.trim() {
308                "ctrl" | "control" => modifiers.push("^"),
309                "shift" => modifiers.push("Shift+"),
310                "alt" => modifiers.push("Alt+"),
311                _ => {}
312            }
313        }
314    }
315
316    // Format the key part
317    let key_display = match key_part {
318        "esc" | "escape" => "Esc".to_string(),
319        "enter" | "return" => "Enter".to_string(),
320        "tab" => "Tab".to_string(),
321        "backspace" => "Backspace".to_string(),
322        "up" => "Up".to_string(),
323        "down" => "Down".to_string(),
324        "left" => "Left".to_string(),
325        "right" => "Right".to_string(),
326        "home" => "Home".to_string(),
327        "end" => "End".to_string(),
328        "pageup" => "PgUp".to_string(),
329        "pagedown" => "PgDn".to_string(),
330        "delete" => "Del".to_string(),
331        "insert" => "Ins".to_string(),
332        "space" => "Space".to_string(),
333        "f1" => "F1".to_string(),
334        "f2" => "F2".to_string(),
335        "f3" => "F3".to_string(),
336        "f4" => "F4".to_string(),
337        "f5" => "F5".to_string(),
338        "f6" => "F6".to_string(),
339        "f7" => "F7".to_string(),
340        "f8" => "F8".to_string(),
341        "f9" => "F9".to_string(),
342        "f10" => "F10".to_string(),
343        "f11" => "F11".to_string(),
344        "f12" => "F12".to_string(),
345        // Single character - capitalize for display
346        c if c.len() == 1 => {
347            let ch = c.chars().next().unwrap();
348            // Keep special characters as-is, capitalize letters
349            if ch.is_alphabetic() {
350                ch.to_uppercase().collect::<String>()
351            } else {
352                ch.to_string()
353            }
354        }
355        _ => key_part.to_string(),
356    };
357
358    // Combine modifiers with key
359    if modifiers.is_empty() {
360        key_display
361    } else {
362        format!("{}{}", modifiers.join(""), key_display)
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369    use crossterm::event::{KeyCode, KeyModifiers};
370
371    // Test context for unit tests
372    #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
373    enum TestContext {
374        Default,
375        Search,
376    }
377
378    impl BindingContext for TestContext {
379        fn name(&self) -> &'static str {
380            match self {
381                TestContext::Default => "default",
382                TestContext::Search => "search",
383            }
384        }
385
386        fn from_name(name: &str) -> Option<Self> {
387            match name {
388                "default" => Some(TestContext::Default),
389                "search" => Some(TestContext::Search),
390                _ => None,
391            }
392        }
393
394        fn all() -> &'static [Self] {
395            &[TestContext::Default, TestContext::Search]
396        }
397    }
398
399    #[test]
400    fn test_parse_simple_key() {
401        let result = parse_key_string("q").unwrap();
402        assert_eq!(result.code, KeyCode::Char('q'));
403        assert_eq!(result.modifiers, KeyModifiers::empty());
404    }
405
406    #[test]
407    fn test_parse_esc() {
408        let result = parse_key_string("esc").unwrap();
409        assert_eq!(result.code, KeyCode::Esc);
410    }
411
412    #[test]
413    fn test_parse_ctrl_key() {
414        let result = parse_key_string("ctrl+p").unwrap();
415        assert_eq!(result.code, KeyCode::Char('p'));
416        assert!(result.modifiers.contains(KeyModifiers::CONTROL));
417    }
418
419    #[test]
420    fn test_parse_shift_tab() {
421        let result = parse_key_string("shift+tab").unwrap();
422        assert_eq!(result.code, KeyCode::BackTab);
423        assert!(result.modifiers.contains(KeyModifiers::SHIFT));
424    }
425
426    #[test]
427    fn test_parse_backtab() {
428        let result = parse_key_string("backtab").unwrap();
429        assert_eq!(result.code, KeyCode::BackTab);
430        assert!(result.modifiers.contains(KeyModifiers::SHIFT));
431    }
432
433    #[test]
434    fn test_parse_arrow_keys() {
435        let result = parse_key_string("up").unwrap();
436        assert_eq!(result.code, KeyCode::Up);
437
438        let result = parse_key_string("down").unwrap();
439        assert_eq!(result.code, KeyCode::Down);
440    }
441
442    #[test]
443    fn test_get_command() {
444        let mut bindings: Keybindings<TestContext> = Keybindings::new();
445        bindings.add_global("quit", vec!["q".to_string()]);
446        bindings.add(TestContext::Search, "clear", vec!["esc".to_string()]);
447
448        let key_q = KeyEvent {
449            code: KeyCode::Char('q'),
450            modifiers: KeyModifiers::empty(),
451            kind: crossterm::event::KeyEventKind::Press,
452            state: crossterm::event::KeyEventState::empty(),
453        };
454
455        // Global should work in any context
456        assert_eq!(
457            bindings.get_command(key_q, TestContext::Default),
458            Some("quit".to_string())
459        );
460        assert_eq!(
461            bindings.get_command(key_q, TestContext::Search),
462            Some("quit".to_string())
463        );
464
465        // Context-specific
466        let key_esc = KeyEvent {
467            code: KeyCode::Esc,
468            modifiers: KeyModifiers::empty(),
469            kind: crossterm::event::KeyEventKind::Press,
470            state: crossterm::event::KeyEventState::empty(),
471        };
472
473        assert_eq!(
474            bindings.get_command(key_esc, TestContext::Search),
475            Some("clear".to_string())
476        );
477        assert_eq!(bindings.get_command(key_esc, TestContext::Default), None);
478    }
479
480    #[test]
481    fn test_merge() {
482        let mut defaults: Keybindings<TestContext> = Keybindings::new();
483        defaults.add_global("quit", vec!["q".to_string()]);
484        defaults.add_global("help", vec!["?".to_string()]);
485
486        let mut user: Keybindings<TestContext> = Keybindings::new();
487        user.add_global("quit", vec!["x".to_string()]); // Override
488
489        let merged = Keybindings::merge(defaults, user);
490
491        // User override should be present
492        assert_eq!(
493            merged.global_bindings().get("quit"),
494            Some(&vec!["x".to_string()])
495        );
496
497        // Default should still be there
498        assert_eq!(
499            merged.global_bindings().get("help"),
500            Some(&vec!["?".to_string()])
501        );
502    }
503
504    #[test]
505    fn test_format_key_for_display() {
506        assert_eq!(format_key_for_display("q"), "Q");
507        assert_eq!(format_key_for_display("ctrl+p"), "^P");
508        assert_eq!(format_key_for_display("esc"), "Esc");
509        assert_eq!(format_key_for_display("shift+tab"), "Shift+Tab");
510    }
511}