vrchat_osc/
lib.rs

1mod oscquery;
2mod mdns;
3mod fetch;
4
5pub use oscquery::*;
6
7use crate::fetch::fetch;
8
9use std::{
10    collections::HashMap,
11    net::{Ipv4Addr, SocketAddr, SocketAddrV4},
12    sync::Arc,
13};
14use wildmatch::WildMatch;
15use convert_case::{Case, Casing};
16use futures::{stream, StreamExt};
17use hickory_proto::rr::Name;
18use oscquery::models::{HostInfo, OscNode, OscRootNode};
19use rosc::OscPacket;
20use tokio::{
21    net::UdpSocket,
22    sync::{mpsc, RwLock},
23    task::JoinHandle,
24};
25
26/// Defines the possible errors that can occur within the VRChatOSC library.
27#[derive(thiserror::Error, Debug)]
28pub enum Error {
29    #[error("OSC error: {0}")]
30    OscError(#[from] rosc::OscError),
31    #[error("OSCQuery error: {0}")]
32    OscQueryError(#[from] oscquery::Error),
33    #[error("mDNS error: {0}")]
34    MdnsError(#[from] mdns::Error),
35    #[error("Hickory DNS protocol error: {0}")]
36    HickoryError(#[from] hickory_proto::ProtoError),
37    #[error("I/O error: {0}")]
38    IoError(#[from] std::io::Error),
39    #[error("Fetch error: {0}")]
40    FetchError(#[from] fetch::Error),
41}
42
43/// Holds handles related to a registered OSC service.
44struct ServiceHandle {
45    /// Join handle for the OSC listening task.
46    osc: JoinHandle<()>,
47    /// The OSCQuery server instance.
48    osc_query: OscQuery,
49}
50
51pub enum ServiceType {
52    /// OSC service type.
53    Osc(Name, SocketAddr),
54    /// OSCQuery service type.
55    OscQuery(Name, SocketAddr),
56}
57
58/// Main struct for managing VRChat OSC services, discovery, and communication.
59pub struct VRChatOSC {
60    /// Socket for sending OSC messages.
61    send_socket: UdpSocket,
62    /// mDNS client instance for service discovery.
63    mdns: mdns::Mdns,
64    /// Stores registered service handles, mapping service name to its handle.
65    service_handles: Arc<RwLock<HashMap<String, ServiceHandle>>>,
66    /// Callback function to be executed when a new mDNS service is discovered.
67    /// The Name is the service instance name, and SocketAddr is its resolved address.
68    on_service_discovered_callback: Arc<RwLock<Option<Arc<dyn Fn(ServiceType) + Send + Sync + 'static>>>>,
69}
70
71impl VRChatOSC {
72    /// Creates a new `VRChatOSC` instance.
73    /// Initializes mDNS, sets up service discovery, and starts a listener task for mDNS service notifications.
74    pub async fn new() -> Result<Arc<VRChatOSC>, Error> {
75        let socket = UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)).await?;
76
77        // Create an mpsc channel for notifying about discovered mDNS services.
78        let (discover_notifier_tx, mut discover_notifier_rx) = mpsc::channel(8);
79        
80        // Initialize the mDNS client, passing the sender part of the notification channel.
81        let mdns_client = mdns::Mdns::new(discover_notifier_tx).await?;
82        
83        // Start following OSC services and OSCQuery JSON services on the local network.
84        let _ = mdns_client.follow(Name::from_ascii("_osc._udp.local.")?).await;
85        let _ = mdns_client.follow(Name::from_ascii("_oscjson._tcp.local.")?).await;
86        
87        // Prepare a shared storage for the service discovered callback.
88        let on_service_discovered_callback = Arc::new(RwLock::new(None::<Arc<dyn Fn(ServiceType) + Send + Sync + 'static>>));
89        let callback_arc_clone = on_service_discovered_callback.clone();
90
91        // Spawn a new asynchronous task to listen for service discovery notifications.
92        // This task will own the `discover_notifier_rx` (receiver end of the mpsc channel).
93        tokio::spawn(async move {
94            // Continuously try to receive messages from the discovery notification channel.
95            loop {
96                if let Some((service_name, socket_addr)) = discover_notifier_rx.recv().await {
97                    let callback_guard = callback_arc_clone.read().await;
98                    // If a callback is registered, invoke it with the service name and address.
99                    if let Some(callback) = callback_guard.as_ref() {
100                        if service_name.trim_to(3).to_ascii() == "_osc._udp.local." {
101                            callback(ServiceType::Osc(service_name.clone(), socket_addr));
102                        } else if service_name.trim_to(3).to_ascii() == "_oscjson._tcp.local." {
103                            callback(ServiceType::OscQuery(service_name.clone(), socket_addr));
104                        }
105                    }
106                }
107            }
108        });
109
110        Ok(Arc::new(VRChatOSC {
111            send_socket: socket,
112            mdns: mdns_client,
113            service_handles: Arc::new(RwLock::new(HashMap::new())),
114            on_service_discovered_callback,
115        }))
116    }
117
118    /// Registers a callback function to be invoked when an mDNS service is discovered.
119    ///
120    /// # Arguments
121    /// * `callback` - A function or closure that takes the service `Name` and `SocketAddr`
122    ///                as arguments. It must be `Send + Sync + 'static`.
123    pub async fn on_connect<F>(&self, callback: F)
124    where
125        F: Fn(ServiceType) + Send + Sync + 'static,
126    {
127        let mut callback_guard = self.on_service_discovered_callback.write().await;
128        *callback_guard = Some(Arc::new(callback));
129    }
130    
131    /// Registers a new OSC service with the local mDNS daemon and starts listening for OSC messages.
132    ///
133    /// # Arguments
134    /// * `service_name` - The name of the service to register (e.g., "MyAppOSC").
135    /// * `parameters` - The root node of the OSC address space for this service.
136    /// * `handler` - A function that will be called when an OSC packet is received for this service.
137    ///               It must be `Fn(OscPacket) + Send + 'static`.
138    pub async fn register<F>(&self, service_name: &str, parameters: OscRootNode, handler: F) -> Result<(), Error>
139    where
140        F: Fn(OscPacket) + Send + 'static,
141    {
142        // Start OSC server (UDP listener)
143        // Bind to localhost on an OS-assigned port.
144        let socket = UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)).await?;
145        let osc_local_addr = socket.local_addr()?; // Get the actual address it bound to.
146        
147        // Spawn a task to handle incoming OSC packets.
148        let osc_handle = tokio::spawn(async move {
149            let mut buf = [0; rosc::decoder::MTU]; // Buffer for receiving OSC packets.
150            loop {
151                // Wait to receive data on the socket.
152                match socket.recv_from(&mut buf).await {
153                    Ok((len, addr)) => {
154                        // Decode the received UDP data into an OSC packet.
155                        if let Ok((_, packet)) = rosc::decoder::decode_udp(&buf[..len]) {
156                            handler(packet); // Call the provided handler with the decoded packet.
157                        } else {
158                            log::warn!("Failed to decode OSC packet from {}", addr);
159                        }
160                    }
161                    Err(e) => {
162                        if e.kind() == std::io::ErrorKind::ConnectionReset || e.kind() == std::io::ErrorKind::BrokenPipe {
163                            log::warn!("Socket connection error ({}). Task for {:?} might need to be restarted or interface is down.", e, socket.local_addr().ok());
164                            break;
165                        } else {
166                            log::error!("Failed to receive data on mDNS socket {:?}: {}", socket.local_addr().ok(), e);
167                        }
168                        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
169                        continue;
170                    }
171                }
172            }
173        });
174        
175        // Start OSCQuery server (HTTP server)
176        let host_info = HostInfo::new(
177            service_name.to_string(),
178            osc_local_addr.ip(), // Use the IP of the OSC server.
179            osc_local_addr.port(), // Use the port of the OSC server.
180        );
181        let mut osc_query = OscQuery::new(host_info, parameters);
182        // Serve OSCQuery on localhost with an OS-assigned port.
183        let osc_query_local_addr = osc_query.serve(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)).await?;
184
185        // Create mDNS service announcements.
186        let service_name_upper_camel = service_name.to_case(Case::UpperCamel); // Convert service name case.
187        
188        // Register the OSC and OSCQuery services with mDNS.
189        self.mdns.register(
190            Name::from_ascii(format!("{}._osc._udp.local.", service_name_upper_camel))?,
191            osc_local_addr,
192        ).await?;
193        self.mdns.register(
194            Name::from_ascii(format!("{}._oscjson._tcp.local.", service_name_upper_camel))?,
195            osc_query_local_addr,
196        ).await?;
197
198        // Save service handles for later management (e.g., unregistering).
199        let mut handles = self.service_handles.write().await;
200        handles.insert(service_name.to_string(), ServiceHandle {
201            osc: osc_handle,
202            osc_query,
203        });
204        Ok(())
205    }
206
207    /// Unregisters an OSC service.
208    /// Stops the OSC and OSCQuery servers and removes mDNS announcements.
209    ///
210    /// # Arguments
211    /// * `service_name` - The name of the service to unregister.
212    pub async fn unregister(&self, service_name: &str) -> Result<(), Error> {
213        let service_name_upper_camel = service_name.to_case(Case::UpperCamel);
214        // Remove the service from our tracking.
215        let mut service_handles_map = self.service_handles.write().await;
216        if let Some(mut service_handle_entry) = service_handles_map.remove(service_name) {
217            // Unregister from mDNS.
218            self.mdns.unregister(Name::from_ascii(format!("{}._osc._udp.local.", service_name_upper_camel))?).await?;
219            self.mdns.unregister(Name::from_ascii(format!("{}._oscjson._tcp.local.", service_name_upper_camel))?).await?;
220            
221            // Stop the associated tasks/servers.
222            service_handle_entry.osc.abort(); // Abort the OSC listening task.
223            service_handle_entry.osc_query.shutdown(); // Gracefully shutdown the OSCQuery server.
224        }
225        Ok(())
226    }
227
228    /// Sends an OSC packet to services matching a given pattern.
229    ///
230    /// # Arguments
231    /// * `packet` - The `OscPacket` to send.
232    /// * `to` - A glob pattern (e.g., "VRChat-Client-*") to match against service names.
233    ///          This matches against the service instance name found via mDNS.
234    pub async fn send(&self, packet: OscPacket, to: &str) -> Result<(), Error> {
235        // Find services matching the pattern. The matching logic is within `find_service`.
236        // The closure provided to `find_service` determines if a service (by its Name) matches.
237        let services = self.mdns.find_service(|name, _| {
238            // `WildMatch` performs glob-style pattern matching.
239            WildMatch::new(&format!("{}._osc._udp.local.", to)).matches(&name.to_ascii())
240        }).await;
241
242        if services.is_empty() {
243            log::warn!("No mDNS services found matching the expression: {}", to);
244            return Ok(());
245        }
246
247        // Encode the OSC packet into bytes.
248        let msg_buf = rosc::encoder::encode(&packet)?;
249        // Send the packet to all found services.
250        for (_, addr) in services {
251            self.send_socket.send_to(&msg_buf, addr).await?;
252        }
253
254        Ok(())
255    }
256
257    /// Sends an OSC packet to a specific socket address.
258    ///
259    /// # Arguments
260    /// * `packet` - The `OscPacket` to send.
261    /// * `addr` - The `SocketAddr` to send the packet to.
262    pub async fn send_to_addr(&self, packet: OscPacket, addr: SocketAddr) -> Result<(), Error> {
263        let msg_buf = rosc::encoder::encode(&packet)?;
264        self.send_socket.send_to(&msg_buf, addr).await?;
265        Ok(())
266    }
267
268    /// Retrieves a specific OSC parameter (node) from services matching a pattern.
269    ///
270    /// # Arguments
271    /// * `method` - The OSC path of the parameter to fetch (e.g., "/avatar/parameters/SomeParam").
272    /// * `from` - A glob pattern (e.g., "VRChat-Client-*") to match against service names.
273    ///          This matches against the service instance name found via mDNS.
274    ///
275    /// # Returns
276    /// A `Vec` of tuples, where each tuple contains the service `Name` and the fetched `OscNode`.
277    /// Returns an empty Vec if no services match or if fetching fails for all matched services.
278    pub async fn get_parameter(&self, method: &str, from: &str) -> Result<Vec<(Name, OscNode)>, Error> {
279        // Find services matching the pattern. The matching logic is within `find_service`.
280        // The closure provided to `find_service` determines if a service (by its Name) matches.
281        let services = self.mdns.find_service(|name, _| {
282            WildMatch::new(&format!("{}._oscjson._tcp.local.", from)).matches(&name.to_ascii())
283        }).await;
284
285        if services.is_empty() {
286            log::warn!("No mDNS services found for get_parameter matching expression: {}", from);
287            return Ok(Vec::new());
288        }
289
290        // Asynchronously fetch the parameter from all matching services.
291        // `stream::iter` creates a stream from the services.
292        // `map` transforms each service into a future that fetches the parameter.
293        // `buffer_unordered(3)` allows up to 3 fetches to run concurrently.
294        // `filter_map` discards any fetches that resulted in an error.
295        // `collect` gathers all successful results into a Vec.
296        let params = stream::iter(services)
297            .map(|(name, addr)| async move {
298                fetch::<_, OscNode>(addr, method).await.map(|(param, _)| (name.clone(), param))
299            })
300            .buffer_unordered(3)
301            .filter_map(|res| async {
302                if let Err(e) = &res {
303                    log::warn!("Failed to fetch parameter: {:?}", e);
304                }
305                res.ok()
306            })
307            .collect::<Vec<_>>()
308            .await;
309
310        Ok(params)
311    }
312
313    /// Retrieves a specific OSC parameter (node) from a specific OSCQuery service address.
314    ///
315    /// # Arguments
316    /// * `method` - The OSC path of the parameter to fetch (e.g., "/avatar/parameters/SomeParam").
317    /// * `addr` - The `SocketAddr` of the OSCQuery service.
318    ///
319    /// # Returns
320    /// The fetched `OscNode`.
321    pub async fn get_parameter_from_addr(&self, method: &str, addr: SocketAddr) -> Result<OscNode, Error> {
322        let (param, _url) = fetch::<_, OscNode>(addr, method).await?;
323        Ok(param)
324    }
325
326    /// Shuts down all registered services and cleans up resources.
327    /// This method should be called before the VRChatOSC instance is dropped
328    /// to ensure graceful shutdown of asynchronous tasks and network services.
329    pub async fn shutdown(&self) -> Result<(), Error> {
330        let mut service_handles_map = self.service_handles.write().await;
331        let service_names: Vec<String> = service_handles_map.keys().cloned().collect();
332        
333        for name in service_names {
334            if let Some(mut handle) = service_handles_map.remove(&name) {
335                let service_name_upper_camel = name.to_case(Case::UpperCamel);
336                // Attempt to unregister from mDNS. Errors are logged but not propagated to allow other services to shut down.
337                if let Err(e) = self.mdns.unregister(Name::from_ascii(format!("{}._osc._udp.local.", service_name_upper_camel))?).await {
338                    log::error!("Failed to unregister OSC for {}: {}", name, e);
339                }
340                if let Err(e) = self.mdns.unregister(Name::from_ascii(format!("{}._oscjson._tcp.local.", service_name_upper_camel))?).await {
341                    log::error!("Failed to unregister OSCQuery for {}: {}", name, e);
342                }
343                
344                handle.osc.abort();
345                handle.osc_query.shutdown();
346            }
347        }
348        Ok(())
349    }
350}
351
352impl Drop for VRChatOSC {
353    fn drop(&mut self) {
354        // Best-effort synchronous cleanup.
355        // For robust cleanup, especially of async tasks and network resources,
356        // the asynchronous `shutdown` method should be called explicitly.
357        if let Ok(mut handles) = self.service_handles.try_write() {
358            let service_names: Vec<String> = handles.keys().cloned().collect();
359            for name in service_names {
360                if let Some(mut service_handle) = handles.remove(&name) {
361                    // mDNS unregistration cannot be reliably called here due to async and potential blocking.
362                    service_handle.osc.abort();
363                    // OscQuery::shutdown() is assumed to be synchronous or non-blocking here.
364                    // If it's async, it cannot be .await-ed in drop.
365                    service_handle.osc_query.shutdown();
366                }
367            }
368        } else {
369            // This might happen if the lock is poisoned or contended in a way not suitable for drop.
370            // In a real application, this should be logged or handled appropriately.
371            // Using log::error! or eprintln! here might be appropriate.
372            // For now, we acknowledge that proper async shutdown is preferred.
373            if !std::thread::panicking() { // Avoid double panic if already panicking
374                 log::warn!("VRChatOSC: Could not acquire lock on service_handles during drop. Explicitly call shutdown() for robust cleanup.");
375            }
376        }
377    }
378}