1#![allow(clippy::type_complexity)]
2use std::collections::HashMap;
10use std::sync::{Arc, Mutex};
11use std::time::Duration;
12
13use serde::{Deserialize, Serialize};
14use tokio::sync::oneshot;
15use tokio::task::JoinHandle;
16
17use crate::generated::approval_request::ApprovalRequest;
18
19#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
20#[serde(rename_all = "snake_case")]
21pub enum ApprovalDecision {
22 Approve,
23 Deny,
24}
25
26#[derive(Debug, thiserror::Error)]
27pub enum ApprovalError {
28 #[error("approval queue full ({0} pending)")]
29 QueueFull(usize),
30 #[error("approval result channel closed")]
31 ChannelClosed,
32}
33
34#[derive(Clone, Debug)]
35pub struct ApprovalResult {
36 pub decision: ApprovalDecision,
37 pub note: Option<String>,
38}
39
40struct PendingRecord {
41 request: ApprovalRequest,
42 responder: Option<oneshot::Sender<ApprovalResult>>,
43 timer: Option<JoinHandle<()>>,
44}
45
46pub struct ApprovalQueue {
47 pending: Arc<Mutex<HashMap<String, PendingRecord>>>,
48 max_pending: usize,
49 default_timeout: Duration,
50 on_push: Option<Arc<dyn Fn(&ApprovalRequest) + Send + Sync>>,
51 on_resolve: Option<Arc<dyn Fn(&ApprovalRequest, &ApprovalResult) + Send + Sync>>,
52}
53
54impl ApprovalQueue {
55 pub fn new(max_pending: usize, default_timeout: Duration) -> Self {
56 ApprovalQueue {
57 pending: Arc::new(Mutex::new(HashMap::new())),
58 max_pending,
59 default_timeout,
60 on_push: None,
61 on_resolve: None,
62 }
63 }
64
65 pub fn on_push<F>(mut self, f: F) -> Self
66 where
67 F: Fn(&ApprovalRequest) + Send + Sync + 'static,
68 {
69 self.on_push = Some(Arc::new(f));
70 self
71 }
72
73 pub fn on_resolve<F>(mut self, f: F) -> Self
74 where
75 F: Fn(&ApprovalRequest, &ApprovalResult) + Send + Sync + 'static,
76 {
77 self.on_resolve = Some(Arc::new(f));
78 self
79 }
80
81 pub fn size(&self) -> usize {
82 self.pending.lock().unwrap().len()
83 }
84
85 pub fn list(&self) -> Vec<ApprovalRequest> {
86 self.pending
87 .lock()
88 .unwrap()
89 .values()
90 .map(|r| r.request.clone())
91 .collect()
92 }
93
94 pub async fn push(&self, request: ApprovalRequest) -> Result<ApprovalResult, ApprovalError> {
97 let (tx, rx) = oneshot::channel::<ApprovalResult>();
98 {
99 let mut map = self.pending.lock().unwrap();
100 if map.len() >= self.max_pending {
101 return Err(ApprovalError::QueueFull(map.len()));
102 }
103 let id = request.id.clone();
104 let pending = self.pending.clone();
105 let timeout_id = id.clone();
106 let default_timeout = self.default_timeout;
107 let on_resolve = self.on_resolve.clone();
108 let request_for_timer = request.clone();
109 let timer = tokio::spawn(async move {
110 tokio::time::sleep(default_timeout).await;
111 let sender = {
112 let mut map = pending.lock().unwrap();
113 map.remove(&timeout_id).and_then(|r| r.responder)
114 };
115 if let Some(tx) = sender {
116 let result = ApprovalResult {
117 decision: ApprovalDecision::Deny,
118 note: Some("timeout".to_string()),
119 };
120 if let Some(cb) = &on_resolve {
121 cb(&request_for_timer, &result);
122 }
123 let _ = tx.send(result);
124 }
125 });
126 map.insert(
127 id,
128 PendingRecord {
129 request: request.clone(),
130 responder: Some(tx),
131 timer: Some(timer),
132 },
133 );
134 }
135 if let Some(cb) = &self.on_push {
136 cb(&request);
137 }
138 rx.await.map_err(|_| ApprovalError::ChannelClosed)
139 }
140
141 pub fn respond(
144 &self,
145 request_id: &str,
146 decision: ApprovalDecision,
147 note: Option<String>,
148 ) -> bool {
149 let (responder, request, timer) = {
150 let mut map = self.pending.lock().unwrap();
151 match map.remove(request_id) {
152 Some(r) => (r.responder, r.request, r.timer),
153 None => return false,
154 }
155 };
156 if let Some(handle) = timer {
157 handle.abort();
158 }
159 let result = ApprovalResult { decision, note };
160 if let Some(cb) = &self.on_resolve {
161 cb(&request, &result);
162 }
163 if let Some(tx) = responder {
164 let _ = tx.send(result);
165 }
166 true
167 }
168
169 pub fn drain_deny(&self, reason: &str) {
171 let items: Vec<_> = {
172 let mut map = self.pending.lock().unwrap();
173 map.drain().map(|(_, r)| r).collect()
174 };
175 for record in items {
176 if let Some(handle) = record.timer {
177 handle.abort();
178 }
179 let result = ApprovalResult {
180 decision: ApprovalDecision::Deny,
181 note: Some(reason.to_string()),
182 };
183 if let Some(cb) = &self.on_resolve {
184 cb(&record.request, &result);
185 }
186 if let Some(tx) = record.responder {
187 let _ = tx.send(result);
188 }
189 }
190 }
191}