Skip to main content

tui_dispatch_core/
keybindings.rs

1//! Keybindings system with context-aware key parsing and lookup
2
3use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
4use std::collections::HashMap;
5use std::hash::Hash;
6use std::sync::OnceLock;
7
8/// Trait for user-defined keybinding contexts
9///
10/// Implement this trait for your own context enum, or use `#[derive(BindingContext)]`
11/// from `tui-dispatch-macros` to auto-generate the implementation.
12///
13/// # Example
14/// ```ignore
15/// #[derive(BindingContext, Clone, Copy, PartialEq, Eq, Hash)]
16/// pub enum MyContext {
17///     Default,
18///     Search,
19///     Modal,
20/// }
21/// ```
22pub trait BindingContext: Clone + Copy + Eq + Hash {
23    /// Get the context name as a string (for config file lookup)
24    fn name(&self) -> &'static str;
25
26    /// Parse a context from its name
27    fn from_name(name: &str) -> Option<Self>;
28
29    /// Get all possible context values (for iteration/config loading)
30    fn all() -> &'static [Self];
31}
32
33/// Keybindings configuration with context support
34///
35/// Generic over the context type `C` which must implement `BindingContext`.
36#[derive(Debug)]
37pub struct Keybindings<C: BindingContext> {
38    /// Global keybindings - checked as fallback for all contexts
39    global: HashMap<String, Vec<String>>,
40    /// Context-specific keybindings
41    contexts: HashMap<C, HashMap<String, Vec<String>>>,
42    /// Cached parsed bindings for fast lookup
43    compiled: OnceLock<CompiledKeybindings<C>>,
44}
45
46#[derive(Debug, Clone, PartialEq, Eq, Hash)]
47struct ParsedKey {
48    code: KeyCode,
49    modifiers: KeyModifiers,
50}
51
52#[derive(Debug, Clone)]
53struct CompiledKeybindings<C: BindingContext> {
54    global: HashMap<ParsedKey, String>,
55    contexts: HashMap<C, HashMap<ParsedKey, String>>,
56}
57
58impl<C: BindingContext> Default for Keybindings<C> {
59    fn default() -> Self {
60        Self::new()
61    }
62}
63
64impl<C: BindingContext> Clone for Keybindings<C> {
65    fn clone(&self) -> Self {
66        Self {
67            global: self.global.clone(),
68            contexts: self.contexts.clone(),
69            compiled: OnceLock::new(),
70        }
71    }
72}
73
74#[cfg(feature = "serde")]
75impl<C: BindingContext> serde::Serialize for Keybindings<C> {
76    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
77    where
78        S: serde::Serializer,
79    {
80        use serde::ser::SerializeMap;
81
82        // Count total entries: global + all contexts
83        let mut map = serializer.serialize_map(Some(1 + self.contexts.len()))?;
84
85        // Serialize global bindings
86        map.serialize_entry("global", &self.global)?;
87
88        // Serialize context-specific bindings using context names
89        for (context, bindings) in &self.contexts {
90            map.serialize_entry(context.name(), bindings)?;
91        }
92
93        map.end()
94    }
95}
96
97#[cfg(feature = "serde")]
98impl<'de, C: BindingContext> serde::Deserialize<'de> for Keybindings<C> {
99    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
100    where
101        D: serde::Deserializer<'de>,
102    {
103        // Deserialize as a map of string -> bindings
104        let raw: HashMap<String, HashMap<String, Vec<String>>> =
105            serde::Deserialize::deserialize(deserializer)?;
106
107        let mut keybindings = Keybindings::new();
108
109        for (context_name, bindings) in raw {
110            if context_name == "global" {
111                keybindings.global = bindings;
112            } else if let Some(context) = C::from_name(&context_name) {
113                keybindings.contexts.insert(context, bindings);
114            }
115            // Silently ignore unknown contexts (allows forward compatibility)
116        }
117
118        Ok(keybindings)
119    }
120}
121
122impl<C: BindingContext> Keybindings<C> {
123    /// Create a new empty keybindings configuration
124    pub fn new() -> Self {
125        Self {
126            global: HashMap::new(),
127            contexts: HashMap::new(),
128            compiled: OnceLock::new(),
129        }
130    }
131
132    /// Add a global keybinding
133    pub fn add_global(&mut self, command: impl Into<String>, keys: Vec<String>) {
134        self.global.insert(command.into(), keys);
135        self.invalidate_cache();
136    }
137
138    /// Add a context-specific keybinding
139    pub fn add(&mut self, context: C, command: impl Into<String>, keys: Vec<String>) {
140        self.contexts
141            .entry(context)
142            .or_default()
143            .insert(command.into(), keys);
144        self.invalidate_cache();
145    }
146
147    /// Get bindings for a specific context
148    pub fn get_context_bindings(&self, context: C) -> Option<&HashMap<String, Vec<String>>> {
149        self.contexts.get(&context)
150    }
151
152    /// Get global bindings
153    pub fn global_bindings(&self) -> &HashMap<String, Vec<String>> {
154        &self.global
155    }
156
157    /// Get command name for a key event in the given context
158    ///
159    /// First checks context-specific bindings, then falls back to global
160    pub fn get_command(&self, key: KeyEvent, context: C) -> Option<String> {
161        self.get_command_ref(&key, context).map(str::to_string)
162    }
163
164    /// Get command name for a key event by reference
165    ///
166    /// Returns a borrowed `&str` to avoid allocations on hot paths.
167    pub fn get_command_ref(&self, key: &KeyEvent, context: C) -> Option<&str> {
168        let parsed = ParsedKey::from_key_event(key)?;
169        let compiled = self.compiled();
170
171        if let Some(context_bindings) = compiled.contexts.get(&context) {
172            if let Some(cmd) = context_bindings.get(&parsed) {
173                return Some(cmd.as_str());
174            }
175        }
176
177        compiled.global.get(&parsed).map(String::as_str)
178    }
179
180    /// Get the first keybinding string for a command in the given context
181    ///
182    /// First checks context-specific bindings, then falls back to global
183    pub fn get_first_keybinding(&self, command: &str, context: C) -> Option<String> {
184        if let Some(context_bindings) = self.contexts.get(&context) {
185            if let Some(keys) = context_bindings.get(command) {
186                if let Some(first) = keys.first() {
187                    return Some(first.clone());
188                }
189            }
190        }
191
192        self.global
193            .get(command)
194            .and_then(|keys| keys.first().cloned())
195    }
196
197    /// Merge user config onto defaults - user config overrides defaults
198    pub fn merge(mut defaults: Self, user: Self) -> Self {
199        // Merge global
200        for (key, value) in user.global {
201            defaults.global.insert(key, value);
202        }
203
204        // Merge contexts
205        for (context, bindings) in user.contexts {
206            let entry = defaults.contexts.entry(context).or_default();
207            for (key, value) in bindings {
208                entry.insert(key, value);
209            }
210        }
211
212        defaults.invalidate_cache();
213        defaults
214    }
215
216    fn compiled(&self) -> &CompiledKeybindings<C> {
217        self.compiled
218            .get_or_init(|| CompiledKeybindings::build(self))
219    }
220
221    fn invalidate_cache(&mut self) {
222        self.compiled = OnceLock::new();
223    }
224}
225
226impl ParsedKey {
227    fn from_key_event(key: &KeyEvent) -> Option<Self> {
228        Some(Self {
229            code: normalize_code(key.code)?,
230            modifiers: key.modifiers,
231        })
232    }
233
234    fn from_key_string(key_str: &str) -> Option<Self> {
235        let key = parse_key_string(key_str)?;
236        Self::from_key_event(&key)
237    }
238}
239
240impl<C: BindingContext> CompiledKeybindings<C> {
241    fn build(bindings: &Keybindings<C>) -> Self {
242        let mut global = HashMap::new();
243        for (command, keys) in &bindings.global {
244            insert_bindings(&mut global, command, keys);
245        }
246
247        let mut contexts = HashMap::new();
248        for (context, bindings) in &bindings.contexts {
249            let entry = contexts.entry(*context).or_insert_with(HashMap::new);
250            for (command, keys) in bindings {
251                insert_bindings(entry, command, keys);
252            }
253        }
254
255        Self { global, contexts }
256    }
257}
258
259fn insert_bindings(target: &mut HashMap<ParsedKey, String>, command: &str, keys: &[String]) {
260    for key_str in keys {
261        if let Some(parsed) = ParsedKey::from_key_string(key_str) {
262            target.entry(parsed).or_insert_with(|| command.to_string());
263        }
264    }
265}
266
267fn normalize_code(code: KeyCode) -> Option<KeyCode> {
268    match code {
269        KeyCode::Char(c) => normalize_char(c).map(KeyCode::Char),
270        other => Some(other),
271    }
272}
273
274fn normalize_char(c: char) -> Option<char> {
275    if c.is_ascii() {
276        return Some(c.to_ascii_lowercase());
277    }
278    let mut folded = c.to_lowercase();
279    let first = folded.next()?;
280    if folded.next().is_some() {
281        None
282    } else {
283        Some(first)
284    }
285}
286
287/// Parse a key string like "q", "esc", "ctrl+p", "shift+tab" into a KeyEvent
288pub fn parse_key_string(key_str: &str) -> Option<KeyEvent> {
289    let key_str = key_str.trim().to_lowercase();
290
291    if key_str.is_empty() {
292        return None;
293    }
294
295    // Special case: shift+tab should be BackTab
296    if key_str == "shift+tab" || key_str == "backtab" {
297        return Some(KeyEvent {
298            code: KeyCode::BackTab,
299            modifiers: KeyModifiers::SHIFT,
300            kind: crossterm::event::KeyEventKind::Press,
301            state: crossterm::event::KeyEventState::empty(),
302        });
303    }
304
305    // Check for modifiers
306    let parts: Vec<&str> = key_str.split('+').collect();
307    let mut modifiers = KeyModifiers::empty();
308    let key_part = parts.last()?.trim();
309
310    if parts.len() > 1 {
311        for part in &parts[..parts.len() - 1] {
312            match part.trim() {
313                "ctrl" | "control" => modifiers |= KeyModifiers::CONTROL,
314                "shift" => modifiers |= KeyModifiers::SHIFT,
315                "alt" => modifiers |= KeyModifiers::ALT,
316                _ => {}
317            }
318        }
319    }
320
321    // Parse the key code
322    let code = match key_part {
323        "esc" | "escape" => KeyCode::Esc,
324        "enter" | "return" => KeyCode::Enter,
325        "tab" => KeyCode::Tab,
326        "backtab" => {
327            if modifiers.is_empty() {
328                modifiers |= KeyModifiers::SHIFT;
329            }
330            KeyCode::BackTab
331        }
332        "backspace" => KeyCode::Backspace,
333        "up" => KeyCode::Up,
334        "down" => KeyCode::Down,
335        "left" => KeyCode::Left,
336        "right" => KeyCode::Right,
337        "home" => KeyCode::Home,
338        "end" => KeyCode::End,
339        "pageup" => KeyCode::PageUp,
340        "pagedown" => KeyCode::PageDown,
341        "delete" => KeyCode::Delete,
342        "insert" => KeyCode::Insert,
343        "f1" => KeyCode::F(1),
344        "f2" => KeyCode::F(2),
345        "f3" => KeyCode::F(3),
346        "f4" => KeyCode::F(4),
347        "f5" => KeyCode::F(5),
348        "f6" => KeyCode::F(6),
349        "f7" => KeyCode::F(7),
350        "f8" => KeyCode::F(8),
351        "f9" => KeyCode::F(9),
352        "f10" => KeyCode::F(10),
353        "f11" => KeyCode::F(11),
354        "f12" => KeyCode::F(12),
355        "space" => KeyCode::Char(' '),
356        // Single character
357        c if c.len() == 1 => {
358            let ch = c.chars().next()?;
359            KeyCode::Char(ch)
360        }
361        _ => return None,
362    };
363
364    Some(KeyEvent {
365        code,
366        modifiers,
367        kind: crossterm::event::KeyEventKind::Press,
368        state: crossterm::event::KeyEventState::empty(),
369    })
370}
371
372/// Format a key string for display (e.g., "ctrl+p" -> "^P", "q" -> "q", "tab" -> "Tab")
373pub fn format_key_for_display(key_str: &str) -> String {
374    let key_str = key_str.trim().to_lowercase();
375
376    // Handle special cases first
377    if key_str == "shift+tab" || key_str == "backtab" {
378        return "Shift+Tab".to_string();
379    }
380
381    // Check for modifiers
382    let parts: Vec<&str> = key_str.split('+').collect();
383    let mut modifiers = Vec::new();
384    let key_part = parts.last().copied().unwrap_or(key_str.as_str());
385
386    if parts.len() > 1 {
387        for part in &parts[..parts.len() - 1] {
388            match part.trim() {
389                "ctrl" | "control" => modifiers.push("^"),
390                "shift" => modifiers.push("Shift+"),
391                "alt" => modifiers.push("Alt+"),
392                _ => {}
393            }
394        }
395    }
396
397    // Format the key part
398    let key_display = match key_part {
399        "esc" | "escape" => "Esc".to_string(),
400        "enter" | "return" => "Enter".to_string(),
401        "tab" => "Tab".to_string(),
402        "backspace" => "Backspace".to_string(),
403        "up" => "Up".to_string(),
404        "down" => "Down".to_string(),
405        "left" => "Left".to_string(),
406        "right" => "Right".to_string(),
407        "home" => "Home".to_string(),
408        "end" => "End".to_string(),
409        "pageup" => "PgUp".to_string(),
410        "pagedown" => "PgDn".to_string(),
411        "delete" => "Del".to_string(),
412        "insert" => "Ins".to_string(),
413        "space" => "Space".to_string(),
414        "f1" => "F1".to_string(),
415        "f2" => "F2".to_string(),
416        "f3" => "F3".to_string(),
417        "f4" => "F4".to_string(),
418        "f5" => "F5".to_string(),
419        "f6" => "F6".to_string(),
420        "f7" => "F7".to_string(),
421        "f8" => "F8".to_string(),
422        "f9" => "F9".to_string(),
423        "f10" => "F10".to_string(),
424        "f11" => "F11".to_string(),
425        "f12" => "F12".to_string(),
426        // Single character - capitalize for display
427        c if c.len() == 1 => {
428            let ch = c.chars().next().unwrap();
429            // Keep special characters as-is, capitalize letters
430            if ch.is_alphabetic() {
431                ch.to_uppercase().collect::<String>()
432            } else {
433                ch.to_string()
434            }
435        }
436        _ => key_part.to_string(),
437    };
438
439    // Combine modifiers with key
440    if modifiers.is_empty() {
441        key_display
442    } else {
443        format!("{}{}", modifiers.join(""), key_display)
444    }
445}
446
447#[cfg(test)]
448mod tests {
449    use super::*;
450    use crossterm::event::{KeyCode, KeyModifiers};
451
452    // Test context for unit tests
453    #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
454    enum TestContext {
455        Default,
456        Search,
457    }
458
459    impl BindingContext for TestContext {
460        fn name(&self) -> &'static str {
461            match self {
462                TestContext::Default => "default",
463                TestContext::Search => "search",
464            }
465        }
466
467        fn from_name(name: &str) -> Option<Self> {
468            match name {
469                "default" => Some(TestContext::Default),
470                "search" => Some(TestContext::Search),
471                _ => None,
472            }
473        }
474
475        fn all() -> &'static [Self] {
476            &[TestContext::Default, TestContext::Search]
477        }
478    }
479
480    #[test]
481    fn test_parse_simple_key() {
482        let result = parse_key_string("q").unwrap();
483        assert_eq!(result.code, KeyCode::Char('q'));
484        assert_eq!(result.modifiers, KeyModifiers::empty());
485    }
486
487    #[test]
488    fn test_parse_esc() {
489        let result = parse_key_string("esc").unwrap();
490        assert_eq!(result.code, KeyCode::Esc);
491    }
492
493    #[test]
494    fn test_parse_ctrl_key() {
495        let result = parse_key_string("ctrl+p").unwrap();
496        assert_eq!(result.code, KeyCode::Char('p'));
497        assert!(result.modifiers.contains(KeyModifiers::CONTROL));
498    }
499
500    #[test]
501    fn test_parse_shift_tab() {
502        let result = parse_key_string("shift+tab").unwrap();
503        assert_eq!(result.code, KeyCode::BackTab);
504        assert!(result.modifiers.contains(KeyModifiers::SHIFT));
505    }
506
507    #[test]
508    fn test_parse_backtab() {
509        let result = parse_key_string("backtab").unwrap();
510        assert_eq!(result.code, KeyCode::BackTab);
511        assert!(result.modifiers.contains(KeyModifiers::SHIFT));
512    }
513
514    #[test]
515    fn test_parse_arrow_keys() {
516        let result = parse_key_string("up").unwrap();
517        assert_eq!(result.code, KeyCode::Up);
518
519        let result = parse_key_string("down").unwrap();
520        assert_eq!(result.code, KeyCode::Down);
521    }
522
523    #[test]
524    fn test_get_command() {
525        let mut bindings: Keybindings<TestContext> = Keybindings::new();
526        bindings.add_global("quit", vec!["q".to_string()]);
527        bindings.add(TestContext::Search, "clear", vec!["esc".to_string()]);
528
529        let key_q = KeyEvent {
530            code: KeyCode::Char('q'),
531            modifiers: KeyModifiers::empty(),
532            kind: crossterm::event::KeyEventKind::Press,
533            state: crossterm::event::KeyEventState::empty(),
534        };
535
536        // Global should work in any context
537        assert_eq!(
538            bindings.get_command(key_q, TestContext::Default),
539            Some("quit".to_string())
540        );
541        assert_eq!(
542            bindings.get_command(key_q, TestContext::Search),
543            Some("quit".to_string())
544        );
545
546        // Context-specific
547        let key_esc = KeyEvent {
548            code: KeyCode::Esc,
549            modifiers: KeyModifiers::empty(),
550            kind: crossterm::event::KeyEventKind::Press,
551            state: crossterm::event::KeyEventState::empty(),
552        };
553
554        assert_eq!(
555            bindings.get_command(key_esc, TestContext::Search),
556            Some("clear".to_string())
557        );
558        assert_eq!(bindings.get_command(key_esc, TestContext::Default), None);
559    }
560
561    #[test]
562    fn test_merge() {
563        let mut defaults: Keybindings<TestContext> = Keybindings::new();
564        defaults.add_global("quit", vec!["q".to_string()]);
565        defaults.add_global("help", vec!["?".to_string()]);
566
567        let mut user: Keybindings<TestContext> = Keybindings::new();
568        user.add_global("quit", vec!["x".to_string()]); // Override
569
570        let merged = Keybindings::merge(defaults, user);
571
572        // User override should be present
573        assert_eq!(
574            merged.global_bindings().get("quit"),
575            Some(&vec!["x".to_string()])
576        );
577
578        // Default should still be there
579        assert_eq!(
580            merged.global_bindings().get("help"),
581            Some(&vec!["?".to_string()])
582        );
583    }
584
585    #[test]
586    fn test_format_key_for_display() {
587        assert_eq!(format_key_for_display("q"), "Q");
588        assert_eq!(format_key_for_display("ctrl+p"), "^P");
589        assert_eq!(format_key_for_display("esc"), "Esc");
590        assert_eq!(format_key_for_display("shift+tab"), "Shift+Tab");
591    }
592}