1use std::fmt;
4use std::sync::Arc;
5use std::sync::atomic::Ordering;
6use std::time::Duration;
7
8use surrealism_types::args::Args;
9use surrealism_types::err::{SurrealismError, SurrealismResult};
10use tokio::sync::OwnedSemaphorePermit;
11use wasmtime::*;
12use web_time::Instant;
13
14use crate::epoch::EPOCH_TICK_MS;
15use crate::host::{InvocationContext, NullContext};
16use crate::store::StoreData;
17
18fn effective_timeout(
19 context_remaining: Option<Duration>,
20 module_limit: Option<Duration>,
21) -> Option<Duration> {
22 [context_remaining, module_limit].into_iter().flatten().min()
23}
24
25pub struct Controller {
29 store: Store<StoreData>,
30 invoke_fn: component::Func,
31 args_fn: Option<component::Func>,
33 returns_fn: Option<component::Func>,
35 list_fn: Option<component::Func>,
37 writeable_fn: Option<component::Func>,
39 comment_fn: Option<component::Func>,
41 init_fn: Option<component::Func>,
42 module_execution_time: Option<Duration>,
45 epoch_counter: Arc<std::sync::atomic::AtomicU64>,
47 controller_slot: Option<OwnedSemaphorePermit>,
51}
52
53impl fmt::Debug for Controller {
54 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55 f.debug_struct("Controller").finish_non_exhaustive()
56 }
57}
58
59impl Controller {
60 #[allow(clippy::too_many_arguments)]
61 pub(crate) fn new(
62 store: Store<StoreData>,
63 invoke_fn: component::Func,
64 args_fn: Option<component::Func>,
65 returns_fn: Option<component::Func>,
66 list_fn: Option<component::Func>,
67 writeable_fn: Option<component::Func>,
68 comment_fn: Option<component::Func>,
69 init_fn: Option<component::Func>,
70 module_execution_time: Option<Duration>,
71 epoch_counter: Arc<std::sync::atomic::AtomicU64>,
72 controller_slot: OwnedSemaphorePermit,
73 ) -> Self {
74 Self {
75 store,
76 invoke_fn,
77 args_fn,
78 returns_fn,
79 list_fn,
80 writeable_fn,
81 comment_fn,
82 init_fn,
83 module_execution_time,
84 epoch_counter,
85 controller_slot: Some(controller_slot),
86 }
87 }
88
89 pub(crate) fn attach_controller_slot(&mut self, permit: OwnedSemaphorePermit) {
91 debug_assert!(self.controller_slot.is_none(), "controller already holds a slot permit");
92 self.controller_slot = Some(permit);
93 }
94
95 pub(crate) fn take_controller_slot(&mut self) -> Option<OwnedSemaphorePermit> {
98 self.controller_slot.take()
99 }
100
101 pub fn set_context(&mut self, context: Box<dyn InvocationContext>) {
106 let data = self.store.data_mut();
107 *data.stdout_cb.lock() = context.stdout_callback();
108 *data.stderr_cb.lock() = context.stderr_callback();
109 data.context = context;
110 }
111
112 pub fn clear_context(&mut self) {
116 let data = self.store.data_mut();
117 *data.stdout_cb.lock() = Arc::new(|output| print!("{}", output));
118 *data.stderr_cb.lock() = Arc::new(|output| eprint!("{}", output));
119 data.context = Box::new(NullContext);
120 }
121
122 pub fn reset_epoch_deadline(&mut self) {
129 let epoch = self.epoch_counter.load(Ordering::Acquire);
130 self.store.set_epoch_deadline(u64::MAX.saturating_sub(epoch).saturating_sub(1));
131 }
132
133 fn apply_module_deadline(&mut self) {
136 match self.module_execution_time {
137 Some(timeout) => {
138 let ticks = (timeout.as_millis() as u64) / EPOCH_TICK_MS;
139 self.store.set_epoch_deadline(ticks.max(1));
140 }
141 None => self.reset_epoch_deadline(),
142 }
143 }
144
145 #[tracing::instrument(skip_all)]
146 pub async fn init(&mut self) -> SurrealismResult<()> {
147 let t0 = Instant::now();
148 let Some(func) = self.init_fn else {
149 tracing::debug!("controller.init(): no init_fn, skipping");
150 return Ok(());
151 };
152 self.apply_module_deadline();
153 tracing::info!(
154 module_execution_time = ?self.module_execution_time,
155 "controller.init(): calling init function..."
156 );
157 let typed = func.typed::<(), (Result<(), String>,)>(&self.store)?;
158 match typed.call_async(&mut self.store, ()).await {
159 Ok((result,)) => {
160 tracing::info!(elapsed = ?t0.elapsed(), ok = result.is_ok(), "controller.init(): completed");
161 result.map_err(SurrealismError::FunctionCallError)
162 }
163 Err(e) => {
164 if e.downcast_ref::<Trap>() == Some(&Trap::Interrupt) {
165 tracing::error!(elapsed = ?t0.elapsed(), "controller.init(): timed out");
166 return Err(SurrealismError::Timeout {
167 effective: self.module_execution_time,
168 context_timeout: None,
169 module_limit: self.module_execution_time,
170 });
171 }
172 tracing::error!(elapsed = ?t0.elapsed(), error = %e, "controller.init(): WASM TRAP");
173 Err(e.into())
174 }
175 }
176 }
177
178 #[tracing::instrument(skip_all, fields(name))]
179 pub async fn invoke<A: Args>(
180 &mut self,
181 name: Option<String>,
182 args: A,
183 ) -> SurrealismResult<surrealdb_types::Value> {
184 self.invoke_with_timeout(name, args, None).await
185 }
186
187 #[tracing::instrument(skip_all, fields(name))]
190 pub async fn invoke_with_timeout<A: Args>(
191 &mut self,
192 name: Option<String>,
193 args: A,
194 context_timeout: Option<Duration>,
195 ) -> SurrealismResult<surrealdb_types::Value> {
196 let display_name = name.as_deref().unwrap_or("<default>");
197 let effective = effective_timeout(context_timeout, self.module_execution_time);
198
199 match effective {
200 Some(timeout) => {
201 let ticks = (timeout.as_millis() as u64) / EPOCH_TICK_MS;
202 self.store.set_epoch_deadline(ticks.max(1));
203 }
204 None => {
205 self.reset_epoch_deadline();
206 }
207 }
208
209 let args_values = args.to_values();
210 let args_bytes = surrealdb_types::encode_value_list(&args_values)?;
211
212 let typed = self
213 .invoke_fn
214 .typed::<(Option<&str>, &[u8]), (Result<Vec<u8>, String>,)>(&self.store)?;
215
216 let call_result = typed.call_async(&mut self.store, (name.as_deref(), &args_bytes)).await;
217
218 if let Err(e) = &call_result {
219 tracing::error!(name = %display_name, error = %e, "invoke_with_timeout: call_async FAILED");
220 }
221
222 let (result,) = call_result.map_err(|e| {
223 if e.downcast_ref::<Trap>() == Some(&Trap::Interrupt) {
224 SurrealismError::Timeout {
225 effective,
226 context_timeout,
227 module_limit: self.module_execution_time,
228 }
229 } else {
230 SurrealismError::from(e)
231 }
232 })?;
233
234 if let Err(guest_err) = &result {
235 tracing::warn!(name = %display_name, guest_error = %guest_err, "invoke_with_timeout: guest returned Err");
236 }
237
238 let result_bytes = result.map_err(SurrealismError::FunctionCallError)?;
239 let value = surrealdb_types::decode::<surrealdb_types::Value>(&result_bytes)?;
240 Ok(value)
241 }
242
243 fn trap_to_timeout(&self, e: wasmtime::Error) -> SurrealismError {
247 if e.downcast_ref::<Trap>() == Some(&Trap::Interrupt) {
248 SurrealismError::Timeout {
249 effective: self.module_execution_time,
250 context_timeout: None,
251 module_limit: self.module_execution_time,
252 }
253 } else {
254 SurrealismError::from(e)
255 }
256 }
257
258 #[tracing::instrument(skip_all, fields(name))]
261 pub async fn args(
262 &mut self,
263 name: Option<String>,
264 ) -> SurrealismResult<Vec<(String, surrealdb_types::Kind)>> {
265 let display_name = name.as_deref().unwrap_or("<default>");
266 tracing::debug!(name = %display_name, "controller.args(): calling function-args");
267 let func = self.args_fn.ok_or_else(|| {
268 SurrealismError::Other(anyhow::anyhow!("function-args export not available"))
269 })?;
270 self.apply_module_deadline();
271 let typed = func.typed::<(Option<&str>,), (Result<Vec<u8>, String>,)>(&self.store)?;
272
273 match typed.call_async(&mut self.store, (name.as_deref(),)).await {
274 Ok((result,)) => {
275 tracing::debug!(name = %display_name, ok = result.is_ok(), "controller.args(): call_async completed");
276 let result_bytes = result.map_err(SurrealismError::FunctionCallError)?;
277 Ok(surrealdb_types::decode_argument_list(&result_bytes)?)
278 }
279 Err(e) => {
280 tracing::error!(name = %display_name, error = %e, error_debug = ?e, "controller.args(): WASM TRAP");
281 Err(self.trap_to_timeout(e))
282 }
283 }
284 }
285
286 #[tracing::instrument(skip_all, fields(name))]
289 pub async fn returns(
290 &mut self,
291 name: Option<String>,
292 ) -> SurrealismResult<surrealdb_types::Kind> {
293 let display_name = name.as_deref().unwrap_or("<default>");
294 tracing::debug!(name = %display_name, "controller.returns(): calling function-returns");
295 let func = self.returns_fn.ok_or_else(|| {
296 SurrealismError::Other(anyhow::anyhow!("function-returns export not available"))
297 })?;
298 self.apply_module_deadline();
299 let typed = func.typed::<(Option<&str>,), (Result<Vec<u8>, String>,)>(&self.store)?;
300
301 match typed.call_async(&mut self.store, (name.as_deref(),)).await {
302 Ok((result,)) => {
303 tracing::debug!(name = %display_name, ok = result.is_ok(), "controller.returns(): call_async completed");
304 let result_bytes = result.map_err(SurrealismError::FunctionCallError)?;
305 Ok(surrealdb_types::decode_kind(&result_bytes)?)
306 }
307 Err(e) => {
308 tracing::error!(name = %display_name, error = %e, error_debug = ?e, "controller.returns(): WASM TRAP");
309 Err(self.trap_to_timeout(e))
310 }
311 }
312 }
313
314 #[tracing::instrument(skip_all, fields(name))]
317 pub async fn writeable(&mut self, name: Option<String>) -> SurrealismResult<bool> {
318 let display_name = name.as_deref().unwrap_or("<default>");
319 tracing::debug!(name = %display_name, "controller.writeable(): calling function-writeable");
320 let func = self.writeable_fn.ok_or_else(|| {
321 SurrealismError::Other(anyhow::anyhow!("function-writeable export not available"))
322 })?;
323 self.apply_module_deadline();
324 let typed = func.typed::<(Option<&str>,), (Result<bool, String>,)>(&self.store)?;
325
326 match typed.call_async(&mut self.store, (name.as_deref(),)).await {
327 Ok((result,)) => {
328 tracing::debug!(name = %display_name, ok = result.is_ok(), "controller.writeable(): call_async completed");
329 result.map_err(SurrealismError::FunctionCallError)
330 }
331 Err(e) => {
332 tracing::error!(name = %display_name, error = %e, error_debug = ?e, "controller.writeable(): WASM TRAP");
333 Err(self.trap_to_timeout(e))
334 }
335 }
336 }
337
338 #[tracing::instrument(skip_all, fields(name))]
341 pub async fn comment(&mut self, name: Option<String>) -> SurrealismResult<Option<String>> {
342 let display_name = name.as_deref().unwrap_or("<default>");
343 tracing::debug!(name = %display_name, "controller.comment(): calling function-comment");
344 let func = self.comment_fn.ok_or_else(|| {
345 SurrealismError::Other(anyhow::anyhow!("function-comment export not available"))
346 })?;
347 self.apply_module_deadline();
348 let typed =
349 func.typed::<(Option<&str>,), (Result<Option<String>, String>,)>(&self.store)?;
350
351 match typed.call_async(&mut self.store, (name.as_deref(),)).await {
352 Ok((result,)) => {
353 tracing::debug!(name = %display_name, ok = result.is_ok(), "controller.comment(): call_async completed");
354 result.map_err(SurrealismError::FunctionCallError)
355 }
356 Err(e) => {
357 tracing::error!(name = %display_name, error = %e, error_debug = ?e, "controller.comment(): WASM TRAP");
358 Err(self.trap_to_timeout(e))
359 }
360 }
361 }
362
363 #[tracing::instrument(skip_all)]
366 pub async fn list(&mut self) -> SurrealismResult<Vec<Option<String>>> {
367 tracing::debug!("controller.list(): calling list-functions");
368 let func = self.list_fn.ok_or_else(|| {
369 SurrealismError::Other(anyhow::anyhow!("list-functions export not available"))
370 })?;
371 self.apply_module_deadline();
372 let typed = func.typed::<(), (Vec<Option<String>>,)>(&self.store)?;
373
374 match typed.call_async(&mut self.store, ()).await {
375 Ok((names,)) => {
376 tracing::debug!(count = names.len(), names = ?names, "controller.list(): completed");
377 Ok(names)
378 }
379 Err(e) => {
380 tracing::error!(error = %e, error_debug = ?e, "controller.list(): WASM TRAP");
381 Err(self.trap_to_timeout(e))
382 }
383 }
384 }
385}