1use std::collections::HashMap;
2use std::net::SocketAddr;
3use std::path::{Path, PathBuf};
4use std::process::Stdio;
5use std::sync::{Arc, Mutex};
6use std::time::UNIX_EPOCH;
7
8use futures_util::{Sink, SinkExt, StreamExt};
9use rmpv::Value;
10use rpc_runtime_core::{InstanceId, MethodId, ServiceGuid};
11use rpc_runtime_errors::{RuntimeError, RuntimeErrorCode};
12use rpc_runtime_server::{
13 ConnectionCleanupFuture, HandlerFuture, RpcCallContext, RpcConnectionCleanupSink, RpcServer,
14 RpcServerBuilder, RpcServerSecurityConfig,
15};
16use serde::{Deserialize, Serialize};
17use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt, SeekFrom};
18use tokio::net::tcp::OwnedWriteHalf;
19use tokio::net::{TcpListener, TcpStream};
20use tokio::sync::Mutex as AsyncMutex;
21use tokio::task::JoinHandle;
22use tokio_tungstenite::tungstenite::Message;
23use uuid::Uuid;
24
25mod archive;
26mod filesystem;
27pub mod generated;
28
29pub const RUNTIME_INSTANCE: &str = "tripley.native.runtime";
30pub const FS_INSTANCE: &str = "tripley.native.fs";
31pub const ARCHIVE_INSTANCE: &str = "tripley.native.archive";
32pub const TCP_INSTANCE: &str = "tripley.native.tcp";
33pub const WEBSOCKET_INSTANCE: &str = "tripley.native.websocket";
34pub const SQLITE_INSTANCE: &str = "tripley.native.sqlite";
35pub const SYSTEM_INSTANCE: &str = "tripley.native.system";
36
37pub const RUNTIME_SERVICE_GUID: &str = "52c6943d-d956-4a42-b69e-51e94399c001";
38pub const FS_SERVICE_GUID: &str = "52c6943d-d956-4a42-b69e-51e94399c002";
39pub const ARCHIVE_SERVICE_GUID: &str = "1e7d1d50-721c-4b14-ad7e-7793170cea05";
40pub const TCP_SERVICE_GUID: &str = "52c6943d-d956-4a42-b69e-51e94399c003";
41pub const WEBSOCKET_SERVICE_GUID: &str = "52c6943d-d956-4a42-b69e-51e94399c004";
42pub const SQLITE_SERVICE_GUID: &str = "52c6943d-d956-4a42-b69e-51e94399c005";
43pub const SYSTEM_SERVICE_GUID: &str = "52c6943d-d956-4a42-b69e-51e94399c006";
44
45pub const EVENT_NOTIFICATION_ID: u32 = 1;
46
47pub trait NativeRpcProvider: Send + Sync {
48 fn register(&self, builder: &mut RpcServerBuilder);
49 fn capabilities(&self) -> Vec<&'static str>;
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub struct NativeServiceSet {
54 pub fs: bool,
55 pub archive: bool,
56 pub tcp: bool,
57 pub websocket: bool,
58 pub sqlite: bool,
59 pub system: bool,
60}
61
62impl NativeServiceSet {
63 pub const fn all() -> Self {
64 Self {
65 fs: true,
66 archive: true,
67 tcp: true,
68 websocket: true,
69 sqlite: true,
70 system: true,
71 }
72 }
73
74 pub const fn runtime_only() -> Self {
75 Self {
76 fs: false,
77 archive: false,
78 tcp: false,
79 websocket: false,
80 sqlite: false,
81 system: false,
82 }
83 }
84
85 pub fn capabilities(self) -> Vec<&'static str> {
86 let mut capabilities = vec!["runtime.info"];
87 if self.fs {
88 capabilities.push("fs");
89 }
90 if self.archive {
91 capabilities.push("archive");
92 }
93 if self.tcp {
94 capabilities.push("tcp.client");
95 capabilities.push("tcp.server");
96 }
97 if self.websocket {
98 capabilities.push("websocket.client");
99 capabilities.push("websocket.server");
100 }
101 if self.sqlite {
102 capabilities.push("sqlite");
103 }
104 if self.system {
105 capabilities.push("system.shutdown");
106 capabilities.push("system.reboot");
107 }
108 capabilities
109 }
110}
111
112impl Default for NativeServiceSet {
113 fn default() -> Self {
114 Self::all()
115 }
116}
117
118#[derive(Clone)]
119pub struct NativeRpcServerOptions {
120 pub policy: Arc<dyn NativePolicy>,
121 pub services: NativeServiceSet,
122 pub security: RpcServerSecurityConfig,
123 pub providers: Vec<Arc<dyn NativeRpcProvider>>,
124}
125
126impl Default for NativeRpcServerOptions {
127 fn default() -> Self {
128 Self {
129 policy: Arc::new(DevPermissivePolicy),
130 services: NativeServiceSet::all(),
131 security: RpcServerSecurityConfig::default(),
132 providers: Vec::new(),
133 }
134 }
135}
136
137#[derive(Clone)]
138pub struct NativeState {
139 open_files: Arc<AsyncMutex<HashMap<String, OpenFileResource>>>,
140 tcp_sockets: Arc<AsyncMutex<HashMap<String, TcpSocketResource>>>,
141 tcp_servers: Arc<AsyncMutex<HashMap<String, TaskResource>>>,
142 websocket_sockets: Arc<AsyncMutex<HashMap<String, WebSocketResource>>>,
143 websocket_servers: Arc<AsyncMutex<HashMap<String, TaskResource>>>,
144 sqlite: Arc<Mutex<HashMap<String, SqliteResource>>>,
145 policy: Arc<dyn NativePolicy>,
146 services: NativeServiceSet,
147 provider_capabilities: Arc<Vec<&'static str>>,
148}
149
150impl Default for NativeState {
151 fn default() -> Self {
152 Self {
153 open_files: Arc::new(AsyncMutex::new(HashMap::new())),
154 tcp_sockets: Arc::new(AsyncMutex::new(HashMap::new())),
155 tcp_servers: Arc::new(AsyncMutex::new(HashMap::new())),
156 websocket_sockets: Arc::new(AsyncMutex::new(HashMap::new())),
157 websocket_servers: Arc::new(AsyncMutex::new(HashMap::new())),
158 sqlite: Arc::new(Mutex::new(HashMap::new())),
159 policy: Arc::new(DevPermissivePolicy),
160 services: NativeServiceSet::all(),
161 provider_capabilities: Arc::new(Vec::new()),
162 }
163 }
164}
165
166impl NativeState {
167 async fn dispose_connection_resources(&self, connection_id: u64) {
168 let file_ids = {
169 let files = self.open_files.lock().await;
170 files
171 .iter()
172 .filter_map(|(id, resource)| {
173 (resource.owner_connection_id == connection_id).then(|| id.clone())
174 })
175 .collect::<Vec<_>>()
176 };
177 for id in file_ids {
178 self.open_files.lock().await.remove(&id);
179 }
180
181 let tcp_socket_ids = {
182 let sockets = self.tcp_sockets.lock().await;
183 sockets
184 .iter()
185 .filter_map(|(id, resource)| {
186 (resource.owner_connection_id == connection_id).then(|| id.clone())
187 })
188 .collect::<Vec<_>>()
189 };
190 for id in tcp_socket_ids {
191 self.tcp_sockets.lock().await.remove(&id);
192 }
193
194 let tcp_server_ids = {
195 let servers = self.tcp_servers.lock().await;
196 servers
197 .iter()
198 .filter_map(|(id, resource)| {
199 (resource.owner_connection_id == connection_id).then(|| id.clone())
200 })
201 .collect::<Vec<_>>()
202 };
203 for id in tcp_server_ids {
204 if let Some(resource) = self.tcp_servers.lock().await.remove(&id) {
205 resource.task.abort();
206 }
207 }
208
209 let websocket_ids = {
210 let sockets = self.websocket_sockets.lock().await;
211 sockets
212 .iter()
213 .filter_map(|(id, resource)| {
214 (resource.owner_connection_id == connection_id).then(|| id.clone())
215 })
216 .collect::<Vec<_>>()
217 };
218 for id in websocket_ids {
219 if let Some(resource) = self.websocket_sockets.lock().await.remove(&id) {
220 let _ = resource
221 .writer
222 .lock()
223 .await
224 .send(Message::Close(None))
225 .await;
226 }
227 }
228
229 let websocket_server_ids = {
230 let servers = self.websocket_servers.lock().await;
231 servers
232 .iter()
233 .filter_map(|(id, resource)| {
234 (resource.owner_connection_id == connection_id).then(|| id.clone())
235 })
236 .collect::<Vec<_>>()
237 };
238 for id in websocket_server_ids {
239 if let Some(resource) = self.websocket_servers.lock().await.remove(&id) {
240 resource.task.abort();
241 }
242 }
243
244 let sqlite_ids = {
245 let dbs = self.sqlite.lock().expect("sqlite lock");
246 dbs.iter()
247 .filter_map(|(id, resource)| {
248 (resource.owner_connection_id == connection_id).then(|| id.clone())
249 })
250 .collect::<Vec<_>>()
251 };
252 let mut dbs = self.sqlite.lock().expect("sqlite lock");
253 for id in sqlite_ids {
254 dbs.remove(&id);
255 }
256 }
257}
258
259impl RpcConnectionCleanupSink for NativeState {
260 fn cleanup_connection<'a>(&'a self, connection_id: u64) -> ConnectionCleanupFuture<'a> {
261 Box::pin(async move {
262 self.dispose_connection_resources(connection_id).await;
263 })
264 }
265}
266
267type BoxedWebSocketWriter =
268 std::pin::Pin<Box<dyn Sink<Message, Error = tokio_tungstenite::tungstenite::Error> + Send>>;
269type WebSocketWriter = Arc<AsyncMutex<BoxedWebSocketWriter>>;
270
271#[derive(Clone)]
272struct OpenFileResource {
273 owner_connection_id: u64,
274 file: Arc<AsyncMutex<tokio::fs::File>>,
275}
276
277#[derive(Clone)]
278struct TcpSocketResource {
279 owner_connection_id: u64,
280 writer: Arc<AsyncMutex<OwnedWriteHalf>>,
281}
282
283#[derive(Clone)]
284struct WebSocketResource {
285 owner_connection_id: u64,
286 writer: WebSocketWriter,
287}
288
289struct TaskResource {
290 owner_connection_id: u64,
291 task: JoinHandle<()>,
292}
293
294struct SqliteResource {
295 owner_connection_id: u64,
296 connection: rusqlite::Connection,
297}
298
299#[derive(Debug, Clone, Default, Serialize, Deserialize)]
300#[serde(default, rename_all = "camelCase")]
301pub struct NativePolicyConfig {
302 pub filesystem: FileSystemPolicyConfig,
303 pub network: NetworkPolicyConfig,
304 pub sqlite: SqlitePolicyConfig,
305 pub power: PowerPolicyConfig,
306}
307
308impl NativePolicyConfig {
309 pub fn dev_permissive() -> Self {
310 Self {
311 filesystem: FileSystemPolicyConfig {
312 read: vec![PathBuf::from("/")],
313 write: vec![PathBuf::from("/")],
314 },
315 network: NetworkPolicyConfig {
316 tcp: vec![NetworkRule::any()],
317 websocket: vec![NetworkRule::any()],
318 },
319 sqlite: SqlitePolicyConfig {
320 paths: vec![PathBuf::from("/")],
321 },
322 power: PowerPolicyConfig {
323 shutdown: true,
324 reboot: true,
325 },
326 }
327 }
328
329 pub fn from_json(value: &str) -> Result<Self, RuntimeError> {
330 serde_json::from_str(value).map_err(|error| {
331 RuntimeError::runtime(
332 RuntimeErrorCode::PayloadDecodeFailed,
333 format!("invalid native policy config JSON: {error}"),
334 )
335 })
336 }
337
338 pub fn allow_fs_read(mut self, path: impl Into<PathBuf>) -> Self {
339 self.filesystem.read.push(path.into());
340 self
341 }
342
343 pub fn allow_fs_write(mut self, path: impl Into<PathBuf>) -> Self {
344 self.filesystem.write.push(path.into());
345 self
346 }
347
348 pub fn allow_fs_read_write(mut self, path: impl Into<PathBuf>) -> Self {
349 let path = path.into();
350 self.filesystem.read.push(path.clone());
351 self.filesystem.write.push(path);
352 self
353 }
354
355 pub fn allow_tcp(mut self, host: impl Into<String>, port: Option<u16>) -> Self {
356 self.network.tcp.push(NetworkRule {
357 host: host.into(),
358 port,
359 });
360 self
361 }
362
363 pub fn allow_websocket(mut self, host: impl Into<String>, port: Option<u16>) -> Self {
364 self.network.websocket.push(NetworkRule {
365 host: host.into(),
366 port,
367 });
368 self
369 }
370
371 pub fn allow_sqlite(mut self, path: impl Into<PathBuf>) -> Self {
372 self.sqlite.paths.push(path.into());
373 self
374 }
375
376 pub fn allow_shutdown(mut self) -> Self {
377 self.power.shutdown = true;
378 self
379 }
380
381 pub fn allow_reboot(mut self) -> Self {
382 self.power.reboot = true;
383 self
384 }
385}
386
387#[derive(Debug, Clone, Default, Serialize, Deserialize)]
388#[serde(default, rename_all = "camelCase")]
389pub struct FileSystemPolicyConfig {
390 pub read: Vec<PathBuf>,
391 pub write: Vec<PathBuf>,
392}
393
394#[derive(Debug, Clone, Default, Serialize, Deserialize)]
395#[serde(default, rename_all = "camelCase")]
396pub struct NetworkPolicyConfig {
397 pub tcp: Vec<NetworkRule>,
398 pub websocket: Vec<NetworkRule>,
399}
400
401#[derive(Debug, Clone, Serialize, Deserialize)]
402#[serde(default, rename_all = "camelCase")]
403pub struct NetworkRule {
404 pub host: String,
405 pub port: Option<u16>,
406}
407
408impl Default for NetworkRule {
409 fn default() -> Self {
410 Self {
411 host: String::new(),
412 port: None,
413 }
414 }
415}
416
417impl NetworkRule {
418 pub fn any() -> Self {
419 Self {
420 host: "*".to_string(),
421 port: None,
422 }
423 }
424
425 fn matches(&self, host: &str, port: u16) -> bool {
426 (self.host == "*" || self.host.eq_ignore_ascii_case(host))
427 && self.port.is_none_or(|allowed| allowed == port)
428 }
429}
430
431#[derive(Debug, Clone, Default, Serialize, Deserialize)]
432#[serde(default, rename_all = "camelCase")]
433pub struct SqlitePolicyConfig {
434 pub paths: Vec<PathBuf>,
435}
436
437#[derive(Debug, Clone, Default, Serialize, Deserialize)]
438#[serde(default, rename_all = "camelCase")]
439pub struct PowerPolicyConfig {
440 pub shutdown: bool,
441 pub reboot: bool,
442}
443
444pub trait NativePolicy: Send + Sync {
445 fn mode(&self) -> &'static str;
446 fn allow_fs_path(&self, _operation: &str, _path: &std::path::Path) -> Result<(), RuntimeError> {
447 Ok(())
448 }
449 fn allow_network(&self, _operation: &str, _host: &str, _port: u16) -> Result<(), RuntimeError> {
450 Ok(())
451 }
452 fn allow_sqlite_path(&self, _path: &std::path::Path) -> Result<(), RuntimeError> {
453 Ok(())
454 }
455 fn allow_power(&self, _operation: &str) -> Result<(), RuntimeError> {
456 Ok(())
457 }
458}
459
460#[derive(Debug, Clone, Default)]
461pub struct ConfiguredNativePolicy {
462 config: NativePolicyConfig,
463}
464
465impl ConfiguredNativePolicy {
466 pub fn new(config: NativePolicyConfig) -> Self {
467 Self { config }
468 }
469
470 pub fn config(&self) -> &NativePolicyConfig {
471 &self.config
472 }
473}
474
475impl NativePolicy for ConfiguredNativePolicy {
476 fn mode(&self) -> &'static str {
477 "configured"
478 }
479
480 fn allow_fs_path(&self, operation: &str, path: &Path) -> Result<(), RuntimeError> {
481 let allowed = if is_fs_write_operation(operation) {
482 &self.config.filesystem.write
483 } else {
484 &self.config.filesystem.read
485 };
486 if allowed.iter().any(|root| path_is_under(path, root)) {
487 return Ok(());
488 }
489 Err(policy_denied("filesystem", operation))
490 }
491
492 fn allow_network(&self, operation: &str, host: &str, port: u16) -> Result<(), RuntimeError> {
493 let rules = if operation.starts_with("websocket_") {
494 &self.config.network.websocket
495 } else {
496 &self.config.network.tcp
497 };
498 if rules.iter().any(|rule| rule.matches(host, port)) {
499 return Ok(());
500 }
501 Err(policy_denied("network", operation))
502 }
503
504 fn allow_sqlite_path(&self, path: &Path) -> Result<(), RuntimeError> {
505 if self
506 .config
507 .sqlite
508 .paths
509 .iter()
510 .any(|root| path_is_under(path, root))
511 {
512 return Ok(());
513 }
514 Err(policy_denied("sqlite", "open"))
515 }
516
517 fn allow_power(&self, operation: &str) -> Result<(), RuntimeError> {
518 let allowed = match operation {
519 "shutdown" => self.config.power.shutdown,
520 "reboot" => self.config.power.reboot,
521 _ => false,
522 };
523 if allowed {
524 return Ok(());
525 }
526 Err(policy_denied("power", operation))
527 }
528}
529
530#[derive(Debug, Default)]
531pub struct DevPermissivePolicy;
532
533impl NativePolicy for DevPermissivePolicy {
534 fn mode(&self) -> &'static str {
535 "dev-permissive"
536 }
537}
538
539pub fn build_native_rpc_server() -> RpcServer {
540 build_native_rpc_server_with_policy(Arc::new(DevPermissivePolicy))
541}
542
543pub fn build_native_rpc_server_with_config(config: NativePolicyConfig) -> RpcServer {
544 build_native_rpc_server_with_policy(Arc::new(ConfiguredNativePolicy::new(config)))
545}
546
547pub fn build_native_rpc_server_with_policy(policy: Arc<dyn NativePolicy>) -> RpcServer {
548 build_native_rpc_server_with_options(NativeRpcServerOptions {
549 policy,
550 ..NativeRpcServerOptions::default()
551 })
552}
553
554pub fn build_native_rpc_server_with_options(options: NativeRpcServerOptions) -> RpcServer {
555 let provider_capabilities = options
556 .providers
557 .iter()
558 .flat_map(|provider| provider.capabilities())
559 .collect::<Vec<_>>();
560 let state = NativeState {
561 policy: options.policy,
562 services: options.services,
563 provider_capabilities: Arc::new(provider_capabilities),
564 ..NativeState::default()
565 };
566 let mut builder = RpcServerBuilder::new();
567 builder.set_security(options.security);
568 builder.set_connection_cleanup_sink(Arc::new(state.clone()));
569 register(
570 &mut builder,
571 RUNTIME_INSTANCE,
572 RUNTIME_SERVICE_GUID,
573 1..=3,
574 NativeHandler::new(state.clone(), ServiceKind::Runtime),
575 );
576 if options.services.fs {
577 register(
578 &mut builder,
579 FS_INSTANCE,
580 FS_SERVICE_GUID,
581 1..=17,
582 NativeHandler::new(state.clone(), ServiceKind::Fs),
583 );
584 }
585 if options.services.archive {
586 register(
587 &mut builder,
588 ARCHIVE_INSTANCE,
589 ARCHIVE_SERVICE_GUID,
590 1..=2,
591 NativeHandler::new(state.clone(), ServiceKind::Archive),
592 );
593 }
594 if options.services.tcp {
595 register(
596 &mut builder,
597 TCP_INSTANCE,
598 TCP_SERVICE_GUID,
599 1..=6,
600 NativeHandler::new(state.clone(), ServiceKind::Tcp),
601 );
602 }
603 if options.services.websocket {
604 register(
605 &mut builder,
606 WEBSOCKET_INSTANCE,
607 WEBSOCKET_SERVICE_GUID,
608 1..=6,
609 NativeHandler::new(state.clone(), ServiceKind::WebSocket),
610 );
611 }
612 if options.services.sqlite {
613 register(
614 &mut builder,
615 SQLITE_INSTANCE,
616 SQLITE_SERVICE_GUID,
617 1..=7,
618 NativeHandler::new(state.clone(), ServiceKind::Sqlite),
619 );
620 }
621 if options.services.system {
622 register(
623 &mut builder,
624 SYSTEM_INSTANCE,
625 SYSTEM_SERVICE_GUID,
626 1..=3,
627 NativeHandler::new(state.clone(), ServiceKind::System),
628 );
629 }
630 for provider in &options.providers {
631 provider.register(&mut builder);
632 }
633 builder.build()
634}
635
636pub fn register_native_named_instance(
637 builder: &mut RpcServerBuilder,
638 name: &str,
639 guid: &str,
640 methods: impl IntoIterator<Item = u32>,
641 handler: Arc<dyn rpc_runtime_server::RpcServiceHandler>,
642) -> InstanceId {
643 let guid = Uuid::parse_str(guid).expect("static service guid");
644 builder.register_named_instance(name, ServiceGuid::new(guid), methods, handler)
645}
646
647fn register(
648 builder: &mut RpcServerBuilder,
649 name: &str,
650 guid: &str,
651 methods: impl IntoIterator<Item = u32>,
652 handler: NativeHandler,
653) {
654 let _ = register_native_named_instance(builder, name, guid, methods, Arc::new(handler));
655}
656
657#[derive(Clone, Copy)]
658enum ServiceKind {
659 Runtime,
660 Fs,
661 Archive,
662 Tcp,
663 WebSocket,
664 Sqlite,
665 System,
666}
667
668struct NativeHandler {
669 state: NativeState,
670 kind: ServiceKind,
671}
672
673impl NativeHandler {
674 fn new(state: NativeState, kind: ServiceKind) -> Self {
675 Self { state, kind }
676 }
677}
678
679impl rpc_runtime_server::RpcServiceHandler for NativeHandler {
680 fn call(&self, ctx: RpcCallContext, method_id: MethodId, payload: Value) -> HandlerFuture {
681 let state = self.state.clone();
682 let kind = self.kind;
683 Box::pin(async move {
684 match kind {
685 ServiceKind::Runtime => runtime_call(state, ctx, method_id.get(), payload).await,
686 ServiceKind::Fs => filesystem::fs_call(state, ctx, method_id.get(), payload).await,
687 ServiceKind::Archive => {
688 archive::archive_call(state, method_id.get(), payload).await
689 }
690 ServiceKind::Tcp => tcp_call(state, ctx, method_id.get(), payload).await,
691 ServiceKind::WebSocket => {
692 websocket_call(state, ctx, method_id.get(), payload).await
693 }
694 ServiceKind::Sqlite => sqlite_call(state, ctx, method_id.get(), payload).await,
695 ServiceKind::System => system_call(state, method_id.get(), payload).await,
696 }
697 })
698 }
699}
700
701async fn runtime_call(
702 state: NativeState,
703 ctx: RpcCallContext,
704 method: u32,
705 _payload: Value,
706) -> Result<Value, RuntimeError> {
707 match method {
708 1 => Ok(array(vec![
709 string(std::env::consts::OS),
710 string(std::env::consts::ARCH),
711 string(std::env::consts::FAMILY),
712 std::env::current_exe()
713 .ok()
714 .and_then(|path| path.to_str().map(string))
715 .unwrap_or(Value::Nil),
716 string_list(native_capabilities(&state)),
717 string(state.policy.mode()),
718 ])),
719 2 => Ok(string_list(native_capabilities(&state))),
720 3 => {
721 state
722 .dispose_connection_resources(ctx.connection_id())
723 .await;
724 Ok(empty())
725 }
726 _ => Err(method_not_found(method)),
727 }
728}
729
730fn native_capabilities(state: &NativeState) -> Vec<&'static str> {
731 let mut capabilities = state.services.capabilities();
732 capabilities.extend(state.provider_capabilities.iter().copied());
733 capabilities
734}
735
736async fn tcp_call(
737 state: NativeState,
738 ctx: RpcCallContext,
739 method: u32,
740 payload: Value,
741) -> Result<Value, RuntimeError> {
742 match method {
743 1 => {
744 let host = string_arg(&payload, 0)?;
745 let port = u16_arg(&payload, 1)?;
746 state.policy.allow_network("tcp_connect", &host, port)?;
747 let stream = TcpStream::connect((host.as_str(), port))
748 .await
749 .map_err(io_error)?;
750 let id = resource_id("tcp");
751 install_tcp_socket(state, ctx.clone(), id.clone(), None, stream).await;
752 Ok(array(vec![string(id)]))
753 }
754 2 => {
755 let id = string_arg(&payload, 0)?;
756 let bytes = bytes_arg(&payload, 1)?;
757 let writer = socket_writer(&state, &id).await?;
758 writer
759 .lock()
760 .await
761 .write_all(&bytes)
762 .await
763 .map_err(io_error)?;
764 Ok(empty())
765 }
766 3 => {
767 let id = string_arg(&payload, 0)?;
768 let writer = socket_writer(&state, &id).await?;
769 writer.lock().await.shutdown().await.map_err(io_error)?;
770 Ok(empty())
771 }
772 4 => {
773 let id = string_arg(&payload, 0)?;
774 state.tcp_sockets.lock().await.remove(&id);
775 Ok(empty())
776 }
777 5 => {
778 let host = string_arg(&payload, 0)?;
779 let port = u16_arg(&payload, 1)?;
780 state.policy.allow_network("tcp_listen", &host, port)?;
781 let listener = TcpListener::bind((host.as_str(), port))
782 .await
783 .map_err(io_error)?;
784 let addr = listener.local_addr().map_err(io_error)?;
785 let id = resource_id("tcp-server");
786 let server_id = id.clone();
787 let state_for_task = state.clone();
788 let ctx_for_task = ctx.clone();
789 let task = tokio::spawn(async move {
790 tcp_accept_loop(state_for_task, ctx_for_task, server_id, listener).await;
791 });
792 state.tcp_servers.lock().await.insert(
793 id.clone(),
794 TaskResource {
795 owner_connection_id: ctx.connection_id(),
796 task,
797 },
798 );
799 Ok(array(vec![string(id), string(addr.to_string())]))
800 }
801 6 => {
802 let id = string_arg(&payload, 0)?;
803 if let Some(resource) = state.tcp_servers.lock().await.remove(&id) {
804 resource.task.abort();
805 }
806 Ok(empty())
807 }
808 _ => Err(method_not_found(method)),
809 }
810}
811
812async fn install_tcp_socket(
813 state: NativeState,
814 ctx: RpcCallContext,
815 id: String,
816 parent_id: Option<String>,
817 stream: TcpStream,
818) {
819 let (mut reader, writer) = stream.into_split();
820 state.tcp_sockets.lock().await.insert(
821 id.clone(),
822 TcpSocketResource {
823 owner_connection_id: ctx.connection_id(),
824 writer: Arc::new(AsyncMutex::new(writer)),
825 },
826 );
827 let state_for_task = state.clone();
828 tokio::spawn(async move {
829 let mut buf = vec![0_u8; 8192];
830 loop {
831 match reader.read(&mut buf).await {
832 Ok(0) => {
833 let _ = notify_tcp(&ctx, "close", &id, parent_id.as_deref(), None, None).await;
834 state_for_task.tcp_sockets.lock().await.remove(&id);
835 break;
836 }
837 Ok(n) => {
838 let _ = notify_tcp(
839 &ctx,
840 "data",
841 &id,
842 parent_id.as_deref(),
843 Some(&buf[..n]),
844 None,
845 )
846 .await;
847 }
848 Err(error) => {
849 let _ = notify_tcp(
850 &ctx,
851 "error",
852 &id,
853 parent_id.as_deref(),
854 None,
855 Some(error.to_string()),
856 )
857 .await;
858 state_for_task.tcp_sockets.lock().await.remove(&id);
859 break;
860 }
861 }
862 }
863 });
864}
865
866async fn tcp_accept_loop(
867 state: NativeState,
868 ctx: RpcCallContext,
869 server_id: String,
870 listener: TcpListener,
871) {
872 loop {
873 match listener.accept().await {
874 Ok((stream, _)) => {
875 let socket_id = resource_id("tcp");
876 install_tcp_socket(
877 state.clone(),
878 ctx.clone(),
879 socket_id.clone(),
880 Some(server_id.clone()),
881 stream,
882 )
883 .await;
884 let _ =
885 notify_tcp(&ctx, "connection", &socket_id, Some(&server_id), None, None).await;
886 }
887 Err(error) => {
888 let _ = notify_tcp(
889 &ctx,
890 "error",
891 &server_id,
892 None,
893 None,
894 Some(error.to_string()),
895 )
896 .await;
897 break;
898 }
899 }
900 }
901}
902
903async fn socket_writer(
904 state: &NativeState,
905 id: &str,
906) -> Result<Arc<AsyncMutex<OwnedWriteHalf>>, RuntimeError> {
907 state
908 .tcp_sockets
909 .lock()
910 .await
911 .get(id)
912 .map(|resource| resource.writer.clone())
913 .ok_or_else(|| runtime_error(format!("TCP socket `{id}` was not found")))
914}
915
916async fn notify_tcp(
917 ctx: &RpcCallContext,
918 kind: &str,
919 id: &str,
920 parent_id: Option<&str>,
921 data: Option<&[u8]>,
922 message: Option<String>,
923) -> Result<(), RuntimeError> {
924 ctx.notify_bound(
925 EVENT_NOTIFICATION_ID,
926 array(vec![
927 string(kind),
928 string(id),
929 parent_id.map(string).unwrap_or(Value::Nil),
930 data.map(|bytes| Value::Binary(bytes.to_vec()))
931 .unwrap_or(Value::Nil),
932 message.map(string).unwrap_or(Value::Nil),
933 ]),
934 )
935 .await
936}
937
938async fn websocket_call(
939 state: NativeState,
940 ctx: RpcCallContext,
941 method: u32,
942 payload: Value,
943) -> Result<Value, RuntimeError> {
944 match method {
945 1 => {
946 let url = string_arg(&payload, 0)?;
947 let (host, port) = websocket_url_authority(&url)?;
948 state
949 .policy
950 .allow_network("websocket_connect", &host, port)?;
951 let (stream, _) = tokio_tungstenite::connect_async(&url)
952 .await
953 .map_err(ws_error)?;
954 let id = resource_id("ws");
955 install_websocket(state, ctx, id.clone(), None, stream).await;
956 Ok(array(vec![string(id)]))
957 }
958 2 => {
959 let id = string_arg(&payload, 0)?;
960 let text = string_arg(&payload, 1)?;
961 websocket_writer(&state, &id)
962 .await?
963 .lock()
964 .await
965 .send(Message::Text(text.into()))
966 .await
967 .map_err(ws_error)?;
968 Ok(empty())
969 }
970 3 => {
971 let id = string_arg(&payload, 0)?;
972 let bytes = bytes_arg(&payload, 1)?;
973 websocket_writer(&state, &id)
974 .await?
975 .lock()
976 .await
977 .send(Message::Binary(bytes.into()))
978 .await
979 .map_err(ws_error)?;
980 Ok(empty())
981 }
982 4 => {
983 let id = string_arg(&payload, 0)?;
984 if let Some(resource) = state.websocket_sockets.lock().await.remove(&id) {
985 let _ = resource
986 .writer
987 .lock()
988 .await
989 .send(Message::Close(None))
990 .await;
991 }
992 Ok(empty())
993 }
994 5 => {
995 let host = string_arg(&payload, 0)?;
996 let port = u16_arg(&payload, 1)?;
997 state
998 .policy
999 .allow_network("websocket_listen", &host, port)?;
1000 let listener = TcpListener::bind((host.as_str(), port))
1001 .await
1002 .map_err(io_error)?;
1003 let addr = listener.local_addr().map_err(io_error)?;
1004 let id = resource_id("ws-server");
1005 let server_id = id.clone();
1006 let state_for_task = state.clone();
1007 let ctx_for_task = ctx.clone();
1008 let task = tokio::spawn(async move {
1009 websocket_accept_loop(state_for_task, ctx_for_task, server_id, listener).await;
1010 });
1011 state.websocket_servers.lock().await.insert(
1012 id.clone(),
1013 TaskResource {
1014 owner_connection_id: ctx.connection_id(),
1015 task,
1016 },
1017 );
1018 Ok(array(vec![string(id), string(format!("ws://{addr}"))]))
1019 }
1020 6 => {
1021 let id = string_arg(&payload, 0)?;
1022 if let Some(resource) = state.websocket_servers.lock().await.remove(&id) {
1023 resource.task.abort();
1024 }
1025 Ok(empty())
1026 }
1027 _ => Err(method_not_found(method)),
1028 }
1029}
1030
1031async fn install_websocket<S>(
1032 state: NativeState,
1033 ctx: RpcCallContext,
1034 id: String,
1035 parent_id: Option<String>,
1036 stream: tokio_tungstenite::WebSocketStream<S>,
1037) where
1038 tokio_tungstenite::WebSocketStream<S>: futures_util::Stream<Item = Result<Message, tokio_tungstenite::tungstenite::Error>>
1039 + futures_util::Sink<Message, Error = tokio_tungstenite::tungstenite::Error>
1040 + Unpin
1041 + Send
1042 + 'static,
1043{
1044 let (writer, mut reader) = stream.split();
1045 state.websocket_sockets.lock().await.insert(
1046 id.clone(),
1047 WebSocketResource {
1048 owner_connection_id: ctx.connection_id(),
1049 writer: Arc::new(AsyncMutex::new(Box::pin(writer))),
1050 },
1051 );
1052 let state_for_task = state.clone();
1053 tokio::spawn(async move {
1054 while let Some(message) = reader.next().await {
1055 match message {
1056 Ok(Message::Text(text)) => {
1057 let _ = notify_ws(
1058 &ctx,
1059 "text",
1060 &id,
1061 parent_id.as_deref(),
1062 None,
1063 Some(&text),
1064 None,
1065 )
1066 .await;
1067 }
1068 Ok(Message::Binary(bytes)) => {
1069 let _ = notify_ws(
1070 &ctx,
1071 "binary",
1072 &id,
1073 parent_id.as_deref(),
1074 Some(&bytes),
1075 None,
1076 None,
1077 )
1078 .await;
1079 }
1080 Ok(Message::Close(_)) => {
1081 let _ =
1082 notify_ws(&ctx, "close", &id, parent_id.as_deref(), None, None, None).await;
1083 break;
1084 }
1085 Ok(_) => {}
1086 Err(error) => {
1087 let _ = notify_ws(
1088 &ctx,
1089 "error",
1090 &id,
1091 parent_id.as_deref(),
1092 None,
1093 None,
1094 Some(error.to_string()),
1095 )
1096 .await;
1097 break;
1098 }
1099 }
1100 }
1101 state_for_task.websocket_sockets.lock().await.remove(&id);
1102 });
1103}
1104
1105async fn websocket_accept_loop(
1106 state: NativeState,
1107 ctx: RpcCallContext,
1108 server_id: String,
1109 listener: TcpListener,
1110) {
1111 loop {
1112 match listener.accept().await {
1113 Ok((stream, _)) => match tokio_tungstenite::accept_async(stream).await {
1114 Ok(ws) => {
1115 let id = resource_id("ws");
1116 install_websocket(
1117 state.clone(),
1118 ctx.clone(),
1119 id.clone(),
1120 Some(server_id.clone()),
1121 ws,
1122 )
1123 .await;
1124 let _ = notify_ws(&ctx, "connection", &id, Some(&server_id), None, None, None)
1125 .await;
1126 }
1127 Err(error) => {
1128 let _ = notify_ws(
1129 &ctx,
1130 "error",
1131 &server_id,
1132 None,
1133 None,
1134 None,
1135 Some(error.to_string()),
1136 )
1137 .await;
1138 }
1139 },
1140 Err(error) => {
1141 let _ = notify_ws(
1142 &ctx,
1143 "error",
1144 &server_id,
1145 None,
1146 None,
1147 None,
1148 Some(error.to_string()),
1149 )
1150 .await;
1151 break;
1152 }
1153 }
1154 }
1155}
1156
1157async fn websocket_writer(state: &NativeState, id: &str) -> Result<WebSocketWriter, RuntimeError> {
1158 state
1159 .websocket_sockets
1160 .lock()
1161 .await
1162 .get(id)
1163 .map(|resource| resource.writer.clone())
1164 .ok_or_else(|| runtime_error(format!("WebSocket `{id}` was not found")))
1165}
1166
1167async fn notify_ws(
1168 ctx: &RpcCallContext,
1169 kind: &str,
1170 id: &str,
1171 parent_id: Option<&str>,
1172 data: Option<&[u8]>,
1173 text: Option<&str>,
1174 message: Option<String>,
1175) -> Result<(), RuntimeError> {
1176 ctx.notify_bound(
1177 EVENT_NOTIFICATION_ID,
1178 array(vec![
1179 string(kind),
1180 string(id),
1181 parent_id.map(string).unwrap_or(Value::Nil),
1182 data.map(|bytes| Value::Binary(bytes.to_vec()))
1183 .unwrap_or(Value::Nil),
1184 text.map(string).unwrap_or(Value::Nil),
1185 message.map(string).unwrap_or(Value::Nil),
1186 ]),
1187 )
1188 .await
1189}
1190
1191async fn sqlite_call(
1192 state: NativeState,
1193 ctx: RpcCallContext,
1194 method: u32,
1195 payload: Value,
1196) -> Result<Value, RuntimeError> {
1197 match method {
1198 1 => {
1199 let path = string_arg(&payload, 0)?;
1200 let path = PathBuf::from(path);
1201 state.policy.allow_sqlite_path(&path)?;
1202 let db = rusqlite::Connection::open(path).map_err(sqlite_error)?;
1203 let id = resource_id("sqlite");
1204 state.sqlite.lock().expect("sqlite lock").insert(
1205 id.clone(),
1206 SqliteResource {
1207 owner_connection_id: ctx.connection_id(),
1208 connection: db,
1209 },
1210 );
1211 Ok(array(vec![string(id)]))
1212 }
1213 2 => {
1214 let id = string_arg(&payload, 0)?;
1215 state.sqlite.lock().expect("sqlite lock").remove(&id);
1216 Ok(empty())
1217 }
1218 3 => with_db(&state, &payload, |db, payload| {
1219 db.execute_batch(&string_arg(payload, 1)?)
1220 .map_err(sqlite_error)?;
1221 Ok(empty())
1222 }),
1223 4 => with_db(&state, &payload, |db, payload| {
1224 let changes = execute_with_params(db, &string_arg(payload, 1)?, field(payload, 2))?;
1225 Ok(array(vec![
1226 Value::from(changes as u64),
1227 Value::from(db.last_insert_rowid()),
1228 ]))
1229 }),
1230 5 => with_db(&state, &payload, |db, payload| {
1231 let rows = query_rows(db, &string_arg(payload, 1)?, field(payload, 2))?;
1232 Ok(rows.into_iter().next().unwrap_or(Value::Nil))
1233 }),
1234 6 => with_db(&state, &payload, |db, payload| {
1235 Ok(Value::Array(query_rows(
1236 db,
1237 &string_arg(payload, 1)?,
1238 field(payload, 2),
1239 )?))
1240 }),
1241 7 => with_db(&state, &payload, |db, payload| {
1242 db.execute_batch("BEGIN IMMEDIATE").map_err(sqlite_error)?;
1243 let statements = array_arg(payload, 1)?;
1244 let result = (|| {
1245 for statement in statements {
1246 db.execute_batch(&value_string(statement)?)
1247 .map_err(sqlite_error)?;
1248 }
1249 Ok::<_, RuntimeError>(())
1250 })();
1251 if result.is_ok() {
1252 db.execute_batch("COMMIT").map_err(sqlite_error)?;
1253 } else {
1254 let _ = db.execute_batch("ROLLBACK");
1255 }
1256 result.map(|_| empty())
1257 }),
1258 _ => Err(method_not_found(method)),
1259 }
1260}
1261
1262fn with_db(
1263 state: &NativeState,
1264 payload: &Value,
1265 op: impl FnOnce(&rusqlite::Connection, &Value) -> Result<Value, RuntimeError>,
1266) -> Result<Value, RuntimeError> {
1267 let id = string_arg(payload, 0)?;
1268 let guard = state.sqlite.lock().expect("sqlite lock");
1269 let db = &guard
1270 .get(&id)
1271 .ok_or_else(|| runtime_error(format!("SQLite database `{id}` was not found")))?
1272 .connection;
1273 op(db, payload)
1274}
1275
1276fn execute_with_params(
1277 db: &rusqlite::Connection,
1278 sql: &str,
1279 params: Option<&Value>,
1280) -> Result<usize, RuntimeError> {
1281 let values = sqlite_params(params)?;
1282 db.execute(sql, rusqlite::params_from_iter(values))
1283 .map_err(sqlite_error)
1284}
1285
1286fn query_rows(
1287 db: &rusqlite::Connection,
1288 sql: &str,
1289 params: Option<&Value>,
1290) -> Result<Vec<Value>, RuntimeError> {
1291 let values = sqlite_params(params)?;
1292 let mut statement = db.prepare(sql).map_err(sqlite_error)?;
1293 let names: Vec<String> = statement
1294 .column_names()
1295 .into_iter()
1296 .map(ToOwned::to_owned)
1297 .collect();
1298 let rows = statement
1299 .query_map(rusqlite::params_from_iter(values), |row| {
1300 let mut columns = Vec::new();
1301 for (index, name) in names.iter().enumerate() {
1302 columns.push(array(vec![string(name), sqlite_value(row.get_ref(index)?)]));
1303 }
1304 Ok(array(vec![Value::Array(columns)]))
1305 })
1306 .map_err(sqlite_error)?;
1307 rows.collect::<Result<Vec<_>, _>>().map_err(sqlite_error)
1308}
1309
1310fn sqlite_params(input: Option<&Value>) -> Result<Vec<rusqlite::types::Value>, RuntimeError> {
1311 let Some(Value::Array(values)) = input else {
1312 return Ok(Vec::new());
1313 };
1314 values.iter().map(sqlite_param).collect()
1315}
1316
1317fn sqlite_param(value: &Value) -> Result<rusqlite::types::Value, RuntimeError> {
1318 if let Value::Array(fields) = value {
1319 let kind = string_arg(value, 0)?;
1320 return match kind.as_str() {
1321 "null" => Ok(rusqlite::types::Value::Null),
1322 "integer" => Ok(rusqlite::types::Value::Integer(i64_arg_from_fields(
1323 fields, 1,
1324 )?)),
1325 "real" => Ok(rusqlite::types::Value::Real(f64_arg_from_fields(
1326 fields, 2,
1327 )?)),
1328 "text" => Ok(rusqlite::types::Value::Text(string_arg(value, 3)?)),
1329 "blob" => Ok(rusqlite::types::Value::Blob(bytes_arg(value, 4)?)),
1330 "boolean" => Ok(rusqlite::types::Value::Integer(i64::from(
1331 bool_arg(value, 5).unwrap_or(false),
1332 ))),
1333 _ => Err(decode_error(format!(
1334 "unsupported SQLite tagged parameter kind `{kind}`"
1335 ))),
1336 };
1337 }
1338 match value {
1339 Value::Nil => Ok(rusqlite::types::Value::Null),
1340 Value::Boolean(value) => Ok(rusqlite::types::Value::Integer(i64::from(*value))),
1341 Value::Integer(value) => Ok(rusqlite::types::Value::Integer(
1342 value
1343 .as_i64()
1344 .ok_or_else(|| decode_error("integer out of range"))?,
1345 )),
1346 Value::F32(value) => Ok(rusqlite::types::Value::Real((*value).into())),
1347 Value::F64(value) => Ok(rusqlite::types::Value::Real(*value)),
1348 Value::String(value) => Ok(rusqlite::types::Value::Text(
1349 value.as_str().unwrap_or_default().to_string(),
1350 )),
1351 Value::Binary(value) => Ok(rusqlite::types::Value::Blob(value.clone())),
1352 _ => Err(decode_error("unsupported SQLite parameter")),
1353 }
1354}
1355
1356fn sqlite_value(value: rusqlite::types::ValueRef<'_>) -> Value {
1357 match value {
1358 rusqlite::types::ValueRef::Null => array(vec![
1359 string("null"),
1360 Value::Nil,
1361 Value::Nil,
1362 Value::Nil,
1363 Value::Nil,
1364 Value::Nil,
1365 ]),
1366 rusqlite::types::ValueRef::Integer(value) => array(vec![
1367 string("integer"),
1368 Value::from(value),
1369 Value::Nil,
1370 Value::Nil,
1371 Value::Nil,
1372 Value::Nil,
1373 ]),
1374 rusqlite::types::ValueRef::Real(value) => array(vec![
1375 string("real"),
1376 Value::Nil,
1377 Value::from(value),
1378 Value::Nil,
1379 Value::Nil,
1380 Value::Nil,
1381 ]),
1382 rusqlite::types::ValueRef::Text(value) => array(vec![
1383 string("text"),
1384 Value::Nil,
1385 Value::Nil,
1386 string(String::from_utf8_lossy(value)),
1387 Value::Nil,
1388 Value::Nil,
1389 ]),
1390 rusqlite::types::ValueRef::Blob(value) => array(vec![
1391 string("blob"),
1392 Value::Nil,
1393 Value::Nil,
1394 Value::Nil,
1395 Value::Binary(value.to_vec()),
1396 Value::Nil,
1397 ]),
1398 }
1399}
1400
1401async fn system_call(
1402 state: NativeState,
1403 method: u32,
1404 payload: Value,
1405) -> Result<Value, RuntimeError> {
1406 match method {
1407 1 => Ok(string_list(vec!["shutdown", "reboot"])),
1408 2 => {
1409 state.policy.allow_power("shutdown")?;
1410 run_power_command("shutdown", payload).await
1411 }
1412 3 => {
1413 state.policy.allow_power("reboot")?;
1414 run_power_command("reboot", payload).await
1415 }
1416 _ => Err(method_not_found(method)),
1417 }
1418}
1419
1420async fn run_power_command(kind: &str, payload: Value) -> Result<Value, RuntimeError> {
1421 let delay = u64_arg(&payload, 0).unwrap_or(0);
1422 let mut command = if cfg!(target_os = "windows") {
1423 let mut command = tokio::process::Command::new("shutdown");
1424 command.arg(if kind == "reboot" { "/r" } else { "/s" });
1425 command.arg("/t").arg(delay.to_string());
1426 command
1427 } else if cfg!(target_os = "macos") {
1428 let mut command = tokio::process::Command::new("osascript");
1429 command.arg("-e").arg(if kind == "reboot" {
1430 "tell app \"System Events\" to restart"
1431 } else {
1432 "tell app \"System Events\" to shut down"
1433 });
1434 command
1435 } else {
1436 let mut command = tokio::process::Command::new("systemctl");
1437 if delay > 0 {
1438 return Err(runtime_error(
1439 "delayed shutdown/reboot is not supported by the Linux provider",
1440 ));
1441 }
1442 command.arg(if kind == "reboot" {
1443 "reboot"
1444 } else {
1445 "poweroff"
1446 });
1447 command
1448 };
1449 command
1450 .stdin(Stdio::null())
1451 .stdout(Stdio::null())
1452 .stderr(Stdio::null());
1453 command.spawn().map_err(io_error)?;
1454 Ok(empty())
1455}
1456
1457fn empty() -> Value {
1458 Value::Array(Vec::new())
1459}
1460
1461fn array(values: Vec<Value>) -> Value {
1462 Value::Array(values)
1463}
1464
1465fn string(value: impl AsRef<str>) -> Value {
1466 Value::String(value.as_ref().to_string().into())
1467}
1468
1469fn string_list(values: Vec<impl AsRef<str>>) -> Value {
1470 Value::Array(values.into_iter().map(string).collect())
1471}
1472
1473fn field(value: &Value, index: usize) -> Option<&Value> {
1474 match value {
1475 Value::Array(values) => values.get(index),
1476 _ => None,
1477 }
1478}
1479
1480fn array_arg(value: &Value, index: usize) -> Result<&[Value], RuntimeError> {
1481 match field(value, index) {
1482 Some(Value::Array(values)) => Ok(values),
1483 _ => Err(decode_error(format!("field {index} must be an array"))),
1484 }
1485}
1486
1487fn string_arg(value: &Value, index: usize) -> Result<String, RuntimeError> {
1488 field(value, index)
1489 .map(value_string)
1490 .transpose()?
1491 .ok_or_else(|| decode_error(format!("field {index} must be a string")))
1492}
1493
1494fn value_string(value: &Value) -> Result<String, RuntimeError> {
1495 match value {
1496 Value::String(value) => Ok(value.as_str().unwrap_or_default().to_string()),
1497 _ => Err(decode_error("expected string")),
1498 }
1499}
1500
1501fn path_arg(value: &Value, index: usize) -> Result<PathBuf, RuntimeError> {
1502 Ok(PathBuf::from(string_arg(value, index)?))
1503}
1504
1505fn bytes_arg(value: &Value, index: usize) -> Result<Vec<u8>, RuntimeError> {
1506 match field(value, index) {
1507 Some(Value::Binary(value)) => Ok(value.clone()),
1508 Some(Value::String(value)) => Ok(value.as_str().unwrap_or_default().as_bytes().to_vec()),
1509 _ => Err(decode_error(format!("field {index} must be bytes"))),
1510 }
1511}
1512
1513fn bool_arg(value: &Value, index: usize) -> Option<bool> {
1514 match field(value, index) {
1515 Some(Value::Boolean(value)) => Some(*value),
1516 _ => None,
1517 }
1518}
1519
1520fn bool_field(fields: &[Value], index: usize) -> Option<bool> {
1521 match fields.get(index) {
1522 Some(Value::Boolean(value)) => Some(*value),
1523 _ => None,
1524 }
1525}
1526
1527fn u16_arg(value: &Value, index: usize) -> Result<u16, RuntimeError> {
1528 let value = u64_arg(value, index)?;
1529 u16::try_from(value).map_err(|_| decode_error(format!("field {index} is out of u16 range")))
1530}
1531
1532fn i64_arg(value: &Value, index: usize) -> Result<i64, RuntimeError> {
1533 match field(value, index) {
1534 Some(Value::Integer(value)) => value
1535 .as_i64()
1536 .ok_or_else(|| decode_error(format!("field {index} must be an integer"))),
1537 _ => Err(decode_error(format!("field {index} must be an integer"))),
1538 }
1539}
1540
1541fn u64_arg(value: &Value, index: usize) -> Result<u64, RuntimeError> {
1542 match field(value, index) {
1543 Some(Value::Integer(value)) => value
1544 .as_u64()
1545 .ok_or_else(|| decode_error(format!("field {index} must be a non-negative integer"))),
1546 _ => Err(decode_error(format!("field {index} must be an integer"))),
1547 }
1548}
1549
1550fn i64_arg_from_fields(fields: &[Value], index: usize) -> Result<i64, RuntimeError> {
1551 match fields.get(index) {
1552 Some(Value::Integer(value)) => value
1553 .as_i64()
1554 .ok_or_else(|| decode_error(format!("field {index} must be an integer"))),
1555 _ => Err(decode_error(format!("field {index} must be an integer"))),
1556 }
1557}
1558
1559fn f64_arg_from_fields(fields: &[Value], index: usize) -> Result<f64, RuntimeError> {
1560 match fields.get(index) {
1561 Some(Value::F64(value)) => Ok(*value),
1562 Some(Value::F32(value)) => Ok((*value).into()),
1563 Some(Value::Integer(value)) => value
1564 .as_f64()
1565 .ok_or_else(|| decode_error(format!("field {index} must be a number"))),
1566 _ => Err(decode_error(format!("field {index} must be a number"))),
1567 }
1568}
1569
1570fn resource_id(prefix: &str) -> String {
1571 format!("{prefix}-{}", Uuid::new_v4())
1572}
1573
1574async fn open_file(
1575 state: &NativeState,
1576 id: &str,
1577) -> Result<Arc<AsyncMutex<tokio::fs::File>>, RuntimeError> {
1578 state
1579 .open_files
1580 .lock()
1581 .await
1582 .get(id)
1583 .map(|resource| resource.file.clone())
1584 .ok_or_else(|| runtime_error(format!("file resource `{id}` was not found")))
1585}
1586
1587fn is_fs_write_operation(operation: &str) -> bool {
1588 matches!(
1589 operation,
1590 "write_file"
1591 | "append_file"
1592 | "mkdir"
1593 | "remove"
1594 | "rename_to"
1595 | "copy_to"
1596 | "open_file_write"
1597 | "archive_zip_write"
1598 | "archive_unzip_write"
1599 )
1600}
1601
1602fn path_is_under(path: &Path, root: &Path) -> bool {
1603 let Ok(path) = policy_path(path) else {
1604 return false;
1605 };
1606 let Ok(root) = policy_path(root) else {
1607 return false;
1608 };
1609 path == root || path.starts_with(root)
1610}
1611
1612fn policy_path(path: &Path) -> Result<PathBuf, std::io::Error> {
1613 if path.exists() {
1614 return path.canonicalize();
1615 }
1616 let absolute = if path.is_absolute() {
1617 path.to_path_buf()
1618 } else {
1619 std::env::current_dir()?.join(path)
1620 };
1621 let Some(parent) = absolute.parent() else {
1622 return Ok(absolute);
1623 };
1624 let parent = if parent.exists() {
1625 parent.canonicalize()?
1626 } else {
1627 parent.to_path_buf()
1628 };
1629 Ok(match absolute.file_name() {
1630 Some(name) => parent.join(name),
1631 None => parent,
1632 })
1633}
1634
1635fn websocket_url_authority(url: &str) -> Result<(String, u16), RuntimeError> {
1636 let (default_port, rest) = if let Some(rest) = url.strip_prefix("ws://") {
1637 (80, rest)
1638 } else if let Some(rest) = url.strip_prefix("wss://") {
1639 (443, rest)
1640 } else {
1641 return Err(decode_error(
1642 "websocket URL must start with ws:// or wss://",
1643 ));
1644 };
1645 let authority = rest
1646 .split(['/', '?', '#'])
1647 .next()
1648 .filter(|value| !value.is_empty())
1649 .ok_or_else(|| decode_error("websocket URL must include a host"))?;
1650 let (host, port) = if authority.starts_with('[') {
1651 let end = authority
1652 .find(']')
1653 .ok_or_else(|| decode_error("websocket URL has invalid IPv6 host"))?;
1654 let host = authority[1..end].to_string();
1655 let port = authority[end + 1..]
1656 .strip_prefix(':')
1657 .map(parse_port)
1658 .transpose()?
1659 .unwrap_or(default_port);
1660 (host, port)
1661 } else {
1662 match authority.rsplit_once(':') {
1663 Some((host, port)) => (host.to_string(), parse_port(port)?),
1664 None => (authority.to_string(), default_port),
1665 }
1666 };
1667 if host.is_empty() {
1668 return Err(decode_error("websocket URL must include a host"));
1669 }
1670 Ok((host, port))
1671}
1672
1673fn parse_port(value: &str) -> Result<u16, RuntimeError> {
1674 value
1675 .parse()
1676 .map_err(|_| decode_error("websocket URL port must be a valid u16"))
1677}
1678
1679fn method_not_found(method: u32) -> RuntimeError {
1680 RuntimeError::runtime(
1681 RuntimeErrorCode::MethodNotFound,
1682 format!("native method id `{method}` was not found"),
1683 )
1684}
1685
1686fn decode_error(message: impl Into<String>) -> RuntimeError {
1687 RuntimeError::runtime(RuntimeErrorCode::PayloadDecodeFailed, message)
1688}
1689
1690fn runtime_error(message: impl Into<String>) -> RuntimeError {
1691 RuntimeError::runtime(RuntimeErrorCode::InternalRuntimeError, message)
1692}
1693
1694fn policy_denied(scope: &str, operation: &str) -> RuntimeError {
1695 RuntimeError::runtime(
1696 RuntimeErrorCode::AccessDenied,
1697 format!("native policy denied {scope} operation `{operation}`"),
1698 )
1699}
1700
1701fn io_error(error: std::io::Error) -> RuntimeError {
1702 runtime_error(error.to_string())
1703}
1704
1705fn ws_error(error: tokio_tungstenite::tungstenite::Error) -> RuntimeError {
1706 runtime_error(error.to_string())
1707}
1708
1709fn sqlite_error(error: rusqlite::Error) -> RuntimeError {
1710 runtime_error(error.to_string())
1711}
1712
1713#[allow(dead_code)]
1714fn parse_addr(host: &str, port: u16) -> Result<SocketAddr, RuntimeError> {
1715 format!("{host}:{port}")
1716 .parse()
1717 .map_err(|error| runtime_error(format!("invalid socket address: {error}")))
1718}