1use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
4use std::collections::HashMap;
5use std::hash::Hash;
6use std::sync::OnceLock;
7
8pub trait BindingContext: Clone + Copy + Eq + Hash {
23 fn name(&self) -> &'static str;
25
26 fn from_name(name: &str) -> Option<Self>;
28
29 fn all() -> &'static [Self];
31}
32
33#[derive(Debug)]
37pub struct Keybindings<C: BindingContext> {
38 global: HashMap<String, Vec<String>>,
40 contexts: HashMap<C, HashMap<String, Vec<String>>>,
42 compiled: OnceLock<CompiledKeybindings<C>>,
44}
45
46#[derive(Debug, Clone, PartialEq, Eq, Hash)]
47struct ParsedKey {
48 code: KeyCode,
49 modifiers: KeyModifiers,
50}
51
52#[derive(Debug, Clone)]
53struct CompiledKeybindings<C: BindingContext> {
54 global: HashMap<ParsedKey, String>,
55 contexts: HashMap<C, HashMap<ParsedKey, String>>,
56}
57
58impl<C: BindingContext> Default for Keybindings<C> {
59 fn default() -> Self {
60 Self::new()
61 }
62}
63
64impl<C: BindingContext> Clone for Keybindings<C> {
65 fn clone(&self) -> Self {
66 Self {
67 global: self.global.clone(),
68 contexts: self.contexts.clone(),
69 compiled: OnceLock::new(),
70 }
71 }
72}
73
74#[cfg(feature = "serde")]
75impl<C: BindingContext> serde::Serialize for Keybindings<C> {
76 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
77 where
78 S: serde::Serializer,
79 {
80 use serde::ser::SerializeMap;
81
82 let mut map = serializer.serialize_map(Some(1 + self.contexts.len()))?;
84
85 map.serialize_entry("global", &self.global)?;
87
88 for (context, bindings) in &self.contexts {
90 map.serialize_entry(context.name(), bindings)?;
91 }
92
93 map.end()
94 }
95}
96
97#[cfg(feature = "serde")]
98impl<'de, C: BindingContext> serde::Deserialize<'de> for Keybindings<C> {
99 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
100 where
101 D: serde::Deserializer<'de>,
102 {
103 let raw: HashMap<String, HashMap<String, Vec<String>>> =
105 serde::Deserialize::deserialize(deserializer)?;
106
107 let mut keybindings = Keybindings::new();
108
109 for (context_name, bindings) in raw {
110 if context_name == "global" {
111 keybindings.global = bindings;
112 } else if let Some(context) = C::from_name(&context_name) {
113 keybindings.contexts.insert(context, bindings);
114 }
115 }
117
118 Ok(keybindings)
119 }
120}
121
122impl<C: BindingContext> Keybindings<C> {
123 pub fn new() -> Self {
125 Self {
126 global: HashMap::new(),
127 contexts: HashMap::new(),
128 compiled: OnceLock::new(),
129 }
130 }
131
132 pub fn add_global(&mut self, command: impl Into<String>, keys: Vec<String>) {
134 self.global.insert(command.into(), keys);
135 self.invalidate_cache();
136 }
137
138 pub fn add(&mut self, context: C, command: impl Into<String>, keys: Vec<String>) {
140 self.contexts
141 .entry(context)
142 .or_default()
143 .insert(command.into(), keys);
144 self.invalidate_cache();
145 }
146
147 pub fn get_context_bindings(&self, context: C) -> Option<&HashMap<String, Vec<String>>> {
149 self.contexts.get(&context)
150 }
151
152 pub fn global_bindings(&self) -> &HashMap<String, Vec<String>> {
154 &self.global
155 }
156
157 pub fn get_command(&self, key: KeyEvent, context: C) -> Option<String> {
161 self.get_command_ref(&key, context).map(str::to_string)
162 }
163
164 pub fn get_command_ref(&self, key: &KeyEvent, context: C) -> Option<&str> {
168 let parsed = ParsedKey::from_key_event(key)?;
169 let compiled = self.compiled();
170
171 if let Some(context_bindings) = compiled.contexts.get(&context) {
172 if let Some(cmd) = context_bindings.get(&parsed) {
173 return Some(cmd.as_str());
174 }
175 }
176
177 compiled.global.get(&parsed).map(String::as_str)
178 }
179
180 pub fn get_first_keybinding(&self, command: &str, context: C) -> Option<String> {
184 if let Some(context_bindings) = self.contexts.get(&context) {
185 if let Some(keys) = context_bindings.get(command) {
186 if let Some(first) = keys.first() {
187 return Some(first.clone());
188 }
189 }
190 }
191
192 self.global
193 .get(command)
194 .and_then(|keys| keys.first().cloned())
195 }
196
197 pub fn merge(mut defaults: Self, user: Self) -> Self {
199 for (key, value) in user.global {
201 defaults.global.insert(key, value);
202 }
203
204 for (context, bindings) in user.contexts {
206 let entry = defaults.contexts.entry(context).or_default();
207 for (key, value) in bindings {
208 entry.insert(key, value);
209 }
210 }
211
212 defaults.invalidate_cache();
213 defaults
214 }
215
216 fn compiled(&self) -> &CompiledKeybindings<C> {
217 self.compiled
218 .get_or_init(|| CompiledKeybindings::build(self))
219 }
220
221 fn invalidate_cache(&mut self) {
222 self.compiled = OnceLock::new();
223 }
224}
225
226impl ParsedKey {
227 fn from_key_event(key: &KeyEvent) -> Option<Self> {
228 Some(Self {
229 code: normalize_code(key.code)?,
230 modifiers: key.modifiers,
231 })
232 }
233
234 fn from_key_string(key_str: &str) -> Option<Self> {
235 let key = parse_key_string(key_str)?;
236 Self::from_key_event(&key)
237 }
238}
239
240impl<C: BindingContext> CompiledKeybindings<C> {
241 fn build(bindings: &Keybindings<C>) -> Self {
242 let mut global = HashMap::new();
243 for (command, keys) in &bindings.global {
244 insert_bindings(&mut global, command, keys);
245 }
246
247 let mut contexts = HashMap::new();
248 for (context, bindings) in &bindings.contexts {
249 let entry = contexts.entry(*context).or_insert_with(HashMap::new);
250 for (command, keys) in bindings {
251 insert_bindings(entry, command, keys);
252 }
253 }
254
255 Self { global, contexts }
256 }
257}
258
259fn insert_bindings(target: &mut HashMap<ParsedKey, String>, command: &str, keys: &[String]) {
260 for key_str in keys {
261 if let Some(parsed) = ParsedKey::from_key_string(key_str) {
262 target.entry(parsed).or_insert_with(|| command.to_string());
263 }
264 }
265}
266
267fn normalize_code(code: KeyCode) -> Option<KeyCode> {
268 match code {
269 KeyCode::Char(c) => normalize_char(c).map(KeyCode::Char),
270 other => Some(other),
271 }
272}
273
274fn normalize_char(c: char) -> Option<char> {
275 if c.is_ascii() {
276 return Some(c.to_ascii_lowercase());
277 }
278 let mut folded = c.to_lowercase();
279 let first = folded.next()?;
280 if folded.next().is_some() {
281 None
282 } else {
283 Some(first)
284 }
285}
286
287pub fn parse_key_string(key_str: &str) -> Option<KeyEvent> {
289 let key_str = key_str.trim().to_lowercase();
290
291 if key_str.is_empty() {
292 return None;
293 }
294
295 if key_str == "shift+tab" || key_str == "backtab" {
297 return Some(KeyEvent {
298 code: KeyCode::BackTab,
299 modifiers: KeyModifiers::SHIFT,
300 kind: crossterm::event::KeyEventKind::Press,
301 state: crossterm::event::KeyEventState::empty(),
302 });
303 }
304
305 let parts: Vec<&str> = key_str.split('+').collect();
307 let mut modifiers = KeyModifiers::empty();
308 let key_part = parts.last()?.trim();
309
310 if parts.len() > 1 {
311 for part in &parts[..parts.len() - 1] {
312 match part.trim() {
313 "ctrl" | "control" => modifiers |= KeyModifiers::CONTROL,
314 "shift" => modifiers |= KeyModifiers::SHIFT,
315 "alt" => modifiers |= KeyModifiers::ALT,
316 _ => {}
317 }
318 }
319 }
320
321 let code = match key_part {
323 "esc" | "escape" => KeyCode::Esc,
324 "enter" | "return" => KeyCode::Enter,
325 "tab" => KeyCode::Tab,
326 "backtab" => {
327 if modifiers.is_empty() {
328 modifiers |= KeyModifiers::SHIFT;
329 }
330 KeyCode::BackTab
331 }
332 "backspace" => KeyCode::Backspace,
333 "up" => KeyCode::Up,
334 "down" => KeyCode::Down,
335 "left" => KeyCode::Left,
336 "right" => KeyCode::Right,
337 "home" => KeyCode::Home,
338 "end" => KeyCode::End,
339 "pageup" => KeyCode::PageUp,
340 "pagedown" => KeyCode::PageDown,
341 "delete" => KeyCode::Delete,
342 "insert" => KeyCode::Insert,
343 "f1" => KeyCode::F(1),
344 "f2" => KeyCode::F(2),
345 "f3" => KeyCode::F(3),
346 "f4" => KeyCode::F(4),
347 "f5" => KeyCode::F(5),
348 "f6" => KeyCode::F(6),
349 "f7" => KeyCode::F(7),
350 "f8" => KeyCode::F(8),
351 "f9" => KeyCode::F(9),
352 "f10" => KeyCode::F(10),
353 "f11" => KeyCode::F(11),
354 "f12" => KeyCode::F(12),
355 "space" => KeyCode::Char(' '),
356 c if c.len() == 1 => {
358 let ch = c.chars().next()?;
359 KeyCode::Char(ch)
360 }
361 _ => return None,
362 };
363
364 Some(KeyEvent {
365 code,
366 modifiers,
367 kind: crossterm::event::KeyEventKind::Press,
368 state: crossterm::event::KeyEventState::empty(),
369 })
370}
371
372pub fn format_key_for_display(key_str: &str) -> String {
374 let key_str = key_str.trim().to_lowercase();
375
376 if key_str == "shift+tab" || key_str == "backtab" {
378 return "Shift+Tab".to_string();
379 }
380
381 let parts: Vec<&str> = key_str.split('+').collect();
383 let mut modifiers = Vec::new();
384 let key_part = parts.last().copied().unwrap_or(key_str.as_str());
385
386 if parts.len() > 1 {
387 for part in &parts[..parts.len() - 1] {
388 match part.trim() {
389 "ctrl" | "control" => modifiers.push("^"),
390 "shift" => modifiers.push("Shift+"),
391 "alt" => modifiers.push("Alt+"),
392 _ => {}
393 }
394 }
395 }
396
397 let key_display = match key_part {
399 "esc" | "escape" => "Esc".to_string(),
400 "enter" | "return" => "Enter".to_string(),
401 "tab" => "Tab".to_string(),
402 "backspace" => "Backspace".to_string(),
403 "up" => "Up".to_string(),
404 "down" => "Down".to_string(),
405 "left" => "Left".to_string(),
406 "right" => "Right".to_string(),
407 "home" => "Home".to_string(),
408 "end" => "End".to_string(),
409 "pageup" => "PgUp".to_string(),
410 "pagedown" => "PgDn".to_string(),
411 "delete" => "Del".to_string(),
412 "insert" => "Ins".to_string(),
413 "space" => "Space".to_string(),
414 "f1" => "F1".to_string(),
415 "f2" => "F2".to_string(),
416 "f3" => "F3".to_string(),
417 "f4" => "F4".to_string(),
418 "f5" => "F5".to_string(),
419 "f6" => "F6".to_string(),
420 "f7" => "F7".to_string(),
421 "f8" => "F8".to_string(),
422 "f9" => "F9".to_string(),
423 "f10" => "F10".to_string(),
424 "f11" => "F11".to_string(),
425 "f12" => "F12".to_string(),
426 c if c.len() == 1 => {
428 let ch = c.chars().next().unwrap();
429 if ch.is_alphabetic() {
431 ch.to_uppercase().collect::<String>()
432 } else {
433 ch.to_string()
434 }
435 }
436 _ => key_part.to_string(),
437 };
438
439 if modifiers.is_empty() {
441 key_display
442 } else {
443 format!("{}{}", modifiers.join(""), key_display)
444 }
445}
446
447#[cfg(test)]
448mod tests {
449 use super::*;
450 use crossterm::event::{KeyCode, KeyModifiers};
451
452 #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
454 enum TestContext {
455 Default,
456 Search,
457 }
458
459 impl BindingContext for TestContext {
460 fn name(&self) -> &'static str {
461 match self {
462 TestContext::Default => "default",
463 TestContext::Search => "search",
464 }
465 }
466
467 fn from_name(name: &str) -> Option<Self> {
468 match name {
469 "default" => Some(TestContext::Default),
470 "search" => Some(TestContext::Search),
471 _ => None,
472 }
473 }
474
475 fn all() -> &'static [Self] {
476 &[TestContext::Default, TestContext::Search]
477 }
478 }
479
480 #[test]
481 fn test_parse_simple_key() {
482 let result = parse_key_string("q").unwrap();
483 assert_eq!(result.code, KeyCode::Char('q'));
484 assert_eq!(result.modifiers, KeyModifiers::empty());
485 }
486
487 #[test]
488 fn test_parse_esc() {
489 let result = parse_key_string("esc").unwrap();
490 assert_eq!(result.code, KeyCode::Esc);
491 }
492
493 #[test]
494 fn test_parse_ctrl_key() {
495 let result = parse_key_string("ctrl+p").unwrap();
496 assert_eq!(result.code, KeyCode::Char('p'));
497 assert!(result.modifiers.contains(KeyModifiers::CONTROL));
498 }
499
500 #[test]
501 fn test_parse_shift_tab() {
502 let result = parse_key_string("shift+tab").unwrap();
503 assert_eq!(result.code, KeyCode::BackTab);
504 assert!(result.modifiers.contains(KeyModifiers::SHIFT));
505 }
506
507 #[test]
508 fn test_parse_backtab() {
509 let result = parse_key_string("backtab").unwrap();
510 assert_eq!(result.code, KeyCode::BackTab);
511 assert!(result.modifiers.contains(KeyModifiers::SHIFT));
512 }
513
514 #[test]
515 fn test_parse_arrow_keys() {
516 let result = parse_key_string("up").unwrap();
517 assert_eq!(result.code, KeyCode::Up);
518
519 let result = parse_key_string("down").unwrap();
520 assert_eq!(result.code, KeyCode::Down);
521 }
522
523 #[test]
524 fn test_get_command() {
525 let mut bindings: Keybindings<TestContext> = Keybindings::new();
526 bindings.add_global("quit", vec!["q".to_string()]);
527 bindings.add(TestContext::Search, "clear", vec!["esc".to_string()]);
528
529 let key_q = KeyEvent {
530 code: KeyCode::Char('q'),
531 modifiers: KeyModifiers::empty(),
532 kind: crossterm::event::KeyEventKind::Press,
533 state: crossterm::event::KeyEventState::empty(),
534 };
535
536 assert_eq!(
538 bindings.get_command(key_q, TestContext::Default),
539 Some("quit".to_string())
540 );
541 assert_eq!(
542 bindings.get_command(key_q, TestContext::Search),
543 Some("quit".to_string())
544 );
545
546 let key_esc = KeyEvent {
548 code: KeyCode::Esc,
549 modifiers: KeyModifiers::empty(),
550 kind: crossterm::event::KeyEventKind::Press,
551 state: crossterm::event::KeyEventState::empty(),
552 };
553
554 assert_eq!(
555 bindings.get_command(key_esc, TestContext::Search),
556 Some("clear".to_string())
557 );
558 assert_eq!(bindings.get_command(key_esc, TestContext::Default), None);
559 }
560
561 #[test]
562 fn test_merge() {
563 let mut defaults: Keybindings<TestContext> = Keybindings::new();
564 defaults.add_global("quit", vec!["q".to_string()]);
565 defaults.add_global("help", vec!["?".to_string()]);
566
567 let mut user: Keybindings<TestContext> = Keybindings::new();
568 user.add_global("quit", vec!["x".to_string()]); let merged = Keybindings::merge(defaults, user);
571
572 assert_eq!(
574 merged.global_bindings().get("quit"),
575 Some(&vec!["x".to_string()])
576 );
577
578 assert_eq!(
580 merged.global_bindings().get("help"),
581 Some(&vec!["?".to_string()])
582 );
583 }
584
585 #[test]
586 fn test_format_key_for_display() {
587 assert_eq!(format_key_for_display("q"), "Q");
588 assert_eq!(format_key_for_display("ctrl+p"), "^P");
589 assert_eq!(format_key_for_display("esc"), "Esc");
590 assert_eq!(format_key_for_display("shift+tab"), "Shift+Tab");
591 }
592}