valkey_module/context/
blocked.rs1use crate::redismodule::AUTH_HANDLED;
2use crate::{raw, Context, ValkeyError, ValkeyString};
3use std::os::raw::{c_int, c_void};
4
5#[derive(Debug)]
8pub enum ReplyCallback<T> {
9 Auth(fn(&Context, ValkeyString, ValkeyString, Option<&T>) -> Result<c_int, ValkeyError>),
10}
11
12#[derive(Debug)]
13struct BlockedClientPrivateData<T: 'static> {
14 reply_callback: Option<ReplyCallback<T>>,
15 free_callback: Option<FreePrivateDataCallback<T>>,
16 data: Option<Box<T>>,
17}
18
19type FreePrivateDataCallback<T> = fn(&Context, T);
21
22pub struct BlockedClient<T: 'static = ()> {
23 pub(crate) inner: *mut raw::RedisModuleBlockedClient,
24 reply_callback: Option<ReplyCallback<T>>,
25 free_callback: Option<FreePrivateDataCallback<T>>,
26 data: Option<Box<T>>,
27}
28
29#[allow(dead_code)]
30unsafe extern "C" fn auth_reply_wrapper<T: 'static>(
31 ctx: *mut raw::RedisModuleCtx,
32 username: *mut raw::RedisModuleString,
33 password: *mut raw::RedisModuleString,
34 err: *mut *mut raw::RedisModuleString,
35) -> c_int {
36 let context = Context::new(ctx);
37 let ctx_ptr = std::ptr::NonNull::new_unchecked(ctx);
38 let username = ValkeyString::new(Some(ctx_ptr), username);
39 let password = ValkeyString::new(Some(ctx_ptr), password);
40
41 let module_private_data = context.get_blocked_client_private_data();
42 if module_private_data.is_null() {
43 panic!("[auth_reply_wrapper] Module private data is null; this should not happen!");
44 }
45
46 let user_private_data = &*(module_private_data as *const BlockedClientPrivateData<T>);
47
48 let cb = match user_private_data.reply_callback.as_ref() {
49 Some(ReplyCallback::Auth(cb)) => cb,
50 None => panic!("[auth_reply_wrapper] Reply callback is null; this should not happen!"),
51 };
52
53 let data_ref = user_private_data.data.as_deref();
54
55 match cb(&context, username, password, data_ref) {
56 Ok(result) => result,
57 Err(error) => {
58 let error_msg = ValkeyString::create_and_retain(&error.to_string());
59 *err = error_msg.inner;
60 AUTH_HANDLED
61 }
62 }
63}
64
65#[allow(dead_code)]
66unsafe extern "C" fn free_callback_wrapper<T: 'static>(
67 ctx: *mut raw::RedisModuleCtx,
68 module_private_data: *mut c_void,
69) {
70 let context = Context::new(ctx);
71
72 if module_private_data.is_null() {
73 panic!("[free_callback_wrapper] Module private data is null; this should not happen!");
74 }
75
76 let user_private_data = Box::from_raw(module_private_data as *mut BlockedClientPrivateData<T>);
77
78 if let Some(free_cb) = user_private_data.free_callback {
81 if let Some(data) = user_private_data.data {
82 free_cb(&context, *data);
83 }
84 }
85}
86
87unsafe impl<T> Send for BlockedClient<T> {}
89
90impl<T> BlockedClient<T> {
91 pub(crate) fn new(inner: *mut raw::RedisModuleBlockedClient) -> Self {
92 Self {
93 inner,
94 reply_callback: None,
95 free_callback: None,
96 data: None,
97 }
98 }
99
100 #[allow(dead_code)]
101 pub(crate) fn with_auth_callback(
102 inner: *mut raw::RedisModuleBlockedClient,
103 auth_reply_callback: fn(
104 &Context,
105 ValkeyString,
106 ValkeyString,
107 Option<&T>,
108 ) -> Result<c_int, ValkeyError>,
109 free_callback: Option<FreePrivateDataCallback<T>>,
110 ) -> Self
111 where
112 T: 'static,
113 {
114 Self {
115 inner,
116 reply_callback: Some(ReplyCallback::Auth(auth_reply_callback)),
117 free_callback,
118 data: None,
119 }
120 }
121
122 pub fn set_blocked_private_data(&mut self, data: T) -> Result<(), ValkeyError> {
131 if self.free_callback.is_none() {
132 return Err(ValkeyError::Str(
133 "Cannot set private data without a free callback - this would leak memory",
134 ));
135 }
136 self.data = Some(Box::new(data));
137 Ok(())
138 }
139
140 pub fn abort(mut self) -> Result<(), ValkeyError> {
146 unsafe {
147 self.data = None;
149 self.reply_callback = None;
150 self.free_callback = None;
151
152 if raw::RedisModule_AbortBlock.unwrap()(self.inner) == raw::REDISMODULE_OK as c_int {
153 self.inner = std::ptr::null_mut();
155 Ok(())
156 } else {
157 Err(ValkeyError::Str("Failed to abort blocked client"))
158 }
159 }
160 }
161}
162
163impl<T: 'static> Drop for BlockedClient<T> {
164 fn drop(&mut self) {
165 if !self.inner.is_null() {
166 let callback_data_ptr = if self.reply_callback.is_some() || self.free_callback.is_some()
167 {
168 Box::into_raw(Box::new(BlockedClientPrivateData {
169 reply_callback: self.reply_callback.take(),
170 free_callback: self.free_callback.take(),
171 data: self.data.take(),
172 })) as *mut c_void
173 } else {
174 std::ptr::null_mut()
175 };
176
177 unsafe {
178 raw::RedisModule_UnblockClient.unwrap()(self.inner, callback_data_ptr);
179 }
180 }
181 }
182}
183
184impl Context {
185 #[must_use]
186 pub fn block_client(&self) -> BlockedClient {
187 let blocked_client = unsafe {
188 raw::RedisModule_BlockClient.unwrap()(
189 self.ctx, None, None, None, 0,
193 )
194 };
195
196 BlockedClient::new(blocked_client)
197 }
198
199 #[must_use]
211 #[cfg(all(any(
212 feature = "min-redis-compatibility-version-7-2",
213 feature = "min-valkey-compatibility-version-8-0"
214 ),))]
215 pub fn block_client_on_auth<T: 'static + Send>(
216 &self,
217 auth_reply_callback: fn(
218 &Context,
219 ValkeyString,
220 ValkeyString,
221 Option<&T>,
222 ) -> Result<c_int, ValkeyError>,
223 free_callback: Option<FreePrivateDataCallback<T>>,
224 ) -> BlockedClient<T> {
225 unsafe {
226 let blocked_client = raw::RedisModule_BlockClientOnAuth.unwrap()(
227 self.ctx,
228 Some(auth_reply_wrapper::<T>),
229 Some(free_callback_wrapper::<T>),
230 );
231
232 BlockedClient::with_auth_callback(blocked_client, auth_reply_callback, free_callback)
233 }
234 }
235
236 pub(crate) fn get_blocked_client_private_data(&self) -> *mut c_void {
246 unsafe { raw::RedisModule_GetBlockedClientPrivateData.unwrap()(self.ctx) }
247 }
248}