vrchat_osc/lib.rs
1mod oscquery;
2mod mdns;
3mod fetch;
4
5pub use oscquery::*;
6
7use crate::fetch::fetch;
8
9use std::{
10 collections::HashMap,
11 net::{Ipv4Addr, SocketAddr, SocketAddrV4},
12 sync::Arc,
13};
14use wildmatch::WildMatch;
15use convert_case::{Case, Casing};
16use futures::{stream, StreamExt};
17use hickory_proto::rr::Name;
18use oscquery::models::{HostInfo, OscNode, OscRootNode};
19use rosc::OscPacket;
20use tokio::{
21 net::UdpSocket,
22 sync::{mpsc, RwLock},
23 task::JoinHandle,
24};
25
26/// Defines the possible errors that can occur within the VRChatOSC library.
27#[derive(thiserror::Error, Debug)]
28pub enum Error {
29 #[error("OSC error: {0}")]
30 OscError(#[from] rosc::OscError),
31 #[error("OSCQuery error: {0}")]
32 OscQueryError(#[from] oscquery::Error),
33 #[error("mDNS error: {0}")]
34 MdnsError(#[from] mdns::Error),
35 #[error("Hickory DNS protocol error: {0}")]
36 HickoryError(#[from] hickory_proto::ProtoError),
37 #[error("I/O error: {0}")]
38 IoError(#[from] std::io::Error),
39 #[error("Fetch error: {0}")]
40 FetchError(#[from] fetch::Error),
41}
42
43/// Holds handles related to a registered OSC service.
44struct ServiceHandle {
45 /// Join handle for the OSC listening task.
46 osc: JoinHandle<()>,
47 /// The OSCQuery server instance.
48 osc_query: OscQuery,
49}
50
51pub enum ServiceType {
52 /// OSC service type.
53 Osc(Name, SocketAddr),
54 /// OSCQuery service type.
55 OscQuery(Name, SocketAddr),
56}
57
58/// Main struct for managing VRChat OSC services, discovery, and communication.
59pub struct VRChatOSC {
60 /// Socket for sending OSC messages.
61 send_socket: UdpSocket,
62 /// mDNS client instance for service discovery.
63 mdns: mdns::Mdns,
64 /// Stores registered service handles, mapping service name to its handle.
65 service_handles: Arc<RwLock<HashMap<String, ServiceHandle>>>,
66 /// Callback function to be executed when a new mDNS service is discovered.
67 /// The Name is the service instance name, and SocketAddr is its resolved address.
68 on_service_discovered_callback: Arc<RwLock<Option<Arc<dyn Fn(ServiceType) + Send + Sync + 'static>>>>,
69}
70
71impl VRChatOSC {
72 /// Creates a new `VRChatOSC` instance.
73 /// Initializes mDNS, sets up service discovery, and starts a listener task for mDNS service notifications.
74 pub async fn new() -> Result<Arc<VRChatOSC>, Error> {
75 let socket = UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)).await?;
76
77 // Create an mpsc channel for notifying about discovered mDNS services.
78 let (discover_notifier_tx, mut discover_notifier_rx) = mpsc::channel(8);
79
80 // Initialize the mDNS client, passing the sender part of the notification channel.
81 let mdns_client = mdns::Mdns::new(discover_notifier_tx).await?;
82
83 // Start following OSC services and OSCQuery JSON services on the local network.
84 let _ = mdns_client.follow(Name::from_ascii("_osc._udp.local.")?).await;
85 let _ = mdns_client.follow(Name::from_ascii("_oscjson._tcp.local.")?).await;
86
87 // Prepare a shared storage for the service discovered callback.
88 let on_service_discovered_callback = Arc::new(RwLock::new(None::<Arc<dyn Fn(ServiceType) + Send + Sync + 'static>>));
89 let callback_arc_clone = on_service_discovered_callback.clone();
90
91 // Spawn a new asynchronous task to listen for service discovery notifications.
92 // This task will own the `discover_notifier_rx` (receiver end of the mpsc channel).
93 tokio::spawn(async move {
94 // Continuously try to receive messages from the discovery notification channel.
95 loop {
96 if let Some((service_name, socket_addr)) = discover_notifier_rx.recv().await {
97 let callback_guard = callback_arc_clone.read().await;
98 // If a callback is registered, invoke it with the service name and address.
99 if let Some(callback) = callback_guard.as_ref() {
100 if service_name.trim_to(3).to_ascii() == "_osc._udp.local." {
101 callback(ServiceType::Osc(service_name.clone(), socket_addr));
102 } else if service_name.trim_to(3).to_ascii() == "_oscjson._tcp.local." {
103 callback(ServiceType::OscQuery(service_name.clone(), socket_addr));
104 }
105 }
106 }
107 }
108 });
109
110 Ok(Arc::new(VRChatOSC {
111 send_socket: socket,
112 mdns: mdns_client,
113 service_handles: Arc::new(RwLock::new(HashMap::new())),
114 on_service_discovered_callback,
115 }))
116 }
117
118 /// Registers a callback function to be invoked when an mDNS service is discovered.
119 ///
120 /// # Arguments
121 /// * `callback` - A function or closure that takes the service `Name` and `SocketAddr`
122 /// as arguments. It must be `Send + Sync + 'static`.
123 pub async fn on_connect<F>(&self, callback: F)
124 where
125 F: Fn(ServiceType) + Send + Sync + 'static,
126 {
127 let mut callback_guard = self.on_service_discovered_callback.write().await;
128 *callback_guard = Some(Arc::new(callback));
129 }
130
131 /// Registers a new OSC service with the local mDNS daemon and starts listening for OSC messages.
132 ///
133 /// # Arguments
134 /// * `service_name` - The name of the service to register (e.g., "MyAppOSC").
135 /// * `parameters` - The root node of the OSC address space for this service.
136 /// * `handler` - A function that will be called when an OSC packet is received for this service.
137 /// It must be `Fn(OscPacket) + Send + 'static`.
138 pub async fn register<F>(&self, service_name: &str, parameters: OscRootNode, handler: F) -> Result<(), Error>
139 where
140 F: Fn(OscPacket) + Send + 'static,
141 {
142 // Start OSC server (UDP listener)
143 // Bind to localhost on an OS-assigned port.
144 let socket = UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)).await?;
145 let osc_local_addr = socket.local_addr()?; // Get the actual address it bound to.
146
147 // Spawn a task to handle incoming OSC packets.
148 let osc_handle = tokio::spawn(async move {
149 let mut buf = [0; rosc::decoder::MTU]; // Buffer for receiving OSC packets.
150 loop {
151 // Wait to receive data on the socket.
152 match socket.recv_from(&mut buf).await {
153 Ok((len, addr)) => {
154 // Decode the received UDP data into an OSC packet.
155 if let Ok((_, packet)) = rosc::decoder::decode_udp(&buf[..len]) {
156 handler(packet); // Call the provided handler with the decoded packet.
157 } else {
158 log::warn!("Failed to decode OSC packet from {}", addr);
159 }
160 }
161 Err(e) => {
162 if e.kind() == std::io::ErrorKind::ConnectionReset || e.kind() == std::io::ErrorKind::BrokenPipe {
163 log::warn!("Socket connection error ({}). Task for {:?} might need to be restarted or interface is down.", e, socket.local_addr().ok());
164 break;
165 } else {
166 log::error!("Failed to receive data on mDNS socket {:?}: {}", socket.local_addr().ok(), e);
167 }
168 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
169 continue;
170 }
171 }
172 }
173 });
174
175 // Start OSCQuery server (HTTP server)
176 let host_info = HostInfo::new(
177 service_name.to_string(),
178 osc_local_addr.ip(), // Use the IP of the OSC server.
179 osc_local_addr.port(), // Use the port of the OSC server.
180 );
181 let mut osc_query = OscQuery::new(host_info, parameters);
182 // Serve OSCQuery on localhost with an OS-assigned port.
183 let osc_query_local_addr = osc_query.serve(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)).await?;
184
185 // Create mDNS service announcements.
186 let service_name_upper_camel = service_name.to_case(Case::UpperCamel); // Convert service name case.
187
188 // Register the OSC and OSCQuery services with mDNS.
189 self.mdns.register(
190 Name::from_ascii(format!("{}._osc._udp.local.", service_name_upper_camel))?,
191 osc_local_addr,
192 ).await?;
193 self.mdns.register(
194 Name::from_ascii(format!("{}._oscjson._tcp.local.", service_name_upper_camel))?,
195 osc_query_local_addr,
196 ).await?;
197
198 // Save service handles for later management (e.g., unregistering).
199 let mut handles = self.service_handles.write().await;
200 handles.insert(service_name.to_string(), ServiceHandle {
201 osc: osc_handle,
202 osc_query,
203 });
204 Ok(())
205 }
206
207 /// Unregisters an OSC service.
208 /// Stops the OSC and OSCQuery servers and removes mDNS announcements.
209 ///
210 /// # Arguments
211 /// * `service_name` - The name of the service to unregister.
212 pub async fn unregister(&self, service_name: &str) -> Result<(), Error> {
213 let service_name_upper_camel = service_name.to_case(Case::UpperCamel);
214 // Remove the service from our tracking.
215 let mut service_handles_map = self.service_handles.write().await;
216 if let Some(mut service_handle_entry) = service_handles_map.remove(service_name) {
217 // Unregister from mDNS.
218 self.mdns.unregister(Name::from_ascii(format!("{}._osc._udp.local.", service_name_upper_camel))?).await?;
219 self.mdns.unregister(Name::from_ascii(format!("{}._oscjson._tcp.local.", service_name_upper_camel))?).await?;
220
221 // Stop the associated tasks/servers.
222 service_handle_entry.osc.abort(); // Abort the OSC listening task.
223 service_handle_entry.osc_query.shutdown(); // Gracefully shutdown the OSCQuery server.
224 }
225 Ok(())
226 }
227
228 /// Sends an OSC packet to services matching a given pattern.
229 ///
230 /// # Arguments
231 /// * `packet` - The `OscPacket` to send.
232 /// * `to` - A glob pattern (e.g., "VRChat-Client-*") to match against service names.
233 /// This matches against the service instance name found via mDNS.
234 pub async fn send(&self, packet: OscPacket, to: &str) -> Result<(), Error> {
235 // Find services matching the pattern. The matching logic is within `find_service`.
236 // The closure provided to `find_service` determines if a service (by its Name) matches.
237 let services = self.mdns.find_service(|name, _| {
238 // `WildMatch` performs glob-style pattern matching.
239 WildMatch::new(&format!("{}._osc._udp.local.", to)).matches(&name.to_ascii())
240 }).await;
241
242 if services.is_empty() {
243 log::warn!("No mDNS services found matching the expression: {}", to);
244 return Ok(());
245 }
246
247 // Encode the OSC packet into bytes.
248 let msg_buf = rosc::encoder::encode(&packet)?;
249 // Send the packet to all found services.
250 for (_, addr) in services {
251 self.send_socket.send_to(&msg_buf, addr).await?;
252 }
253
254 Ok(())
255 }
256
257 /// Sends an OSC packet to a specific socket address.
258 ///
259 /// # Arguments
260 /// * `packet` - The `OscPacket` to send.
261 /// * `addr` - The `SocketAddr` to send the packet to.
262 pub async fn send_to_addr(&self, packet: OscPacket, addr: SocketAddr) -> Result<(), Error> {
263 let msg_buf = rosc::encoder::encode(&packet)?;
264 self.send_socket.send_to(&msg_buf, addr).await?;
265 Ok(())
266 }
267
268 /// Retrieves a specific OSC parameter (node) from services matching a pattern.
269 ///
270 /// # Arguments
271 /// * `method` - The OSC path of the parameter to fetch (e.g., "/avatar/parameters/SomeParam").
272 /// * `from` - A glob pattern (e.g., "VRChat-Client-*") to match against service names.
273 /// This matches against the service instance name found via mDNS.
274 ///
275 /// # Returns
276 /// A `Vec` of tuples, where each tuple contains the service `Name` and the fetched `OscNode`.
277 /// Returns an empty Vec if no services match or if fetching fails for all matched services.
278 pub async fn get_parameter(&self, method: &str, from: &str) -> Result<Vec<(Name, OscNode)>, Error> {
279 // Find services matching the pattern. The matching logic is within `find_service`.
280 // The closure provided to `find_service` determines if a service (by its Name) matches.
281 let services = self.mdns.find_service(|name, _| {
282 WildMatch::new(&format!("{}._oscjson._tcp.local.", from)).matches(&name.to_ascii())
283 }).await;
284
285 if services.is_empty() {
286 log::warn!("No mDNS services found for get_parameter matching expression: {}", from);
287 return Ok(Vec::new());
288 }
289
290 // Asynchronously fetch the parameter from all matching services.
291 // `stream::iter` creates a stream from the services.
292 // `map` transforms each service into a future that fetches the parameter.
293 // `buffer_unordered(3)` allows up to 3 fetches to run concurrently.
294 // `filter_map` discards any fetches that resulted in an error.
295 // `collect` gathers all successful results into a Vec.
296 let params = stream::iter(services)
297 .map(|(name, addr)| async move {
298 fetch::<_, OscNode>(addr, method).await.map(|(param, _)| (name.clone(), param))
299 })
300 .buffer_unordered(3)
301 .filter_map(|res| async {
302 if let Err(e) = &res {
303 log::warn!("Failed to fetch parameter: {:?}", e);
304 }
305 res.ok()
306 })
307 .collect::<Vec<_>>()
308 .await;
309
310 Ok(params)
311 }
312
313 /// Retrieves a specific OSC parameter (node) from a specific OSCQuery service address.
314 ///
315 /// # Arguments
316 /// * `method` - The OSC path of the parameter to fetch (e.g., "/avatar/parameters/SomeParam").
317 /// * `addr` - The `SocketAddr` of the OSCQuery service.
318 ///
319 /// # Returns
320 /// The fetched `OscNode`.
321 pub async fn get_parameter_from_addr(&self, method: &str, addr: SocketAddr) -> Result<OscNode, Error> {
322 let (param, _url) = fetch::<_, OscNode>(addr, method).await?;
323 Ok(param)
324 }
325
326 /// Shuts down all registered services and cleans up resources.
327 /// This method should be called before the VRChatOSC instance is dropped
328 /// to ensure graceful shutdown of asynchronous tasks and network services.
329 pub async fn shutdown(&self) -> Result<(), Error> {
330 let mut service_handles_map = self.service_handles.write().await;
331 let service_names: Vec<String> = service_handles_map.keys().cloned().collect();
332
333 for name in service_names {
334 if let Some(mut handle) = service_handles_map.remove(&name) {
335 let service_name_upper_camel = name.to_case(Case::UpperCamel);
336 // Attempt to unregister from mDNS. Errors are logged but not propagated to allow other services to shut down.
337 if let Err(e) = self.mdns.unregister(Name::from_ascii(format!("{}._osc._udp.local.", service_name_upper_camel))?).await {
338 log::error!("Failed to unregister OSC for {}: {}", name, e);
339 }
340 if let Err(e) = self.mdns.unregister(Name::from_ascii(format!("{}._oscjson._tcp.local.", service_name_upper_camel))?).await {
341 log::error!("Failed to unregister OSCQuery for {}: {}", name, e);
342 }
343
344 handle.osc.abort();
345 handle.osc_query.shutdown();
346 }
347 }
348 Ok(())
349 }
350}
351
352impl Drop for VRChatOSC {
353 fn drop(&mut self) {
354 // Best-effort synchronous cleanup.
355 // For robust cleanup, especially of async tasks and network resources,
356 // the asynchronous `shutdown` method should be called explicitly.
357 if let Ok(mut handles) = self.service_handles.try_write() {
358 let service_names: Vec<String> = handles.keys().cloned().collect();
359 for name in service_names {
360 if let Some(mut service_handle) = handles.remove(&name) {
361 // mDNS unregistration cannot be reliably called here due to async and potential blocking.
362 service_handle.osc.abort();
363 // OscQuery::shutdown() is assumed to be synchronous or non-blocking here.
364 // If it's async, it cannot be .await-ed in drop.
365 service_handle.osc_query.shutdown();
366 }
367 }
368 } else {
369 // This might happen if the lock is poisoned or contended in a way not suitable for drop.
370 // In a real application, this should be logged or handled appropriately.
371 // Using log::error! or eprintln! here might be appropriate.
372 // For now, we acknowledge that proper async shutdown is preferred.
373 if !std::thread::panicking() { // Avoid double panic if already panicking
374 log::warn!("VRChatOSC: Could not acquire lock on service_handles during drop. Explicitly call shutdown() for robust cleanup.");
375 }
376 }
377 }
378}