1use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
4use serde::{Deserialize, Deserializer, Serialize, Serializer};
5use std::collections::HashMap;
6use std::hash::Hash;
7
8pub trait BindingContext: Clone + Copy + Eq + Hash {
23 fn name(&self) -> &'static str;
25
26 fn from_name(name: &str) -> Option<Self>;
28
29 fn all() -> &'static [Self];
31}
32
33#[derive(Debug, Clone)]
37pub struct Keybindings<C: BindingContext> {
38 global: HashMap<String, Vec<String>>,
40 contexts: HashMap<C, HashMap<String, Vec<String>>>,
42}
43
44impl<C: BindingContext> Default for Keybindings<C> {
45 fn default() -> Self {
46 Self::new()
47 }
48}
49
50impl<C: BindingContext> Serialize for Keybindings<C> {
51 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
52 where
53 S: Serializer,
54 {
55 use serde::ser::SerializeMap;
56
57 let mut map = serializer.serialize_map(Some(1 + self.contexts.len()))?;
59
60 map.serialize_entry("global", &self.global)?;
62
63 for (context, bindings) in &self.contexts {
65 map.serialize_entry(context.name(), bindings)?;
66 }
67
68 map.end()
69 }
70}
71
72impl<'de, C: BindingContext> Deserialize<'de> for Keybindings<C> {
73 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
74 where
75 D: Deserializer<'de>,
76 {
77 let raw: HashMap<String, HashMap<String, Vec<String>>> =
79 HashMap::deserialize(deserializer)?;
80
81 let mut keybindings = Keybindings::new();
82
83 for (context_name, bindings) in raw {
84 if context_name == "global" {
85 keybindings.global = bindings;
86 } else if let Some(context) = C::from_name(&context_name) {
87 keybindings.contexts.insert(context, bindings);
88 }
89 }
91
92 Ok(keybindings)
93 }
94}
95
96impl<C: BindingContext> Keybindings<C> {
97 pub fn new() -> Self {
99 Self {
100 global: HashMap::new(),
101 contexts: HashMap::new(),
102 }
103 }
104
105 pub fn add_global(&mut self, command: impl Into<String>, keys: Vec<String>) {
107 self.global.insert(command.into(), keys);
108 }
109
110 pub fn add(&mut self, context: C, command: impl Into<String>, keys: Vec<String>) {
112 self.contexts
113 .entry(context)
114 .or_default()
115 .insert(command.into(), keys);
116 }
117
118 pub fn get_context_bindings(&self, context: C) -> Option<&HashMap<String, Vec<String>>> {
120 self.contexts.get(&context)
121 }
122
123 pub fn global_bindings(&self) -> &HashMap<String, Vec<String>> {
125 &self.global
126 }
127
128 pub fn get_command(&self, key: KeyEvent, context: C) -> Option<String> {
132 if let Some(context_bindings) = self.contexts.get(&context) {
134 if let Some(cmd) = self.match_key_in_bindings(key, context_bindings) {
135 return Some(cmd);
136 }
137 }
138
139 self.match_key_in_bindings(key, &self.global)
141 }
142
143 fn match_key_in_bindings(
145 &self,
146 key: KeyEvent,
147 bindings: &HashMap<String, Vec<String>>,
148 ) -> Option<String> {
149 for (command, keys) in bindings {
150 for key_str in keys {
151 if let Some(parsed_key) = parse_key_string(key_str) {
152 let codes_match = match (&parsed_key.code, &key.code) {
155 (KeyCode::Char(c1), KeyCode::Char(c2)) => {
156 c1.to_lowercase().to_string() == c2.to_lowercase().to_string()
157 }
158 _ => parsed_key.code == key.code,
159 };
160
161 if codes_match && parsed_key.modifiers == key.modifiers {
162 return Some(command.clone());
163 }
164 }
165 }
166 }
167 None
168 }
169
170 pub fn get_first_keybinding(&self, command: &str, context: C) -> Option<String> {
174 if let Some(context_bindings) = self.contexts.get(&context) {
175 if let Some(keys) = context_bindings.get(command) {
176 if let Some(first) = keys.first() {
177 return Some(first.clone());
178 }
179 }
180 }
181
182 self.global
183 .get(command)
184 .and_then(|keys| keys.first().cloned())
185 }
186
187 pub fn merge(mut defaults: Self, user: Self) -> Self {
189 for (key, value) in user.global {
191 defaults.global.insert(key, value);
192 }
193
194 for (context, bindings) in user.contexts {
196 let entry = defaults.contexts.entry(context).or_default();
197 for (key, value) in bindings {
198 entry.insert(key, value);
199 }
200 }
201
202 defaults
203 }
204}
205
206pub fn parse_key_string(key_str: &str) -> Option<KeyEvent> {
208 let key_str = key_str.trim().to_lowercase();
209
210 if key_str.is_empty() {
211 return None;
212 }
213
214 if key_str == "shift+tab" || key_str == "backtab" {
216 return Some(KeyEvent {
217 code: KeyCode::BackTab,
218 modifiers: KeyModifiers::SHIFT,
219 kind: crossterm::event::KeyEventKind::Press,
220 state: crossterm::event::KeyEventState::empty(),
221 });
222 }
223
224 let parts: Vec<&str> = key_str.split('+').collect();
226 let mut modifiers = KeyModifiers::empty();
227 let key_part = parts.last()?.trim();
228
229 if parts.len() > 1 {
230 for part in &parts[..parts.len() - 1] {
231 match part.trim() {
232 "ctrl" | "control" => modifiers |= KeyModifiers::CONTROL,
233 "shift" => modifiers |= KeyModifiers::SHIFT,
234 "alt" => modifiers |= KeyModifiers::ALT,
235 _ => {}
236 }
237 }
238 }
239
240 let code = match key_part {
242 "esc" | "escape" => KeyCode::Esc,
243 "enter" | "return" => KeyCode::Enter,
244 "tab" => KeyCode::Tab,
245 "backtab" => {
246 if modifiers.is_empty() {
247 modifiers |= KeyModifiers::SHIFT;
248 }
249 KeyCode::BackTab
250 }
251 "backspace" => KeyCode::Backspace,
252 "up" => KeyCode::Up,
253 "down" => KeyCode::Down,
254 "left" => KeyCode::Left,
255 "right" => KeyCode::Right,
256 "home" => KeyCode::Home,
257 "end" => KeyCode::End,
258 "pageup" => KeyCode::PageUp,
259 "pagedown" => KeyCode::PageDown,
260 "delete" => KeyCode::Delete,
261 "insert" => KeyCode::Insert,
262 "f1" => KeyCode::F(1),
263 "f2" => KeyCode::F(2),
264 "f3" => KeyCode::F(3),
265 "f4" => KeyCode::F(4),
266 "f5" => KeyCode::F(5),
267 "f6" => KeyCode::F(6),
268 "f7" => KeyCode::F(7),
269 "f8" => KeyCode::F(8),
270 "f9" => KeyCode::F(9),
271 "f10" => KeyCode::F(10),
272 "f11" => KeyCode::F(11),
273 "f12" => KeyCode::F(12),
274 "space" => KeyCode::Char(' '),
275 c if c.len() == 1 => {
277 let ch = c.chars().next()?;
278 KeyCode::Char(ch)
279 }
280 _ => return None,
281 };
282
283 Some(KeyEvent {
284 code,
285 modifiers,
286 kind: crossterm::event::KeyEventKind::Press,
287 state: crossterm::event::KeyEventState::empty(),
288 })
289}
290
291pub fn format_key_for_display(key_str: &str) -> String {
293 let key_str = key_str.trim().to_lowercase();
294
295 if key_str == "shift+tab" || key_str == "backtab" {
297 return "Shift+Tab".to_string();
298 }
299
300 let parts: Vec<&str> = key_str.split('+').collect();
302 let mut modifiers = Vec::new();
303 let key_part = parts.last().copied().unwrap_or(key_str.as_str());
304
305 if parts.len() > 1 {
306 for part in &parts[..parts.len() - 1] {
307 match part.trim() {
308 "ctrl" | "control" => modifiers.push("^"),
309 "shift" => modifiers.push("Shift+"),
310 "alt" => modifiers.push("Alt+"),
311 _ => {}
312 }
313 }
314 }
315
316 let key_display = match key_part {
318 "esc" | "escape" => "Esc".to_string(),
319 "enter" | "return" => "Enter".to_string(),
320 "tab" => "Tab".to_string(),
321 "backspace" => "Backspace".to_string(),
322 "up" => "Up".to_string(),
323 "down" => "Down".to_string(),
324 "left" => "Left".to_string(),
325 "right" => "Right".to_string(),
326 "home" => "Home".to_string(),
327 "end" => "End".to_string(),
328 "pageup" => "PgUp".to_string(),
329 "pagedown" => "PgDn".to_string(),
330 "delete" => "Del".to_string(),
331 "insert" => "Ins".to_string(),
332 "space" => "Space".to_string(),
333 "f1" => "F1".to_string(),
334 "f2" => "F2".to_string(),
335 "f3" => "F3".to_string(),
336 "f4" => "F4".to_string(),
337 "f5" => "F5".to_string(),
338 "f6" => "F6".to_string(),
339 "f7" => "F7".to_string(),
340 "f8" => "F8".to_string(),
341 "f9" => "F9".to_string(),
342 "f10" => "F10".to_string(),
343 "f11" => "F11".to_string(),
344 "f12" => "F12".to_string(),
345 c if c.len() == 1 => {
347 let ch = c.chars().next().unwrap();
348 if ch.is_alphabetic() {
350 ch.to_uppercase().collect::<String>()
351 } else {
352 ch.to_string()
353 }
354 }
355 _ => key_part.to_string(),
356 };
357
358 if modifiers.is_empty() {
360 key_display
361 } else {
362 format!("{}{}", modifiers.join(""), key_display)
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369 use crossterm::event::{KeyCode, KeyModifiers};
370
371 #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
373 enum TestContext {
374 Default,
375 Search,
376 }
377
378 impl BindingContext for TestContext {
379 fn name(&self) -> &'static str {
380 match self {
381 TestContext::Default => "default",
382 TestContext::Search => "search",
383 }
384 }
385
386 fn from_name(name: &str) -> Option<Self> {
387 match name {
388 "default" => Some(TestContext::Default),
389 "search" => Some(TestContext::Search),
390 _ => None,
391 }
392 }
393
394 fn all() -> &'static [Self] {
395 &[TestContext::Default, TestContext::Search]
396 }
397 }
398
399 #[test]
400 fn test_parse_simple_key() {
401 let result = parse_key_string("q").unwrap();
402 assert_eq!(result.code, KeyCode::Char('q'));
403 assert_eq!(result.modifiers, KeyModifiers::empty());
404 }
405
406 #[test]
407 fn test_parse_esc() {
408 let result = parse_key_string("esc").unwrap();
409 assert_eq!(result.code, KeyCode::Esc);
410 }
411
412 #[test]
413 fn test_parse_ctrl_key() {
414 let result = parse_key_string("ctrl+p").unwrap();
415 assert_eq!(result.code, KeyCode::Char('p'));
416 assert!(result.modifiers.contains(KeyModifiers::CONTROL));
417 }
418
419 #[test]
420 fn test_parse_shift_tab() {
421 let result = parse_key_string("shift+tab").unwrap();
422 assert_eq!(result.code, KeyCode::BackTab);
423 assert!(result.modifiers.contains(KeyModifiers::SHIFT));
424 }
425
426 #[test]
427 fn test_parse_backtab() {
428 let result = parse_key_string("backtab").unwrap();
429 assert_eq!(result.code, KeyCode::BackTab);
430 assert!(result.modifiers.contains(KeyModifiers::SHIFT));
431 }
432
433 #[test]
434 fn test_parse_arrow_keys() {
435 let result = parse_key_string("up").unwrap();
436 assert_eq!(result.code, KeyCode::Up);
437
438 let result = parse_key_string("down").unwrap();
439 assert_eq!(result.code, KeyCode::Down);
440 }
441
442 #[test]
443 fn test_get_command() {
444 let mut bindings: Keybindings<TestContext> = Keybindings::new();
445 bindings.add_global("quit", vec!["q".to_string()]);
446 bindings.add(TestContext::Search, "clear", vec!["esc".to_string()]);
447
448 let key_q = KeyEvent {
449 code: KeyCode::Char('q'),
450 modifiers: KeyModifiers::empty(),
451 kind: crossterm::event::KeyEventKind::Press,
452 state: crossterm::event::KeyEventState::empty(),
453 };
454
455 assert_eq!(
457 bindings.get_command(key_q, TestContext::Default),
458 Some("quit".to_string())
459 );
460 assert_eq!(
461 bindings.get_command(key_q, TestContext::Search),
462 Some("quit".to_string())
463 );
464
465 let key_esc = KeyEvent {
467 code: KeyCode::Esc,
468 modifiers: KeyModifiers::empty(),
469 kind: crossterm::event::KeyEventKind::Press,
470 state: crossterm::event::KeyEventState::empty(),
471 };
472
473 assert_eq!(
474 bindings.get_command(key_esc, TestContext::Search),
475 Some("clear".to_string())
476 );
477 assert_eq!(bindings.get_command(key_esc, TestContext::Default), None);
478 }
479
480 #[test]
481 fn test_merge() {
482 let mut defaults: Keybindings<TestContext> = Keybindings::new();
483 defaults.add_global("quit", vec!["q".to_string()]);
484 defaults.add_global("help", vec!["?".to_string()]);
485
486 let mut user: Keybindings<TestContext> = Keybindings::new();
487 user.add_global("quit", vec!["x".to_string()]); let merged = Keybindings::merge(defaults, user);
490
491 assert_eq!(
493 merged.global_bindings().get("quit"),
494 Some(&vec!["x".to_string()])
495 );
496
497 assert_eq!(
499 merged.global_bindings().get("help"),
500 Some(&vec!["?".to_string()])
501 );
502 }
503
504 #[test]
505 fn test_format_key_for_display() {
506 assert_eq!(format_key_for_display("q"), "Q");
507 assert_eq!(format_key_for_display("ctrl+p"), "^P");
508 assert_eq!(format_key_for_display("esc"), "Esc");
509 assert_eq!(format_key_for_display("shift+tab"), "Shift+Tab");
510 }
511}