Skip to main content

tf_types/
approval.rs

1#![allow(clippy::type_complexity)]
2//! ApprovalQueue — Rust mirror of
3//! `tools/tf-types-ts/src/core/approval.ts`.
4//!
5//! A FIFO of pending ApprovalRequests where the daemon side awaits a
6//! resolution. Uses tokio oneshot channels for the awaited responses and a
7//! tokio timer for default-deny timeouts.
8
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex};
11use std::time::Duration;
12
13use serde::{Deserialize, Serialize};
14use tokio::sync::oneshot;
15use tokio::task::JoinHandle;
16
17use crate::generated::approval_request::ApprovalRequest;
18
19#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
20#[serde(rename_all = "snake_case")]
21pub enum ApprovalDecision {
22    Approve,
23    Deny,
24}
25
26#[derive(Debug, thiserror::Error)]
27pub enum ApprovalError {
28    #[error("approval queue full ({0} pending)")]
29    QueueFull(usize),
30    #[error("approval result channel closed")]
31    ChannelClosed,
32}
33
34#[derive(Clone, Debug)]
35pub struct ApprovalResult {
36    pub decision: ApprovalDecision,
37    pub note: Option<String>,
38}
39
40struct PendingRecord {
41    request: ApprovalRequest,
42    responder: Option<oneshot::Sender<ApprovalResult>>,
43    timer: Option<JoinHandle<()>>,
44}
45
46pub struct ApprovalQueue {
47    pending: Arc<Mutex<HashMap<String, PendingRecord>>>,
48    max_pending: usize,
49    default_timeout: Duration,
50    on_push: Option<Arc<dyn Fn(&ApprovalRequest) + Send + Sync>>,
51    on_resolve: Option<Arc<dyn Fn(&ApprovalRequest, &ApprovalResult) + Send + Sync>>,
52}
53
54impl ApprovalQueue {
55    pub fn new(max_pending: usize, default_timeout: Duration) -> Self {
56        ApprovalQueue {
57            pending: Arc::new(Mutex::new(HashMap::new())),
58            max_pending,
59            default_timeout,
60            on_push: None,
61            on_resolve: None,
62        }
63    }
64
65    pub fn on_push<F>(mut self, f: F) -> Self
66    where
67        F: Fn(&ApprovalRequest) + Send + Sync + 'static,
68    {
69        self.on_push = Some(Arc::new(f));
70        self
71    }
72
73    pub fn on_resolve<F>(mut self, f: F) -> Self
74    where
75        F: Fn(&ApprovalRequest, &ApprovalResult) + Send + Sync + 'static,
76    {
77        self.on_resolve = Some(Arc::new(f));
78        self
79    }
80
81    pub fn size(&self) -> usize {
82        self.pending.lock().unwrap().len()
83    }
84
85    pub fn list(&self) -> Vec<ApprovalRequest> {
86        self.pending
87            .lock()
88            .unwrap()
89            .values()
90            .map(|r| r.request.clone())
91            .collect()
92    }
93
94    /// Enqueue a request and await a decision. Resolves with Deny if the
95    /// default timeout elapses first.
96    pub async fn push(&self, request: ApprovalRequest) -> Result<ApprovalResult, ApprovalError> {
97        let (tx, rx) = oneshot::channel::<ApprovalResult>();
98        {
99            let mut map = self.pending.lock().unwrap();
100            if map.len() >= self.max_pending {
101                return Err(ApprovalError::QueueFull(map.len()));
102            }
103            let id = request.id.clone();
104            let pending = self.pending.clone();
105            let timeout_id = id.clone();
106            let default_timeout = self.default_timeout;
107            let on_resolve = self.on_resolve.clone();
108            let request_for_timer = request.clone();
109            let timer = tokio::spawn(async move {
110                tokio::time::sleep(default_timeout).await;
111                let sender = {
112                    let mut map = pending.lock().unwrap();
113                    map.remove(&timeout_id).and_then(|r| r.responder)
114                };
115                if let Some(tx) = sender {
116                    let result = ApprovalResult {
117                        decision: ApprovalDecision::Deny,
118                        note: Some("timeout".to_string()),
119                    };
120                    if let Some(cb) = &on_resolve {
121                        cb(&request_for_timer, &result);
122                    }
123                    let _ = tx.send(result);
124                }
125            });
126            map.insert(
127                id,
128                PendingRecord {
129                    request: request.clone(),
130                    responder: Some(tx),
131                    timer: Some(timer),
132                },
133            );
134        }
135        if let Some(cb) = &self.on_push {
136            cb(&request);
137        }
138        rx.await.map_err(|_| ApprovalError::ChannelClosed)
139    }
140
141    /// Resolve a pending request. Returns true if a matching request was
142    /// found and resolved.
143    pub fn respond(
144        &self,
145        request_id: &str,
146        decision: ApprovalDecision,
147        note: Option<String>,
148    ) -> bool {
149        let (responder, request, timer) = {
150            let mut map = self.pending.lock().unwrap();
151            match map.remove(request_id) {
152                Some(r) => (r.responder, r.request, r.timer),
153                None => return false,
154            }
155        };
156        if let Some(handle) = timer {
157            handle.abort();
158        }
159        let result = ApprovalResult { decision, note };
160        if let Some(cb) = &self.on_resolve {
161            cb(&request, &result);
162        }
163        if let Some(tx) = responder {
164            let _ = tx.send(result);
165        }
166        true
167    }
168
169    /// Cancel every outstanding approval with Deny + `reason`.
170    pub fn drain_deny(&self, reason: &str) {
171        let items: Vec<_> = {
172            let mut map = self.pending.lock().unwrap();
173            map.drain().map(|(_, r)| r).collect()
174        };
175        for record in items {
176            if let Some(handle) = record.timer {
177                handle.abort();
178            }
179            let result = ApprovalResult {
180                decision: ApprovalDecision::Deny,
181                note: Some(reason.to_string()),
182            };
183            if let Some(cb) = &self.on_resolve {
184                cb(&record.request, &result);
185            }
186            if let Some(tx) = record.responder {
187                let _ = tx.send(result);
188            }
189        }
190    }
191}